systemprompt_extension/
builder.rs1use std::marker::PhantomData;
2
3#[cfg(feature = "web")]
4use crate::any::ApiExtensionWrapper;
5use crate::any::{AnyExtension, ExtensionWrapper, SchemaExtensionWrapper};
6use crate::error::LoaderError;
7use crate::hlist::{Subset, TypeList};
8#[cfg(feature = "web")]
9use crate::typed::ApiExtensionTypedDyn;
10use crate::typed::SchemaExtensionTyped;
11use crate::typed_registry::TypedExtensionRegistry;
12use crate::types::{Dependencies, ExtensionType};
13
14pub struct ExtensionBuilder<Registered: TypeList = ()> {
15 extensions: Vec<Box<dyn AnyExtension>>,
16 _marker: PhantomData<Registered>,
17}
18
19impl<R: TypeList> std::fmt::Debug for ExtensionBuilder<R> {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("ExtensionBuilder")
22 .field("extension_count", &self.extensions.len())
23 .finish_non_exhaustive()
24 }
25}
26
27impl ExtensionBuilder<()> {
28 #[must_use]
29 pub fn new() -> Self {
30 Self {
31 extensions: Vec::new(),
32 _marker: PhantomData,
33 }
34 }
35}
36
37impl Default for ExtensionBuilder<()> {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl<R: TypeList> ExtensionBuilder<R> {
44 pub fn extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
45 where
46 E: ExtensionType + Dependencies + std::fmt::Debug + 'static,
47 E::Deps: Subset<R>,
48 {
49 self.extensions.push(Box::new(ExtensionWrapper::new(ext)));
50 ExtensionBuilder {
51 extensions: self.extensions,
52 _marker: PhantomData,
53 }
54 }
55
56 pub fn schema_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
57 where
58 E: ExtensionType + Dependencies + SchemaExtensionTyped + std::fmt::Debug + 'static,
59 E::Deps: Subset<R>,
60 {
61 self.extensions
62 .push(Box::new(SchemaExtensionWrapper::new(ext)));
63 ExtensionBuilder {
64 extensions: self.extensions,
65 _marker: PhantomData,
66 }
67 }
68
69 #[cfg(feature = "web")]
70 pub fn api_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
71 where
72 E: ExtensionType + Dependencies + ApiExtensionTypedDyn + std::fmt::Debug + 'static,
73 E::Deps: Subset<R>,
74 {
75 self.extensions
76 .push(Box::new(ApiExtensionWrapper::new(ext)));
77 ExtensionBuilder {
78 extensions: self.extensions,
79 _marker: PhantomData,
80 }
81 }
82
83 pub fn build(self) -> Result<TypedExtensionRegistry, LoaderError> {
84 let mut registry = TypedExtensionRegistry::new();
85 let mut sorted = self.extensions;
86 sorted.sort_by_key(|e| e.priority());
87
88 for ext in sorted {
89 if registry.has(ext.id()) {
90 return Err(LoaderError::DuplicateExtension(ext.id().to_string()));
91 }
92
93 #[cfg(feature = "web")]
94 if let Some(api) = ext.as_api() {
95 registry.validate_api_path(ext.id(), api.base_path())?;
96 }
97
98 registry.add_boxed(ext);
99 }
100
101 Ok(registry)
102 }
103}