1use super::DockerConfigAuth;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum AuthSource {
15 #[default]
17 Anonymous,
18
19 Basic { username: String, password: String },
21
22 DockerConfig,
24
25 EnvVar {
27 username_var: String,
28 password_var: String,
29 },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub struct RegistryAuthConfig {
35 pub registry: String,
37
38 pub source: AuthSource,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44pub struct AuthConfig {
45 #[serde(default)]
47 pub registries: Vec<RegistryAuthConfig>,
48
49 #[serde(default)]
51 pub default: AuthSource,
52
53 pub docker_config_path: Option<PathBuf>,
55}
56
57impl Default for AuthConfig {
58 fn default() -> Self {
59 Self {
60 registries: Vec::new(),
61 default: AuthSource::DockerConfig,
62 docker_config_path: None,
63 }
64 }
65}
66
67pub struct AuthResolver {
69 config: AuthConfig,
70 docker_config: Option<DockerConfigAuth>,
71 registry_map: HashMap<String, AuthSource>,
72}
73
74impl AuthResolver {
75 #[must_use]
77 pub fn new(config: AuthConfig) -> Self {
78 let registry_map: HashMap<String, AuthSource> = config
80 .registries
81 .iter()
82 .map(|r| (r.registry.clone(), r.source.clone()))
83 .collect();
84
85 let needs_docker_config = config.default == AuthSource::DockerConfig
87 || registry_map
88 .values()
89 .any(|s| matches!(s, AuthSource::DockerConfig));
90
91 let docker_config = if needs_docker_config {
92 Self::load_docker_config(config.docker_config_path.as_ref())
93 } else {
94 None
95 };
96
97 Self {
98 config,
99 docker_config,
100 registry_map,
101 }
102 }
103
104 #[must_use]
109 pub fn resolve(&self, image: &str) -> oci_client::secrets::RegistryAuth {
110 let registry = Self::extract_registry(image);
111 let source = self
112 .registry_map
113 .get(®istry)
114 .unwrap_or(&self.config.default);
115
116 self.resolve_source(source, ®istry)
117 }
118
119 fn resolve_source(
121 &self,
122 source: &AuthSource,
123 registry: &str,
124 ) -> oci_client::secrets::RegistryAuth {
125 match source {
126 AuthSource::Anonymous => oci_client::secrets::RegistryAuth::Anonymous,
127
128 AuthSource::Basic { username, password } => {
129 oci_client::secrets::RegistryAuth::Basic(username.clone(), password.clone())
130 }
131
132 AuthSource::DockerConfig => {
133 if let Some(ref docker_config) = self.docker_config {
134 if let Some((username, password)) = docker_config.get_credentials(registry) {
135 return oci_client::secrets::RegistryAuth::Basic(username, password);
136 }
137 }
138 oci_client::secrets::RegistryAuth::Anonymous
140 }
141
142 AuthSource::EnvVar {
143 username_var,
144 password_var,
145 } => {
146 let username = std::env::var(username_var).unwrap_or_default();
147 let password = std::env::var(password_var).unwrap_or_default();
148
149 if !username.is_empty() && !password.is_empty() {
150 oci_client::secrets::RegistryAuth::Basic(username, password)
151 } else {
152 oci_client::secrets::RegistryAuth::Anonymous
153 }
154 }
155 }
156 }
157
158 fn extract_registry(image: &str) -> String {
165 let image_without_digest = image.split('@').next().unwrap_or(image);
167
168 let parts: Vec<&str> = image_without_digest.split('/').collect();
170
171 if parts.len() == 1 {
173 return "docker.io".to_string();
174 }
175
176 let first_part = parts[0];
178 if first_part.contains('.') || first_part.contains(':') || first_part == "localhost" {
179 first_part.to_string()
180 } else {
181 "docker.io".to_string()
183 }
184 }
185
186 fn load_docker_config(path: Option<&PathBuf>) -> Option<DockerConfigAuth> {
188 let config = if let Some(path) = path {
189 DockerConfigAuth::load_from_path(path).ok()
190 } else {
191 DockerConfigAuth::load().ok()
192 };
193
194 if config.is_none() {
195 tracing::debug!("Failed to load Docker config, using anonymous auth as fallback");
196 }
197
198 config
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_extract_registry() {
208 assert_eq!(AuthResolver::extract_registry("ubuntu"), "docker.io");
209 assert_eq!(AuthResolver::extract_registry("ubuntu:latest"), "docker.io");
210 assert_eq!(
211 AuthResolver::extract_registry("library/ubuntu"),
212 "docker.io"
213 );
214 assert_eq!(
215 AuthResolver::extract_registry("ghcr.io/owner/repo"),
216 "ghcr.io"
217 );
218 assert_eq!(
219 AuthResolver::extract_registry("ghcr.io/owner/repo:tag"),
220 "ghcr.io"
221 );
222 assert_eq!(
223 AuthResolver::extract_registry("localhost:5000/image"),
224 "localhost:5000"
225 );
226 assert_eq!(
227 AuthResolver::extract_registry("myregistry.com/path/to/image:v1.0"),
228 "myregistry.com"
229 );
230 }
231
232 #[test]
233 fn test_anonymous_auth() {
234 let config = AuthConfig {
235 default: AuthSource::Anonymous,
236 ..Default::default()
237 };
238
239 let resolver = AuthResolver::new(config);
240 let auth = resolver.resolve("ubuntu:latest");
241
242 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
243 }
244
245 #[test]
246 fn test_basic_auth() {
247 let config = AuthConfig {
248 default: AuthSource::Basic {
249 username: "user".to_string(),
250 password: "pass".to_string(),
251 },
252 ..Default::default()
253 };
254
255 let resolver = AuthResolver::new(config);
256 let auth = resolver.resolve("ubuntu:latest");
257
258 match auth {
259 oci_client::secrets::RegistryAuth::Basic(username, password) => {
260 assert_eq!(username, "user");
261 assert_eq!(password, "pass");
262 }
263 _ => panic!("Expected Basic auth"),
264 }
265 }
266
267 #[test]
268 fn test_per_registry_auth() {
269 let config = AuthConfig {
270 registries: vec![RegistryAuthConfig {
271 registry: "ghcr.io".to_string(),
272 source: AuthSource::Basic {
273 username: "ghcr_user".to_string(),
274 password: "ghcr_pass".to_string(),
275 },
276 }],
277 default: AuthSource::Anonymous,
278 ..Default::default()
279 };
280
281 let resolver = AuthResolver::new(config);
282
283 let auth = resolver.resolve("ghcr.io/owner/repo:tag");
285 match auth {
286 oci_client::secrets::RegistryAuth::Basic(username, password) => {
287 assert_eq!(username, "ghcr_user");
288 assert_eq!(password, "ghcr_pass");
289 }
290 _ => panic!("Expected Basic auth for ghcr.io"),
291 }
292
293 let auth = resolver.resolve("ubuntu:latest");
295 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
296 }
297
298 #[test]
299 fn test_env_var_auth() {
300 std::env::set_var("TEST_USERNAME", "env_user");
301 std::env::set_var("TEST_PASSWORD", "env_pass");
302
303 let config = AuthConfig {
304 default: AuthSource::EnvVar {
305 username_var: "TEST_USERNAME".to_string(),
306 password_var: "TEST_PASSWORD".to_string(),
307 },
308 ..Default::default()
309 };
310
311 let resolver = AuthResolver::new(config);
312 let auth = resolver.resolve("ubuntu:latest");
313
314 match auth {
315 oci_client::secrets::RegistryAuth::Basic(username, password) => {
316 assert_eq!(username, "env_user");
317 assert_eq!(password, "env_pass");
318 }
319 _ => panic!("Expected Basic auth from env vars"),
320 }
321
322 std::env::remove_var("TEST_USERNAME");
323 std::env::remove_var("TEST_PASSWORD");
324 }
325
326 #[test]
327 fn test_env_var_auth_fallback() {
328 let config = AuthConfig {
330 default: AuthSource::EnvVar {
331 username_var: "NONEXISTENT_USER".to_string(),
332 password_var: "NONEXISTENT_PASS".to_string(),
333 },
334 ..Default::default()
335 };
336
337 let resolver = AuthResolver::new(config);
338 let auth = resolver.resolve("ubuntu:latest");
339
340 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
341 }
342}