Skip to main content

systemprompt_extension/
builder.rs

1use std::marker::PhantomData;
2
3#[cfg(feature = "web")]
4use crate::any::ApiExtensionWrapper;
5use crate::any::{AnyExtension, ExtensionWrapper, SchemaExtensionWrapper};
6use crate::error::LoaderError;
7use crate::hlist::{Subset, TypeList};
8#[cfg(feature = "web")]
9use crate::typed::ApiExtensionTypedDyn;
10use crate::typed::SchemaExtensionTyped;
11use crate::typed_registry::TypedExtensionRegistry;
12use crate::types::{Dependencies, ExtensionType};
13
14pub struct ExtensionBuilder<Registered: TypeList = ()> {
15    extensions: Vec<Box<dyn AnyExtension>>,
16    _marker: PhantomData<Registered>,
17}
18
19impl<R: TypeList> std::fmt::Debug for ExtensionBuilder<R> {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("ExtensionBuilder")
22            .field("extension_count", &self.extensions.len())
23            .finish_non_exhaustive()
24    }
25}
26
27impl ExtensionBuilder<()> {
28    #[must_use]
29    pub fn new() -> Self {
30        Self {
31            extensions: Vec::new(),
32            _marker: PhantomData,
33        }
34    }
35}
36
37impl Default for ExtensionBuilder<()> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl<R: TypeList> ExtensionBuilder<R> {
44    pub fn extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
45    where
46        E: ExtensionType + Dependencies + std::fmt::Debug + 'static,
47        E::Deps: Subset<R>,
48    {
49        self.extensions.push(Box::new(ExtensionWrapper::new(ext)));
50        ExtensionBuilder {
51            extensions: self.extensions,
52            _marker: PhantomData,
53        }
54    }
55
56    pub fn schema_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
57    where
58        E: ExtensionType + Dependencies + SchemaExtensionTyped + std::fmt::Debug + 'static,
59        E::Deps: Subset<R>,
60    {
61        self.extensions
62            .push(Box::new(SchemaExtensionWrapper::new(ext)));
63        ExtensionBuilder {
64            extensions: self.extensions,
65            _marker: PhantomData,
66        }
67    }
68
69    #[cfg(feature = "web")]
70    pub fn api_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
71    where
72        E: ExtensionType + Dependencies + ApiExtensionTypedDyn + std::fmt::Debug + 'static,
73        E::Deps: Subset<R>,
74    {
75        self.extensions
76            .push(Box::new(ApiExtensionWrapper::new(ext)));
77        ExtensionBuilder {
78            extensions: self.extensions,
79            _marker: PhantomData,
80        }
81    }
82
83    pub fn build(self) -> Result<TypedExtensionRegistry, LoaderError> {
84        let mut registry = TypedExtensionRegistry::new();
85        let mut sorted = self.extensions;
86        sorted.sort_by_key(|e| e.priority());
87
88        for ext in sorted {
89            if registry.has(ext.id()) {
90                return Err(LoaderError::DuplicateExtension(ext.id().to_string()));
91            }
92
93            #[cfg(feature = "web")]
94            if let Some(api) = ext.as_api() {
95                registry.validate_api_path(ext.id(), api.base_path())?;
96            }
97
98            registry.add_boxed(ext);
99        }
100
101        Ok(registry)
102    }
103}