Skip to main content

systemprompt_extension/registry/
validation.rs

1use super::ExtensionRegistry;
2use crate::Extension;
3use crate::error::LoaderError;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7pub const RESERVED_PATHS: &[&str] = &[
8    "/api/v1/oauth",
9    "/api/v1/users",
10    "/api/v1/agents",
11    "/api/v1/mcp",
12    "/api/v1/stream",
13    "/api/v1/content",
14    "/api/v1/files",
15    "/api/v1/analytics",
16    "/api/v1/scheduler",
17    "/api/v1/core",
18    "/api/v1/admin",
19    "/.well-known",
20];
21
22impl ExtensionRegistry {
23    pub fn validate_dependencies(&self) -> Result<(), LoaderError> {
24        for ext in self.extensions.values() {
25            for dep_id in ext.dependencies() {
26                if !self.extensions.contains_key(dep_id) {
27                    return Err(LoaderError::MissingDependency {
28                        extension: ext.id().to_string(),
29                        dependency: dep_id.to_string(),
30                    });
31                }
32            }
33        }
34
35        detect_cycles(&self.extensions)
36    }
37
38    #[cfg(feature = "web")]
39    pub fn validate_api_paths(&self, ctx: &dyn crate::ExtensionContext) -> Result<(), LoaderError> {
40        for ext in self.extensions.values() {
41            if let Some(router_config) = ext.router(ctx) {
42                let base_path = router_config.base_path;
43
44                if !base_path.starts_with("/api/") {
45                    return Err(LoaderError::InvalidBasePath {
46                        extension: ext.id().to_string(),
47                        path: base_path.to_string(),
48                    });
49                }
50
51                for reserved in RESERVED_PATHS {
52                    if base_path.starts_with(reserved) {
53                        return Err(LoaderError::ReservedPathCollision {
54                            extension: ext.id().to_string(),
55                            path: base_path.to_string(),
56                        });
57                    }
58                }
59            }
60        }
61        Ok(())
62    }
63
64    #[cfg(not(feature = "web"))]
65    pub fn validate_api_paths(
66        &self,
67        _ctx: &dyn crate::ExtensionContext,
68    ) -> Result<(), LoaderError> {
69        Ok(())
70    }
71}
72
73fn detect_cycles(extensions: &HashMap<String, Arc<dyn Extension>>) -> Result<(), LoaderError> {
74    const WHITE: u8 = 0;
75    const GRAY: u8 = 1;
76    const BLACK: u8 = 2;
77
78    fn dfs<'a>(
79        node: &'a str,
80        extensions: &'a HashMap<String, Arc<dyn Extension>>,
81        color: &mut HashMap<&'a str, u8>,
82        path: &mut Vec<&'a str>,
83    ) -> Result<(), Vec<&'a str>> {
84        color.insert(node, GRAY);
85        path.push(node);
86
87        if let Some(ext) = extensions.get(node) {
88            for dep_id in ext.dependencies() {
89                match color.get(dep_id) {
90                    Some(&GRAY) => {
91                        path.push(dep_id);
92                        return Err(path.clone());
93                    },
94                    Some(&WHITE) | None => {
95                        dfs(dep_id, extensions, color, path)?;
96                    },
97                    _ => {},
98                }
99            }
100        }
101
102        path.pop();
103        color.insert(node, BLACK);
104        Ok(())
105    }
106
107    let mut color: HashMap<&str, u8> = extensions.keys().map(|id| (id.as_str(), WHITE)).collect();
108
109    let mut path = Vec::new();
110    for id in extensions.keys() {
111        if color.get(id.as_str()) == Some(&WHITE) {
112            if let Err(cycle_path) = dfs(id.as_str(), extensions, &mut color, &mut path) {
113                let Some(&cycle_start) = cycle_path.last() else {
114                    return Err(LoaderError::CircularDependency {
115                        chain: "unknown cycle".to_string(),
116                    });
117                };
118                let cycle_start_idx = cycle_path
119                    .iter()
120                    .position(|&x| x == cycle_start)
121                    .unwrap_or(0);
122                let cycle: Vec<_> = cycle_path[cycle_start_idx..].to_vec();
123
124                return Err(LoaderError::CircularDependency {
125                    chain: cycle.join(" -> "),
126                });
127            }
128        }
129    }
130
131    Ok(())
132}