Skip to main content

systemprompt_api/services/gateway/
registry.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use super::protocol::outbound::anthropic::AnthropicOutbound;
5use super::protocol::outbound::gemini::GeminiOutbound;
6use super::protocol::outbound::openai_chat::OpenAiChatOutbound;
7use super::protocol::outbound::openai_responses::OpenAiResponsesOutbound;
8use super::protocol::outbound::{OutboundAdapter, OutboundAdapterRegistration};
9use systemprompt_ai::{HeuristicScanner, SafetyScanner, SafetyScannerRegistration};
10use systemprompt_models::profile::WireProtocol;
11
12pub struct GatewayUpstreamRegistry {
13    entries: HashMap<String, Arc<dyn OutboundAdapter>>,
14}
15
16impl std::fmt::Debug for GatewayUpstreamRegistry {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("GatewayUpstreamRegistry")
19            .field("tags", &self.tags())
20            .finish()
21    }
22}
23
24impl GatewayUpstreamRegistry {
25    pub fn global() -> &'static Self {
26        static REGISTRY: OnceLock<GatewayUpstreamRegistry> = OnceLock::new();
27        REGISTRY.get_or_init(Self::build)
28    }
29
30    pub fn get(&self, tag: &str) -> Option<&Arc<dyn OutboundAdapter>> {
31        self.entries.get(tag)
32    }
33
34    pub fn tags(&self) -> Vec<&str> {
35        self.entries.keys().map(String::as_str).collect()
36    }
37
38    pub(super) fn build() -> Self {
39        let mut entries: HashMap<String, Arc<dyn OutboundAdapter>> = HashMap::new();
40
41        // Outbound adapters are keyed on the WireProtocol tag, not the provider
42        // name: a ProviderEntry's `protocol` selects the wire codec.
43        entries.insert(
44            WireProtocol::Anthropic.as_tag().to_owned(),
45            Arc::new(AnthropicOutbound),
46        );
47        entries.insert(
48            WireProtocol::OpenAiChat.as_tag().to_owned(),
49            Arc::new(OpenAiChatOutbound),
50        );
51        entries.insert(
52            WireProtocol::OpenAiResponses.as_tag().to_owned(),
53            Arc::new(OpenAiResponsesOutbound),
54        );
55        entries.insert(
56            WireProtocol::Gemini.as_tag().to_owned(),
57            Arc::new(GeminiOutbound),
58        );
59
60        for registration in inventory::iter::<OutboundAdapterRegistration> {
61            let tag = registration.tag.to_owned();
62            if entries.contains_key(&tag) {
63                tracing::warn!(
64                    tag = %registration.tag,
65                    "Extension-registered gateway upstream shadows a built-in"
66                );
67            }
68            entries.insert(tag, (registration.factory)());
69        }
70
71        Self { entries }
72    }
73}
74
75pub struct SafetyScannerRegistry {
76    entries: HashMap<String, Arc<dyn SafetyScanner>>,
77}
78
79impl std::fmt::Debug for SafetyScannerRegistry {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        f.debug_struct("SafetyScannerRegistry")
82            .field("names", &self.names())
83            .finish()
84    }
85}
86
87impl SafetyScannerRegistry {
88    pub fn global() -> &'static Self {
89        static REGISTRY: OnceLock<SafetyScannerRegistry> = OnceLock::new();
90        REGISTRY.get_or_init(Self::build)
91    }
92
93    pub fn get(&self, name: &str) -> Option<&Arc<dyn SafetyScanner>> {
94        self.entries.get(name)
95    }
96
97    pub fn names(&self) -> Vec<&str> {
98        self.entries.keys().map(String::as_str).collect()
99    }
100
101    pub(super) fn build() -> Self {
102        let mut entries: HashMap<String, Arc<dyn SafetyScanner>> = HashMap::new();
103        entries.insert("heuristic".to_owned(), Arc::new(HeuristicScanner));
104
105        for registration in inventory::iter::<SafetyScannerRegistration> {
106            let name = registration.name.to_owned();
107            if entries.contains_key(&name) {
108                tracing::warn!(
109                    name = %registration.name,
110                    "Extension-registered safety scanner shadows a built-in"
111                );
112            }
113            entries.insert(name, (registration.factory)());
114        }
115
116        Self { entries }
117    }
118}