systemprompt_api/services/gateway/
registry.rs1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3
4use super::protocol::outbound::anthropic::AnthropicOutbound;
5use super::protocol::outbound::gemini::GeminiOutbound;
6use super::protocol::outbound::openai_chat::OpenAiChatOutbound;
7use super::protocol::outbound::openai_responses::OpenAiResponsesOutbound;
8use super::protocol::outbound::{OutboundAdapter, OutboundAdapterRegistration};
9use systemprompt_ai::{HeuristicScanner, SafetyScanner, SafetyScannerRegistration};
10use systemprompt_models::profile::WireProtocol;
11
12pub struct GatewayUpstreamRegistry {
13 entries: HashMap<String, Arc<dyn OutboundAdapter>>,
14}
15
16impl std::fmt::Debug for GatewayUpstreamRegistry {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("GatewayUpstreamRegistry")
19 .field("tags", &self.tags())
20 .finish()
21 }
22}
23
24impl GatewayUpstreamRegistry {
25 pub fn global() -> &'static Self {
26 static REGISTRY: OnceLock<GatewayUpstreamRegistry> = OnceLock::new();
27 REGISTRY.get_or_init(Self::build)
28 }
29
30 pub fn get(&self, tag: &str) -> Option<&Arc<dyn OutboundAdapter>> {
31 self.entries.get(tag)
32 }
33
34 pub fn tags(&self) -> Vec<&str> {
35 self.entries.keys().map(String::as_str).collect()
36 }
37
38 pub(super) fn build() -> Self {
39 let mut entries: HashMap<String, Arc<dyn OutboundAdapter>> = HashMap::new();
40
41 entries.insert(
44 WireProtocol::Anthropic.as_tag().to_owned(),
45 Arc::new(AnthropicOutbound),
46 );
47 entries.insert(
48 WireProtocol::OpenAiChat.as_tag().to_owned(),
49 Arc::new(OpenAiChatOutbound),
50 );
51 entries.insert(
52 WireProtocol::OpenAiResponses.as_tag().to_owned(),
53 Arc::new(OpenAiResponsesOutbound),
54 );
55 entries.insert(
56 WireProtocol::Gemini.as_tag().to_owned(),
57 Arc::new(GeminiOutbound),
58 );
59
60 for registration in inventory::iter::<OutboundAdapterRegistration> {
61 let tag = registration.tag.to_owned();
62 if entries.contains_key(&tag) {
63 tracing::warn!(
64 tag = %registration.tag,
65 "Extension-registered gateway upstream shadows a built-in"
66 );
67 }
68 entries.insert(tag, (registration.factory)());
69 }
70
71 Self { entries }
72 }
73}
74
75pub struct SafetyScannerRegistry {
76 entries: HashMap<String, Arc<dyn SafetyScanner>>,
77}
78
79impl std::fmt::Debug for SafetyScannerRegistry {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("SafetyScannerRegistry")
82 .field("names", &self.names())
83 .finish()
84 }
85}
86
87impl SafetyScannerRegistry {
88 pub fn global() -> &'static Self {
89 static REGISTRY: OnceLock<SafetyScannerRegistry> = OnceLock::new();
90 REGISTRY.get_or_init(Self::build)
91 }
92
93 pub fn get(&self, name: &str) -> Option<&Arc<dyn SafetyScanner>> {
94 self.entries.get(name)
95 }
96
97 pub fn names(&self) -> Vec<&str> {
98 self.entries.keys().map(String::as_str).collect()
99 }
100
101 pub(super) fn build() -> Self {
102 let mut entries: HashMap<String, Arc<dyn SafetyScanner>> = HashMap::new();
103 entries.insert("heuristic".to_owned(), Arc::new(HeuristicScanner));
104
105 for registration in inventory::iter::<SafetyScannerRegistration> {
106 let name = registration.name.to_owned();
107 if entries.contains_key(&name) {
108 tracing::warn!(
109 name = %registration.name,
110 "Extension-registered safety scanner shadows a built-in"
111 );
112 }
113 entries.insert(name, (registration.factory)());
114 }
115
116 Self { entries }
117 }
118}