Skip to main content

systemprompt_agent/services/a2a_server/
server.rs

1//! The per-agent A2A HTTP server.
2//!
3//! [`Server`] loads an agent's configuration, wires OAuth state and the AI
4//! provider, and builds the axum [`Router`] exposing the agent card and the A2A
5//! request endpoint. It runs the listener with optional graceful shutdown and
6//! supports live configuration reloads.
7
8use axum::routing::{get, post};
9use axum::{Router, middleware};
10use std::pin::Pin;
11use std::sync::Arc;
12use systemprompt_database::DbPool;
13use systemprompt_models::modules::ApiPaths;
14use systemprompt_models::{AgentConfig, AiProvider};
15use tokio::sync::{RwLock, Semaphore};
16use tower_http::cors::CorsLayer;
17use tower_http::services::ServeDir;
18
19use super::auth::{AgentOAuthConfig, AgentOAuthState, agent_oauth_middleware_wrapper};
20use super::handlers::{AgentHandlerState, handle_agent_card, handle_agent_request};
21use crate::state::AgentState;
22
23pub struct Server {
24    db_pool: DbPool,
25    config: Arc<RwLock<AgentConfig>>,
26    oauth_state: Arc<AgentOAuthState>,
27    agent_state: Arc<AgentState>,
28    ai_service: Arc<dyn AiProvider>,
29    stream_semaphore: Arc<Semaphore>,
30    port: u16,
31}
32
33impl std::fmt::Debug for Server {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("Server")
36            .field("db_pool", &"<DbPool>")
37            .field("config", &"Arc<RwLock<AgentConfig>>")
38            .field("oauth_state", &"Arc<AgentOAuthState>")
39            .field("agent_state", &"Arc<AgentState>")
40            .field("ai_service", &"<Arc<dyn AiProvider>>")
41            .field(
42                "stream_semaphore",
43                &self.stream_semaphore.available_permits(),
44            )
45            .field("port", &self.port)
46            .finish()
47    }
48}
49
50impl Server {
51    pub async fn new(
52        db_pool: DbPool,
53        agent_state: Arc<AgentState>,
54        ai_service: Arc<dyn AiProvider>,
55        agent_name: Option<String>,
56        port: u16,
57    ) -> Result<Self, crate::error::AgentError> {
58        use crate::services::registry::AgentRegistry;
59
60        let mut config = if let Some(name) = agent_name {
61            let registry = AgentRegistry::new()
62                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
63            registry
64                .get_agent(&name)
65                .await
66                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?
67        } else {
68            return Err(crate::error::AgentError::Validation(
69                "Agent name is required".to_owned(),
70            ));
71        };
72
73        config.extract_oauth_scopes_from_card();
74
75        let oauth_config = AgentOAuthConfig::default();
76        let global_config = systemprompt_models::Config::get()
77            .map_err(|e| crate::error::AgentError::Config(e.to_string()))?;
78        let mut oauth_state = AgentOAuthState::new(
79            Arc::clone(&db_pool),
80            oauth_config,
81            global_config.jwt_issuer.clone(),
82            global_config.jwt_audiences.clone(),
83        );
84
85        oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
86        if let Some(user_provider) = agent_state.user_provider().cloned() {
87            oauth_state = oauth_state.with_user_provider(user_provider);
88        }
89
90        Ok(Self {
91            db_pool,
92            config: Arc::new(RwLock::new(config)),
93            oauth_state: Arc::new(oauth_state),
94            agent_state,
95            ai_service,
96            stream_semaphore: Arc::new(Semaphore::new(global_config.max_concurrent_streams)),
97            port,
98        })
99    }
100
101    pub async fn reload_config(&self) -> Result<(), crate::error::AgentError> {
102        use crate::services::registry::AgentRegistry;
103
104        let agent_name = {
105            let config = self.config.read().await;
106            config.name.clone()
107        };
108
109        let registry =
110            AgentRegistry::new().map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
111        let mut new_config = registry
112            .get_agent(&agent_name)
113            .await
114            .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
115        new_config.extract_oauth_scopes_from_card();
116        *self.config.write().await = new_config;
117
118        tracing::info!(agent_name = %agent_name, "Configuration reloaded");
119        Ok(())
120    }
121
122    pub fn create_router(&self) -> Router {
123        let state = Arc::new(AgentHandlerState {
124            db_pool: Arc::clone(&self.db_pool),
125            config: Arc::clone(&self.config),
126            oauth_state: Arc::clone(&self.oauth_state),
127            agent_state: Arc::clone(&self.agent_state),
128            ai_service: Arc::clone(&self.ai_service),
129            stream_semaphore: Arc::clone(&self.stream_semaphore),
130        });
131
132        let post_router = Router::new()
133            .route("/", post(handle_agent_request))
134            .with_state(Arc::clone(&state))
135            .layer(middleware::from_fn_with_state(
136                Arc::clone(&state),
137                agent_oauth_middleware_wrapper,
138            ));
139
140        let get_router = Router::new()
141            .route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
142            .route(ApiPaths::A2A_CARD, get(handle_agent_card))
143            .with_state(state);
144
145        let api_router = Router::new().merge(post_router).merge(get_router);
146
147        let web_dist_path = std::path::Path::new("web/dist");
148        let router = if web_dist_path.exists() {
149            api_router.fallback_service(ServeDir::new(web_dist_path))
150        } else {
151            api_router
152        };
153
154        router.layer(CorsLayer::permissive())
155    }
156
157    pub async fn run(self) -> Result<(), crate::error::AgentError> {
158        Self::log_server_configuration();
159        self.start_server(None).await
160    }
161
162    pub async fn run_with_shutdown(
163        self,
164        shutdown_signal: impl Future<Output = ()> + Send + 'static,
165    ) -> Result<(), crate::error::AgentError> {
166        Self::log_server_configuration();
167        self.start_server(Some(Box::pin(shutdown_signal))).await
168    }
169
170    const fn log_server_configuration() {}
171
172    async fn start_server(
173        self,
174        shutdown_signal: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
175    ) -> Result<(), crate::error::AgentError> {
176        let app = self.create_router();
177        let addr = format!("0.0.0.0:{}", self.port);
178        let listener = tokio::net::TcpListener::bind(&addr).await?;
179
180        match shutdown_signal {
181            Some(signal) => axum::serve(listener, app)
182                .with_graceful_shutdown(signal)
183                .await
184                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
185            None => axum::serve(listener, app)
186                .await
187                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
188        }
189    }
190}