Skip to main content

systemprompt_agent/services/a2a_server/
server.rs

1use axum::routing::{get, post};
2use axum::{Router, middleware};
3use std::pin::Pin;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_models::modules::ApiPaths;
7use systemprompt_models::{AgentConfig, AiProvider};
8use tokio::sync::RwLock;
9use tower_http::cors::CorsLayer;
10use tower_http::services::ServeDir;
11
12use super::auth::{AgentOAuthConfig, AgentOAuthState, agent_oauth_middleware_wrapper};
13use super::handlers::{AgentHandlerState, handle_agent_card, handle_agent_request};
14use crate::state::AgentState;
15
16pub struct Server {
17    db_pool: DbPool,
18    config: Arc<RwLock<AgentConfig>>,
19    oauth_state: Arc<AgentOAuthState>,
20    agent_state: Arc<AgentState>,
21    ai_service: Arc<dyn AiProvider>,
22    port: u16,
23}
24
25impl std::fmt::Debug for Server {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("Server")
28            .field("db_pool", &"<DbPool>")
29            .field("config", &"Arc<RwLock<AgentConfig>>")
30            .field("oauth_state", &"Arc<AgentOAuthState>")
31            .field("agent_state", &"Arc<AgentState>")
32            .field("ai_service", &"<Arc<dyn AiProvider>>")
33            .field("port", &self.port)
34            .finish()
35    }
36}
37
38impl Server {
39    pub async fn new(
40        db_pool: DbPool,
41        agent_state: Arc<AgentState>,
42        ai_service: Arc<dyn AiProvider>,
43        agent_name: Option<String>,
44        port: u16,
45    ) -> Result<Self, crate::error::AgentError> {
46        use crate::services::registry::AgentRegistry;
47
48        let mut config = if let Some(name) = agent_name {
49            let registry = AgentRegistry::new()
50                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
51            registry
52                .get_agent(&name)
53                .await
54                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?
55        } else {
56            return Err(crate::error::AgentError::Validation(
57                "Agent name is required".to_string(),
58            ));
59        };
60
61        config.extract_oauth_scopes_from_card();
62
63        let oauth_config = AgentOAuthConfig::default();
64        let jwt_secret = systemprompt_config::SecretsBootstrap::jwt_secret()
65            .map_err(|e| crate::error::AgentError::Config(e.to_string()))?
66            .to_string();
67        let global_config = systemprompt_models::Config::get()
68            .map_err(|e| crate::error::AgentError::Config(e.to_string()))?;
69        let mut oauth_state = AgentOAuthState::new(
70            Arc::clone(&db_pool),
71            oauth_config,
72            jwt_secret,
73            global_config.jwt_issuer.clone(),
74            global_config.jwt_audiences.clone(),
75        );
76
77        oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
78        if let Some(user_provider) = agent_state.user_provider().cloned() {
79            oauth_state = oauth_state.with_user_provider(user_provider);
80        }
81
82        Ok(Self {
83            db_pool,
84            config: Arc::new(RwLock::new(config)),
85            oauth_state: Arc::new(oauth_state),
86            agent_state,
87            ai_service,
88            port,
89        })
90    }
91
92    pub async fn reload_config(&self) -> Result<(), crate::error::AgentError> {
93        use crate::services::registry::AgentRegistry;
94
95        let agent_name = {
96            let config = self.config.read().await;
97            config.name.clone()
98        };
99
100        let registry =
101            AgentRegistry::new().map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
102        let mut new_config = registry
103            .get_agent(&agent_name)
104            .await
105            .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
106        new_config.extract_oauth_scopes_from_card();
107        *self.config.write().await = new_config;
108
109        tracing::info!(agent_name = %agent_name, "Configuration reloaded");
110        Ok(())
111    }
112
113    pub fn create_router(&self) -> Router {
114        let state = Arc::new(AgentHandlerState {
115            db_pool: Arc::clone(&self.db_pool),
116            config: Arc::clone(&self.config),
117            oauth_state: Arc::clone(&self.oauth_state),
118            agent_state: Arc::clone(&self.agent_state),
119            ai_service: Arc::clone(&self.ai_service),
120        });
121
122        let post_router = Router::new()
123            .route("/", post(handle_agent_request))
124            .with_state(Arc::clone(&state))
125            .layer(middleware::from_fn_with_state(
126                Arc::clone(&state),
127                agent_oauth_middleware_wrapper,
128            ));
129
130        let get_router = Router::new()
131            .route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
132            .route(ApiPaths::A2A_CARD, get(handle_agent_card))
133            .with_state(state);
134
135        let api_router = Router::new().merge(post_router).merge(get_router);
136
137        let web_dist_path = std::path::Path::new("web/dist");
138        let router = if web_dist_path.exists() {
139            api_router.fallback_service(ServeDir::new(web_dist_path))
140        } else {
141            api_router
142        };
143
144        router.layer(CorsLayer::permissive())
145    }
146
147    pub async fn run(self) -> Result<(), crate::error::AgentError> {
148        Self::log_server_configuration();
149        self.start_server(None).await
150    }
151
152    pub async fn run_with_shutdown(
153        self,
154        shutdown_signal: impl Future<Output = ()> + Send + 'static,
155    ) -> Result<(), crate::error::AgentError> {
156        Self::log_server_configuration();
157        self.start_server(Some(Box::pin(shutdown_signal))).await
158    }
159
160    const fn log_server_configuration() {}
161
162    async fn start_server(
163        self,
164        shutdown_signal: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
165    ) -> Result<(), crate::error::AgentError> {
166        let app = self.create_router();
167        let addr = format!("0.0.0.0:{}", self.port);
168        let listener = tokio::net::TcpListener::bind(&addr).await?;
169
170        match shutdown_signal {
171            Some(signal) => axum::serve(listener, app)
172                .with_graceful_shutdown(signal)
173                .await
174                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
175            None => axum::serve(listener, app)
176                .await
177                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
178        }
179    }
180}