Skip to main content

systemprompt_extension/registry/
mod.rs

1//! Dynamic extension registry that stores extensions as `Arc<dyn
2//! Extension>`.
3//!
4//! The dynamic registry is the lower-level counterpart of
5//! [`crate::TypedExtensionRegistry`]: it accepts `Arc<dyn Extension>`
6//! values supplied by either inventory discovery or runtime injection.
7
8mod discovery;
9mod queries;
10mod validation;
11
12use crate::Extension;
13use crate::error::LoaderError;
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16use tracing::warn;
17
18pub use validation::RESERVED_PATHS;
19
20#[derive(Default)]
21pub struct ExtensionRegistry {
22    pub(crate) extensions: HashMap<String, Arc<dyn Extension>>,
23    pub(crate) sorted_extensions: Vec<Arc<dyn Extension>>,
24}
25
26impl std::fmt::Debug for ExtensionRegistry {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("ExtensionRegistry")
29            .field("extension_count", &self.extensions.len())
30            .finish_non_exhaustive()
31    }
32}
33
34impl ExtensionRegistry {
35    #[must_use]
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Topologically order extensions by [`Extension::dependencies`], breaking
41    /// ties with [`Extension::priority`] (lower runs first).
42    ///
43    /// Missing dependencies are warned and ignored — an extension may
44    /// optionally depend on another that was not loaded in this build. A
45    /// dependency cycle returns [`LoaderError::DependencyCycle`] with a
46    /// human-readable chain (`"A -> B -> A"`).
47    pub(crate) fn sort_by_priority(&mut self) -> Result<(), LoaderError> {
48        let ids: Vec<String> = self
49            .sorted_extensions
50            .iter()
51            .map(|e| e.id().to_string())
52            .collect();
53        let id_set: HashSet<&str> = ids.iter().map(String::as_str).collect();
54
55        let mut by_id: HashMap<String, Arc<dyn Extension>> = HashMap::new();
56        for ext in self.sorted_extensions.drain(..) {
57            by_id.insert(ext.id().to_string(), ext);
58        }
59
60        for (owner, ext) in &by_id {
61            for dep in ext.dependencies() {
62                if !id_set.contains(dep) {
63                    warn!(
64                        extension = %owner,
65                        missing_dependency = %dep,
66                        "Extension declares dependency that is not loaded; treating as optional \
67                         and ignoring for ordering"
68                    );
69                }
70            }
71        }
72
73        let order = topo_sort(&ids, &by_id)?;
74
75        self.sorted_extensions = order
76            .into_iter()
77            .filter_map(|id| by_id.remove(&id))
78            .collect();
79        Ok(())
80    }
81
82    pub fn register(&mut self, ext: Arc<dyn Extension>) -> Result<(), LoaderError> {
83        let id = ext.id().to_string();
84        if self.extensions.contains_key(&id) {
85            return Err(LoaderError::DuplicateExtension(id));
86        }
87        self.extensions.insert(id, Arc::clone(&ext));
88        self.sorted_extensions.push(ext);
89        self.sort_by_priority()?;
90        Ok(())
91    }
92
93    pub fn merge(&mut self, extensions: Vec<Arc<dyn Extension>>) -> Result<(), LoaderError> {
94        for ext in extensions {
95            self.register(ext)?;
96        }
97        Ok(())
98    }
99
100    pub fn validate(&self) -> Result<(), LoaderError> {
101        self.validate_dependencies()?;
102        Ok(())
103    }
104
105    #[must_use]
106    pub fn len(&self) -> usize {
107        self.extensions.len()
108    }
109
110    #[must_use]
111    pub fn is_empty(&self) -> bool {
112        self.extensions.is_empty()
113    }
114}
115
116fn topo_sort(
117    ids: &[String],
118    by_id: &HashMap<String, Arc<dyn Extension>>,
119) -> Result<Vec<String>, LoaderError> {
120    const WHITE: u8 = 0;
121    const GRAY: u8 = 1;
122    const BLACK: u8 = 2;
123
124    fn visit(
125        node: &str,
126        by_id: &HashMap<String, Arc<dyn Extension>>,
127        color: &mut HashMap<String, u8>,
128        path: &mut Vec<String>,
129        out: &mut Vec<String>,
130    ) -> Result<(), LoaderError> {
131        let state = color.get(node).copied().unwrap_or(WHITE);
132        if state == BLACK {
133            return Ok(());
134        }
135        if state == GRAY {
136            let cycle_start = path.iter().position(|p| p == node).unwrap_or(0);
137            let mut chain: Vec<String> = path[cycle_start..].to_vec();
138            chain.push(node.to_string());
139            return Err(LoaderError::DependencyCycle {
140                chain: chain.join(" -> "),
141            });
142        }
143        color.insert(node.to_string(), GRAY);
144        path.push(node.to_string());
145
146        if let Some(ext) = by_id.get(node) {
147            let mut deps: Vec<&'static str> = ext
148                .dependencies()
149                .into_iter()
150                .filter(|d| by_id.contains_key(*d))
151                .collect();
152            deps.sort_by_key(|d| {
153                by_id.get(*d).map_or((u32::MAX, String::new()), |e| {
154                    (e.priority(), e.id().to_string())
155                })
156            });
157            for dep in deps {
158                visit(dep, by_id, color, path, out)?;
159            }
160        }
161
162        path.pop();
163        color.insert(node.to_string(), BLACK);
164        out.push(node.to_string());
165        Ok(())
166    }
167
168    let mut roots: Vec<&String> = ids.iter().collect();
169    roots.sort_by_key(|id| {
170        by_id.get(*id).map_or((u32::MAX, String::new()), |e| {
171            (e.priority(), e.id().to_string())
172        })
173    });
174
175    let mut color: HashMap<String, u8> = HashMap::with_capacity(ids.len());
176    let mut path: Vec<String> = Vec::new();
177    let mut out: Vec<String> = Vec::with_capacity(ids.len());
178    for id in roots {
179        visit(id, by_id, &mut color, &mut path, &mut out)?;
180    }
181    Ok(out)
182}
183
184#[derive(Debug, Clone, Copy)]
185pub struct ExtensionRegistration {
186    pub factory: fn() -> Arc<dyn Extension>,
187}
188
189inventory::collect!(ExtensionRegistration);