1use std::collections::{HashMap, HashSet};
16
17use crate::context::ExecutionContext;
18use crate::environment::EnvironmentConfig;
19use crate::mounts::Mount;
20use crate::resources::ResourceConfig;
21use crate::runtime::RuntimeOverrides;
22use crate::secrets::SecretsConfig;
23use crate::ContextError;
24
25pub struct ContextResolver<F> {
30 loader: F,
32 cache: HashMap<String, ExecutionContext>,
34 resolving: HashSet<String>,
36}
37
38impl<F> ContextResolver<F>
39where
40 F: Fn(&str) -> Result<ExecutionContext, ContextError>,
41{
42 pub fn new(loader: F) -> Self {
44 Self {
45 loader,
46 cache: HashMap::new(),
47 resolving: HashSet::new(),
48 }
49 }
50
51 pub fn resolve(&mut self, context: &ExecutionContext) -> Result<ExecutionContext, ContextError> {
59 if let Some(cached) = self.cache.get(&context.id) {
61 return Ok(cached.clone());
62 }
63
64 if self.resolving.contains(&context.id) {
66 return Err(ContextError::CircularInheritance(format!(
67 "Circular inheritance detected involving context '{}'",
68 context.id
69 )));
70 }
71
72 self.resolving.insert(context.id.clone());
74
75 let resolved = if let Some(ref parent_id) = context.inherits_from {
76 let resolved_parent = if let Some(cached_parent) = self.cache.get(parent_id) {
78 cached_parent.clone()
79 } else {
80 let parent = (self.loader)(parent_id).map_err(|_| {
82 ContextError::ParentNotFound(format!(
83 "Parent context '{}' not found for context '{}'",
84 parent_id, context.id
85 ))
86 })?;
87
88 self.resolve(&parent)?
89 };
90
91 self.merge_contexts(&resolved_parent, context)
93 } else {
94 context.clone()
96 };
97
98 self.resolving.remove(&context.id);
100
101 self.cache.insert(context.id.clone(), resolved.clone());
103
104 Ok(resolved)
105 }
106
107 pub fn clear_cache(&mut self) {
109 self.cache.clear();
110 }
111
112 pub fn invalidate(&mut self, context_id: &str) {
114 self.cache.remove(context_id);
115 self.cache.clear();
118 }
119
120 fn merge_contexts(&self, parent: &ExecutionContext, child: &ExecutionContext) -> ExecutionContext {
122 ExecutionContext {
123 id: child.id.clone(),
125 name: child.name.clone(),
126 description: child.description.clone().or_else(|| parent.description.clone()),
127 inherits_from: child.inherits_from.clone(),
128
129 mounts: merge_mounts(&parent.mounts, &child.mounts),
131 environment: merge_environments(&parent.environment, &child.environment),
132 secrets: merge_secrets(&parent.secrets, &child.secrets),
133 resources: merge_resources(&parent.resources, &child.resources),
134 runtime_overrides: merge_runtime_overrides(
135 parent.runtime_overrides.as_ref(),
136 child.runtime_overrides.as_ref(),
137 ),
138
139 metadata: child.metadata.clone(),
141 }
142 }
143}
144
145pub fn resolve_context<F>(
149 context: &ExecutionContext,
150 loader: F,
151) -> Result<ExecutionContext, ContextError>
152where
153 F: Fn(&str) -> Result<ExecutionContext, ContextError>,
154{
155 let mut resolver = ContextResolver::new(loader);
156 resolver.resolve(context)
157}
158
159pub fn merge_mounts(parent: &[Mount], child: &[Mount]) -> Vec<Mount> {
163 let mut result: HashMap<String, Mount> = parent
164 .iter()
165 .map(|m| (m.id.clone(), m.clone()))
166 .collect();
167
168 for mount in child {
170 result.insert(mount.id.clone(), mount.clone());
171 }
172
173 result.into_values().collect()
174}
175
176pub fn merge_environments(parent: &EnvironmentConfig, child: &EnvironmentConfig) -> EnvironmentConfig {
181 let mut variables = parent.variables.clone();
182 for (key, value) in &child.variables {
183 variables.insert(key.clone(), value.clone());
184 }
185
186 let mut env_files = parent.env_files.clone();
188 for file in &child.env_files {
189 if !env_files.iter().any(|f| f.path == file.path) {
190 env_files.push(file.clone());
191 }
192 }
193
194 let mut passthrough_prefixes: Vec<String> = parent.passthrough_prefixes.clone();
195 for prefix in &child.passthrough_prefixes {
196 if !passthrough_prefixes.contains(prefix) {
197 passthrough_prefixes.push(prefix.clone());
198 }
199 }
200
201 let mut passthrough_vars: Vec<String> = parent.passthrough_vars.clone();
202 for var in &child.passthrough_vars {
203 if !passthrough_vars.contains(var) {
204 passthrough_vars.push(var.clone());
205 }
206 }
207
208 EnvironmentConfig {
209 variables,
210 env_files,
211 passthrough_prefixes,
212 passthrough_vars,
213 }
214}
215
216pub fn merge_secrets(parent: &SecretsConfig, child: &SecretsConfig) -> SecretsConfig {
221 let mut secrets = parent.secrets.clone();
222 for (key, def) in &child.secrets {
223 secrets.insert(key.clone(), def.clone());
224 }
225
226 let mut providers = child.providers.clone();
228 for provider in &parent.providers {
229 let parent_name = provider.name();
231 if !providers.iter().any(|p| p.name() == parent_name) {
232 providers.push(provider.clone());
233 }
234 }
235
236 SecretsConfig { secrets, providers }
237}
238
239pub fn merge_resources(parent: &ResourceConfig, child: &ResourceConfig) -> ResourceConfig {
243 ResourceConfig {
244 cpu: child.cpu.clone().or_else(|| parent.cpu.clone()),
245 memory: child.memory.clone().or_else(|| parent.memory.clone()),
246 network: merge_network_config(&parent.network, &child.network),
247 filesystem: merge_filesystem_config(&parent.filesystem, &child.filesystem),
248 execution: merge_execution_limits(&parent.execution, &child.execution),
249 }
250}
251
252fn merge_network_config(
254 parent: &crate::resources::NetworkConfig,
255 child: &crate::resources::NetworkConfig,
256) -> crate::resources::NetworkConfig {
257 crate::resources::NetworkConfig {
258 enabled: child.enabled || parent.enabled,
260 mode: child.mode.clone().or_else(|| parent.mode.clone()),
261 allowed_hosts: match (&parent.allowed_hosts, &child.allowed_hosts) {
263 (Some(p), Some(c)) => {
264 let mut hosts = p.clone();
265 for h in c {
266 if !hosts.contains(h) {
267 hosts.push(h.clone());
268 }
269 }
270 Some(hosts)
271 }
272 (None, Some(c)) => Some(c.clone()),
273 (Some(p), None) => Some(p.clone()),
274 (None, None) => None,
275 },
276 blocked_hosts: match (&parent.blocked_hosts, &child.blocked_hosts) {
277 (Some(p), Some(c)) => {
278 let mut hosts = p.clone();
279 for h in c {
280 if !hosts.contains(h) {
281 hosts.push(h.clone());
282 }
283 }
284 Some(hosts)
285 }
286 (None, Some(c)) => Some(c.clone()),
287 (Some(p), None) => Some(p.clone()),
288 (None, None) => None,
289 },
290 dns: child.dns.clone().or_else(|| parent.dns.clone()),
291 }
292}
293
294fn merge_filesystem_config(
296 parent: &crate::resources::FilesystemConfig,
297 child: &crate::resources::FilesystemConfig,
298) -> crate::resources::FilesystemConfig {
299 let mut writable_paths = parent.writable_paths.clone();
301 for path in &child.writable_paths {
302 if !writable_paths.contains(path) {
303 writable_paths.push(path.clone());
304 }
305 }
306
307 crate::resources::FilesystemConfig {
308 read_only_root: child.read_only_root || parent.read_only_root,
309 writable_paths,
310 max_file_size: child
311 .max_file_size
312 .clone()
313 .or_else(|| parent.max_file_size.clone()),
314 max_disk_usage: child
315 .max_disk_usage
316 .clone()
317 .or_else(|| parent.max_disk_usage.clone()),
318 }
319}
320
321fn merge_execution_limits(
323 parent: &crate::resources::ExecutionLimits,
324 child: &crate::resources::ExecutionLimits,
325) -> crate::resources::ExecutionLimits {
326 crate::resources::ExecutionLimits {
327 timeout_seconds: child.timeout_seconds.or(parent.timeout_seconds),
328 max_concurrent: child.max_concurrent.or(parent.max_concurrent),
329 rate_limit: child.rate_limit.clone().or_else(|| parent.rate_limit.clone()),
330 }
331}
332
333fn merge_runtime_overrides(
335 parent: Option<&RuntimeOverrides>,
336 child: Option<&RuntimeOverrides>,
337) -> Option<RuntimeOverrides> {
338 match (parent, child) {
339 (None, None) => None,
340 (Some(p), None) => Some(p.clone()),
341 (None, Some(c)) => Some(c.clone()),
342 (Some(p), Some(c)) => Some(RuntimeOverrides {
343 wasm: merge_wasm_overrides(p.wasm.as_ref(), c.wasm.as_ref()),
344 docker: merge_docker_overrides(p.docker.as_ref(), c.docker.as_ref()),
345 native: merge_native_overrides(p.native.as_ref(), c.native.as_ref()),
346 }),
347 }
348}
349
350fn merge_wasm_overrides(
352 parent: Option<&crate::runtime::WasmOverrides>,
353 child: Option<&crate::runtime::WasmOverrides>,
354) -> Option<crate::runtime::WasmOverrides> {
355 match (parent, child) {
356 (None, None) => None,
357 (Some(p), None) => Some(p.clone()),
358 (None, Some(c)) => Some(c.clone()),
359 (Some(p), Some(c)) => {
360 let mut wasi_capabilities = p.wasi_capabilities.clone();
361 for (key, value) in &c.wasi_capabilities {
362 wasi_capabilities.insert(key.clone(), *value);
363 }
364
365 Some(crate::runtime::WasmOverrides {
366 stack_size: c.stack_size.or(p.stack_size),
367 wasi_capabilities,
368 fuel_limit: c.fuel_limit.or(p.fuel_limit),
369 epoch_interruption: c.epoch_interruption.or(p.epoch_interruption),
370 max_memory_pages: c.max_memory_pages.or(p.max_memory_pages),
371 debug_info: c.debug_info || p.debug_info,
372 })
373 }
374 }
375}
376
377fn merge_docker_overrides(
379 parent: Option<&crate::runtime::DockerOverrides>,
380 child: Option<&crate::runtime::DockerOverrides>,
381) -> Option<crate::runtime::DockerOverrides> {
382 match (parent, child) {
383 (None, None) => None,
384 (Some(p), None) => Some(p.clone()),
385 (None, Some(c)) => Some(c.clone()),
386 (Some(p), Some(c)) => {
387 let mut extra_args = p.extra_args.clone();
389 extra_args.extend(c.extra_args.clone());
390
391 let mut security_opt = p.security_opt.clone();
393 for opt in &c.security_opt {
394 if !security_opt.contains(opt) {
395 security_opt.push(opt.clone());
396 }
397 }
398
399 let mut sysctls = p.sysctls.clone();
401 for (key, value) in &c.sysctls {
402 sysctls.insert(key.clone(), value.clone());
403 }
404
405 let mut labels = p.labels.clone();
407 for (key, value) in &c.labels {
408 labels.insert(key.clone(), value.clone());
409 }
410
411 let mut cap_add = p.cap_add.clone();
413 for cap in &c.cap_add {
414 if !cap_add.contains(cap) {
415 cap_add.push(cap.clone());
416 }
417 }
418
419 let mut cap_drop = p.cap_drop.clone();
420 for cap in &c.cap_drop {
421 if !cap_drop.contains(cap) {
422 cap_drop.push(cap.clone());
423 }
424 }
425
426 Some(crate::runtime::DockerOverrides {
427 image: c.image.clone().or_else(|| p.image.clone()),
428 extra_args,
429 entrypoint: c.entrypoint.clone().or_else(|| p.entrypoint.clone()),
430 command: c.command.clone().or_else(|| p.command.clone()),
431 user: c.user.clone().or_else(|| p.user.clone()),
432 gpus: c.gpus.clone().or_else(|| p.gpus.clone()),
433 platform: c.platform.clone().or_else(|| p.platform.clone()),
434 privileged: c.privileged || p.privileged,
435 security_opt,
436 sysctls,
437 labels,
438 restart: c.restart.clone().or_else(|| p.restart.clone()),
439 rm: c.rm && p.rm, init: c.init || p.init,
441 hostname: c.hostname.clone().or_else(|| p.hostname.clone()),
442 ipc: c.ipc.clone().or_else(|| p.ipc.clone()),
443 pid: c.pid.clone().or_else(|| p.pid.clone()),
444 cap_add,
445 cap_drop,
446 })
447 }
448 }
449}
450
451fn merge_native_overrides(
453 parent: Option<&crate::runtime::NativeOverrides>,
454 child: Option<&crate::runtime::NativeOverrides>,
455) -> Option<crate::runtime::NativeOverrides> {
456 match (parent, child) {
457 (None, None) => None,
458 (Some(p), None) => Some(p.clone()),
459 (None, Some(c)) => Some(c.clone()),
460 (Some(p), Some(c)) => {
461 let mut path_additions = p.path_additions.clone();
463 for path in &c.path_additions {
464 if !path_additions.contains(path) {
465 path_additions.push(path.clone());
466 }
467 }
468
469 Some(crate::runtime::NativeOverrides {
470 working_dir: c.working_dir.clone().or_else(|| p.working_dir.clone()),
471 shell: c.shell.clone().or_else(|| p.shell.clone()),
472 path_additions,
473 run_as: c.run_as.clone().or_else(|| p.run_as.clone()),
474 clear_env: c.clear_env || p.clear_env,
475 inherit_env: c.inherit_env && p.inherit_env,
476 })
477 }
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::environment::EnvValue;
485 use crate::resources::{CpuConfig, MemoryConfig, NetworkConfig};
486 use crate::secrets::SecretDefinition;
487
488 #[test]
489 fn test_simple_inheritance() {
490 let parent = ExecutionContext::new("parent", "Parent")
491 .with_environment(EnvironmentConfig::new().with_var("PARENT_VAR", "parent_value"));
492
493 let child = ExecutionContext::inheriting("child", "Child", "parent")
494 .with_environment(EnvironmentConfig::new().with_var("CHILD_VAR", "child_value"));
495
496 let contexts: HashMap<String, ExecutionContext> =
497 [("parent".to_string(), parent)].into_iter().collect();
498
499 let resolved = resolve_context(&child, |id| {
500 contexts
501 .get(id)
502 .cloned()
503 .ok_or_else(|| ContextError::NotFound(id.to_string()))
504 })
505 .unwrap();
506
507 assert!(resolved.environment.variables.contains_key("PARENT_VAR"));
509 assert!(resolved.environment.variables.contains_key("CHILD_VAR"));
510 }
511
512 #[test]
513 fn test_child_overrides_parent() {
514 let parent = ExecutionContext::new("parent", "Parent")
515 .with_environment(EnvironmentConfig::new().with_var("SHARED_VAR", "parent_value"));
516
517 let child = ExecutionContext::inheriting("child", "Child", "parent")
518 .with_environment(EnvironmentConfig::new().with_var("SHARED_VAR", "child_value"));
519
520 let contexts: HashMap<String, ExecutionContext> =
521 [("parent".to_string(), parent)].into_iter().collect();
522
523 let resolved = resolve_context(&child, |id| {
524 contexts
525 .get(id)
526 .cloned()
527 .ok_or_else(|| ContextError::NotFound(id.to_string()))
528 })
529 .unwrap();
530
531 match resolved.environment.variables.get("SHARED_VAR") {
533 Some(EnvValue::Plain(v)) => assert_eq!(v, "child_value"),
534 _ => panic!("Expected plain value"),
535 }
536 }
537
538 #[test]
539 fn test_multi_level_inheritance() {
540 let base = ExecutionContext::new("base", "Base")
541 .with_environment(EnvironmentConfig::new().with_var("BASE_VAR", "base"));
542
543 let middle = ExecutionContext::inheriting("middle", "Middle", "base")
544 .with_environment(EnvironmentConfig::new().with_var("MIDDLE_VAR", "middle"));
545
546 let child = ExecutionContext::inheriting("child", "Child", "middle")
547 .with_environment(EnvironmentConfig::new().with_var("CHILD_VAR", "child"));
548
549 let contexts: HashMap<String, ExecutionContext> = [
550 ("base".to_string(), base),
551 ("middle".to_string(), middle),
552 ]
553 .into_iter()
554 .collect();
555
556 let resolved = resolve_context(&child, |id| {
557 contexts
558 .get(id)
559 .cloned()
560 .ok_or_else(|| ContextError::NotFound(id.to_string()))
561 })
562 .unwrap();
563
564 assert!(resolved.environment.variables.contains_key("BASE_VAR"));
566 assert!(resolved.environment.variables.contains_key("MIDDLE_VAR"));
567 assert!(resolved.environment.variables.contains_key("CHILD_VAR"));
568 }
569
570 #[test]
571 fn test_circular_inheritance_detection() {
572 let ctx_a = ExecutionContext::inheriting("a", "Context A", "b");
573 let ctx_b = ExecutionContext::inheriting("b", "Context B", "a");
574
575 let contexts: HashMap<String, ExecutionContext> = [
576 ("a".to_string(), ctx_a.clone()),
577 ("b".to_string(), ctx_b),
578 ]
579 .into_iter()
580 .collect();
581
582 let result = resolve_context(&ctx_a, |id| {
583 contexts
584 .get(id)
585 .cloned()
586 .ok_or_else(|| ContextError::NotFound(id.to_string()))
587 });
588
589 assert!(matches!(result, Err(ContextError::CircularInheritance(_))));
590 }
591
592 #[test]
593 fn test_missing_parent() {
594 let child = ExecutionContext::inheriting("child", "Child", "nonexistent");
595
596 let result = resolve_context(&child, |_| Err(ContextError::NotFound("not found".into())));
597
598 assert!(matches!(result, Err(ContextError::ParentNotFound(_))));
599 }
600
601 #[test]
602 fn test_mount_merge() {
603 let parent_mounts = vec![
604 Mount::directory("data", "/parent/data", "/data"),
605 Mount::directory("config", "/parent/config", "/config"),
606 ];
607
608 let child_mounts = vec![
609 Mount::directory("config", "/child/config", "/config"), Mount::directory("logs", "/child/logs", "/logs"), ];
612
613 let merged = merge_mounts(&parent_mounts, &child_mounts);
614
615 assert_eq!(merged.len(), 3);
616
617 let config_mount = merged.iter().find(|m| m.id == "config").unwrap();
619 assert_eq!(config_mount.source, "/child/config");
620 }
621
622 #[test]
623 fn test_secrets_merge() {
624 let parent_secrets = SecretsConfig::new()
625 .with_secret("parent-key", SecretDefinition::required("parent-key"))
626 .with_secret("shared-key", SecretDefinition::required("shared-key"));
627
628 let child_secrets = SecretsConfig::new()
629 .with_secret(
630 "shared-key",
631 SecretDefinition::optional("shared-key"), )
633 .with_secret("child-key", SecretDefinition::required("child-key"));
634
635 let merged = merge_secrets(&parent_secrets, &child_secrets);
636
637 assert_eq!(merged.secrets.len(), 3);
638 assert!(merged.secrets.get("parent-key").unwrap().required);
639 assert!(!merged.secrets.get("shared-key").unwrap().required); assert!(merged.secrets.get("child-key").unwrap().required);
641 }
642
643 #[test]
644 fn test_resources_merge() {
645 let parent_resources = ResourceConfig::new()
646 .with_cpu(CpuConfig::new("2"))
647 .with_memory(MemoryConfig::new("1g"))
648 .with_network_enabled();
649
650 let child_resources = ResourceConfig::new().with_memory(MemoryConfig::new("2g")); let merged = merge_resources(&parent_resources, &child_resources);
653
654 assert_eq!(merged.cpu.as_ref().unwrap().limit, "2");
656 assert_eq!(merged.memory.as_ref().unwrap().limit, "2g");
658 assert!(merged.network.enabled);
660 }
661
662 #[test]
663 fn test_network_hosts_merge() {
664 let parent_network = NetworkConfig::enabled()
665 .allow_host("parent.example.com")
666 .block_host("blocked.example.com");
667
668 let child_network = NetworkConfig::enabled()
669 .allow_host("child.example.com")
670 .allow_host("parent.example.com"); let merged = merge_network_config(&parent_network, &child_network);
673
674 let allowed = merged.allowed_hosts.unwrap();
675 assert_eq!(allowed.len(), 2); assert!(allowed.contains(&"parent.example.com".to_string()));
677 assert!(allowed.contains(&"child.example.com".to_string()));
678 }
679
680 #[test]
681 fn test_no_inheritance() {
682 let context = ExecutionContext::new("standalone", "Standalone")
683 .with_environment(EnvironmentConfig::new().with_var("VAR", "value"));
684
685 let resolved = resolve_context(&context, |_| {
686 Err(ContextError::NotFound("should not be called".into()))
687 })
688 .unwrap();
689
690 assert_eq!(resolved.id, "standalone");
691 assert!(resolved.environment.variables.contains_key("VAR"));
692 }
693
694 #[test]
695 fn test_resolver_cache() {
696 let call_count = std::cell::RefCell::new(0);
697
698 let parent = ExecutionContext::new("parent", "Parent");
699 let child1 = ExecutionContext::inheriting("child1", "Child 1", "parent");
700 let child2 = ExecutionContext::inheriting("child2", "Child 2", "parent");
701
702 let contexts: HashMap<String, ExecutionContext> =
703 [("parent".to_string(), parent)].into_iter().collect();
704
705 let mut resolver = ContextResolver::new(|id| {
706 *call_count.borrow_mut() += 1;
707 contexts
708 .get(id)
709 .cloned()
710 .ok_or_else(|| ContextError::NotFound(id.to_string()))
711 });
712
713 resolver.resolve(&child1).unwrap();
715 assert_eq!(*call_count.borrow(), 1);
716
717 resolver.resolve(&child2).unwrap();
719 assert_eq!(*call_count.borrow(), 1);
721 }
722}