Skip to main content

systemprompt_runtime/
context.rs

1use crate::registry::ModuleApiRegistry;
2use anyhow::Result;
3use std::sync::Arc;
4use systemprompt_analytics::{AnalyticsService, FingerprintRepository, GeoIpReader};
5use systemprompt_database::{Database, DbPool};
6use systemprompt_extension::{Extension, ExtensionContext, ExtensionRegistry};
7use systemprompt_logging::CliService;
8use systemprompt_models::{
9    AppPaths, Config, ContentConfigRaw, ContentRouting, ProfileBootstrap, RouteClassifier,
10};
11use systemprompt_traits::{
12    AnalyticsProvider, AppContext as AppContextTrait, ConfigProvider, DatabaseHandle,
13    FingerprintProvider, UserProvider,
14};
15use systemprompt_users::UserService;
16
17#[derive(Clone)]
18pub struct AppContext {
19    config: Arc<Config>,
20    database: DbPool,
21    api_registry: Arc<ModuleApiRegistry>,
22    extension_registry: Arc<ExtensionRegistry>,
23    geoip_reader: Option<GeoIpReader>,
24    content_config: Option<Arc<ContentConfigRaw>>,
25    route_classifier: Arc<RouteClassifier>,
26    analytics_service: Arc<AnalyticsService>,
27    fingerprint_repo: Option<Arc<FingerprintRepository>>,
28    user_service: Option<Arc<UserService>>,
29}
30
31impl std::fmt::Debug for AppContext {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("AppContext")
34            .field("config", &"Config")
35            .field("database", &"DbPool")
36            .field("api_registry", &"ModuleApiRegistry")
37            .field("extension_registry", &self.extension_registry)
38            .field("geoip_reader", &self.geoip_reader.is_some())
39            .field("content_config", &self.content_config.is_some())
40            .field("route_classifier", &"RouteClassifier")
41            .field("analytics_service", &"AnalyticsService")
42            .field("fingerprint_repo", &self.fingerprint_repo.is_some())
43            .field("user_service", &self.user_service.is_some())
44            .finish()
45    }
46}
47
48impl AppContext {
49    pub async fn new() -> Result<Self> {
50        Self::builder().build().await
51    }
52
53    #[must_use]
54    pub fn builder() -> AppContextBuilder {
55        AppContextBuilder::new()
56    }
57
58    async fn new_internal(
59        extension_registry: Option<ExtensionRegistry>,
60        show_startup_warnings: bool,
61    ) -> Result<Self> {
62        let profile = ProfileBootstrap::get()?;
63        AppPaths::init(&profile.paths)?;
64        systemprompt_files::FilesConfig::init()?;
65        let config = Arc::new(Config::get()?.clone());
66        let database = Arc::new(
67            Database::from_config_with_write(
68                &config.database_type,
69                &config.database_url,
70                config.database_write_url.as_deref(),
71            )
72            .await?,
73        );
74
75        if config.database_write_url.is_some() {
76            tracing::info!(
77                "Database read/write separation enabled: reads from replica, writes to primary"
78            );
79        }
80
81        let api_registry = Arc::new(ModuleApiRegistry::new());
82
83        let registry = extension_registry.unwrap_or_else(ExtensionRegistry::discover);
84        registry.validate()?;
85
86        let extension_registry = Arc::new(registry);
87
88        let geoip_reader = Self::load_geoip_database(&config, show_startup_warnings);
89        let content_config = Self::load_content_config(&config);
90
91        #[allow(trivial_casts)]
92        let content_routing: Option<Arc<dyn ContentRouting>> =
93            content_config.clone().map(|c| c as Arc<dyn ContentRouting>);
94
95        let route_classifier = Arc::new(RouteClassifier::new(content_routing.clone()));
96
97        let analytics_service = Arc::new(AnalyticsService::new(
98            &database,
99            geoip_reader.clone(),
100            content_routing,
101        )?);
102
103        let fingerprint_repo = match FingerprintRepository::new(&database) {
104            Ok(repo) => Some(Arc::new(repo)),
105            Err(e) => {
106                tracing::warn!(error = %e, "Failed to initialize fingerprint repository");
107                None
108            },
109        };
110
111        let user_service = match UserService::new(&database) {
112            Ok(svc) => Some(Arc::new(svc)),
113            Err(e) => {
114                tracing::warn!(error = %e, "Failed to initialize user service");
115                None
116            },
117        };
118
119        systemprompt_logging::init_logging(Arc::clone(&database));
120
121        Ok(Self {
122            config,
123            database,
124            api_registry,
125            extension_registry,
126            geoip_reader,
127            content_config,
128            route_classifier,
129            analytics_service,
130            fingerprint_repo,
131            user_service,
132        })
133    }
134
135    #[cfg(feature = "geolocation")]
136    fn load_geoip_database(config: &Config, show_warnings: bool) -> Option<GeoIpReader> {
137        let Some(geoip_path) = &config.geoip_database_path else {
138            if show_warnings {
139                CliService::warning(
140                    "GeoIP database not configured - geographic data will not be available",
141                );
142                CliService::info("  To enable geographic data:");
143                CliService::info("  1. Download MaxMind GeoLite2-City database from: https://dev.maxmind.com/geoip/geolite2-free-geolocation-data");
144                CliService::info(
145                    "  2. Add paths.geoip_database to your profile pointing to the .mmdb file",
146                );
147            }
148            return None;
149        };
150
151        match maxminddb::Reader::open_readfile(geoip_path) {
152            Ok(reader) => Some(Arc::new(reader)),
153            Err(e) => {
154                if show_warnings {
155                    CliService::warning(&format!(
156                        "Could not load GeoIP database from {geoip_path}: {e}"
157                    ));
158                    CliService::info(
159                        "  Geographic data (country/region/city) will not be available.",
160                    );
161                    CliService::info(
162                        "  To fix: Ensure the path is correct and the file is a valid MaxMind \
163                         .mmdb database",
164                    );
165                }
166                None
167            },
168        }
169    }
170
171    #[cfg(not(feature = "geolocation"))]
172    fn load_geoip_database(_config: &Config, _show_warnings: bool) -> Option<GeoIpReader> {
173        None
174    }
175
176    fn load_content_config(config: &Config) -> Option<Arc<ContentConfigRaw>> {
177        let content_config_path = AppPaths::get()
178            .ok()?
179            .system()
180            .content_config()
181            .to_path_buf();
182
183        if !content_config_path.exists() {
184            CliService::warning(&format!(
185                "Content config not found at: {}",
186                content_config_path.display()
187            ));
188            CliService::info("  Landing page detection will not be available.");
189            return None;
190        }
191
192        let yaml_content = match std::fs::read_to_string(&content_config_path) {
193            Ok(c) => c,
194            Err(e) => {
195                CliService::warning(&format!(
196                    "Could not read content config from {}: {}",
197                    content_config_path.display(),
198                    e
199                ));
200                CliService::info("  Landing page detection will not be available.");
201                return None;
202            },
203        };
204
205        match serde_yaml::from_str::<ContentConfigRaw>(&yaml_content) {
206            Ok(mut content_cfg) => {
207                let base_url = config.api_external_url.trim_end_matches('/');
208
209                content_cfg.metadata.structured_data.organization.url = base_url.to_string();
210
211                let logo = &content_cfg.metadata.structured_data.organization.logo;
212                if logo.starts_with('/') {
213                    content_cfg.metadata.structured_data.organization.logo =
214                        format!("{base_url}{logo}");
215                }
216
217                Some(Arc::new(content_cfg))
218            },
219            Err(e) => {
220                CliService::warning(&format!(
221                    "Could not parse content config from {}: {}",
222                    content_config_path.display(),
223                    e
224                ));
225                CliService::info("  Landing page detection will not be available.");
226                None
227            },
228        }
229    }
230
231    pub fn config(&self) -> &Config {
232        &self.config
233    }
234
235    pub fn content_config(&self) -> Option<&ContentConfigRaw> {
236        self.content_config.as_ref().map(AsRef::as_ref)
237    }
238
239    #[allow(trivial_casts)]
240    pub fn content_routing(&self) -> Option<Arc<dyn ContentRouting>> {
241        self.content_config
242            .clone()
243            .map(|c| c as Arc<dyn ContentRouting>)
244    }
245
246    pub const fn db_pool(&self) -> &DbPool {
247        &self.database
248    }
249
250    pub const fn database(&self) -> &DbPool {
251        &self.database
252    }
253
254    pub fn api_registry(&self) -> &ModuleApiRegistry {
255        &self.api_registry
256    }
257
258    pub fn extension_registry(&self) -> &ExtensionRegistry {
259        &self.extension_registry
260    }
261
262    pub fn server_address(&self) -> String {
263        format!("{}:{}", self.config.host, self.config.port)
264    }
265
266    pub fn get_provided_audiences() -> Vec<String> {
267        vec!["a2a".to_string(), "api".to_string(), "mcp".to_string()]
268    }
269
270    pub fn get_valid_audiences(_module_name: &str) -> Vec<String> {
271        Self::get_provided_audiences()
272    }
273
274    pub fn get_server_audiences(_server_name: &str, _port: u16) -> Vec<String> {
275        Self::get_provided_audiences()
276    }
277
278    pub const fn geoip_reader(&self) -> Option<&GeoIpReader> {
279        self.geoip_reader.as_ref()
280    }
281
282    pub const fn analytics_service(&self) -> &Arc<AnalyticsService> {
283        &self.analytics_service
284    }
285
286    pub const fn route_classifier(&self) -> &Arc<RouteClassifier> {
287        &self.route_classifier
288    }
289}
290
291#[allow(trivial_casts)]
292impl AppContextTrait for AppContext {
293    fn config(&self) -> Arc<dyn ConfigProvider> {
294        Arc::clone(&self.config) as _
295    }
296
297    fn database_handle(&self) -> Arc<dyn DatabaseHandle> {
298        Arc::clone(&self.database) as _
299    }
300
301    fn analytics_provider(&self) -> Option<Arc<dyn AnalyticsProvider>> {
302        Some(Arc::clone(&self.analytics_service) as _)
303    }
304
305    fn fingerprint_provider(&self) -> Option<Arc<dyn FingerprintProvider>> {
306        let repo = self.fingerprint_repo.as_ref()?;
307        Some(Arc::clone(repo) as _)
308    }
309
310    fn user_provider(&self) -> Option<Arc<dyn UserProvider>> {
311        let service = self.user_service.as_ref()?;
312        Some(Arc::clone(service) as _)
313    }
314}
315
316#[allow(trivial_casts)]
317impl ExtensionContext for AppContext {
318    fn config(&self) -> Arc<dyn ConfigProvider> {
319        Arc::clone(&self.config) as _
320    }
321
322    fn database(&self) -> Arc<dyn DatabaseHandle> {
323        Arc::clone(&self.database) as _
324    }
325
326    fn get_extension(&self, id: &str) -> Option<Arc<dyn Extension>> {
327        self.extension_registry.get(id).cloned()
328    }
329}
330
331#[derive(Debug, Default)]
332pub struct AppContextBuilder {
333    extension_registry: Option<ExtensionRegistry>,
334    show_startup_warnings: bool,
335}
336
337impl AppContextBuilder {
338    #[must_use]
339    pub fn new() -> Self {
340        Self::default()
341    }
342
343    #[must_use]
344    pub fn with_extensions(mut self, registry: ExtensionRegistry) -> Self {
345        self.extension_registry = Some(registry);
346        self
347    }
348
349    #[must_use]
350    pub const fn with_startup_warnings(mut self, show: bool) -> Self {
351        self.show_startup_warnings = show;
352        self
353    }
354
355    pub async fn build(self) -> Result<AppContext> {
356        AppContext::new_internal(self.extension_registry, self.show_startup_warnings).await
357    }
358}