Skip to main content

systemprompt_agent/services/a2a_server/
server.rs

1use axum::routing::{get, post};
2use axum::{Router, middleware};
3use std::pin::Pin;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_models::modules::ApiPaths;
7use systemprompt_models::{AgentConfig, AiProvider};
8use tokio::sync::{RwLock, Semaphore};
9use tower_http::cors::CorsLayer;
10use tower_http::services::ServeDir;
11
12use super::auth::{AgentOAuthConfig, AgentOAuthState, agent_oauth_middleware_wrapper};
13use super::handlers::{AgentHandlerState, handle_agent_card, handle_agent_request};
14use crate::state::AgentState;
15
16pub struct Server {
17    db_pool: DbPool,
18    config: Arc<RwLock<AgentConfig>>,
19    oauth_state: Arc<AgentOAuthState>,
20    agent_state: Arc<AgentState>,
21    ai_service: Arc<dyn AiProvider>,
22    stream_semaphore: Arc<Semaphore>,
23    port: u16,
24}
25
26impl std::fmt::Debug for Server {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("Server")
29            .field("db_pool", &"<DbPool>")
30            .field("config", &"Arc<RwLock<AgentConfig>>")
31            .field("oauth_state", &"Arc<AgentOAuthState>")
32            .field("agent_state", &"Arc<AgentState>")
33            .field("ai_service", &"<Arc<dyn AiProvider>>")
34            .field(
35                "stream_semaphore",
36                &self.stream_semaphore.available_permits(),
37            )
38            .field("port", &self.port)
39            .finish()
40    }
41}
42
43impl Server {
44    pub async fn new(
45        db_pool: DbPool,
46        agent_state: Arc<AgentState>,
47        ai_service: Arc<dyn AiProvider>,
48        agent_name: Option<String>,
49        port: u16,
50    ) -> Result<Self, crate::error::AgentError> {
51        use crate::services::registry::AgentRegistry;
52
53        let mut config = if let Some(name) = agent_name {
54            let registry = AgentRegistry::new()
55                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
56            registry
57                .get_agent(&name)
58                .await
59                .map_err(|e| crate::error::AgentError::Server(e.to_string()))?
60        } else {
61            return Err(crate::error::AgentError::Validation(
62                "Agent name is required".to_owned(),
63            ));
64        };
65
66        config.extract_oauth_scopes_from_card();
67
68        let oauth_config = AgentOAuthConfig::default();
69        let global_config = systemprompt_models::Config::get()
70            .map_err(|e| crate::error::AgentError::Config(e.to_string()))?;
71        let mut oauth_state = AgentOAuthState::new(
72            Arc::clone(&db_pool),
73            oauth_config,
74            global_config.jwt_issuer.clone(),
75            global_config.jwt_audiences.clone(),
76        );
77
78        oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
79        if let Some(user_provider) = agent_state.user_provider().cloned() {
80            oauth_state = oauth_state.with_user_provider(user_provider);
81        }
82
83        Ok(Self {
84            db_pool,
85            config: Arc::new(RwLock::new(config)),
86            oauth_state: Arc::new(oauth_state),
87            agent_state,
88            ai_service,
89            stream_semaphore: Arc::new(Semaphore::new(global_config.max_concurrent_streams)),
90            port,
91        })
92    }
93
94    pub async fn reload_config(&self) -> Result<(), crate::error::AgentError> {
95        use crate::services::registry::AgentRegistry;
96
97        let agent_name = {
98            let config = self.config.read().await;
99            config.name.clone()
100        };
101
102        let registry =
103            AgentRegistry::new().map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
104        let mut new_config = registry
105            .get_agent(&agent_name)
106            .await
107            .map_err(|e| crate::error::AgentError::Server(e.to_string()))?;
108        new_config.extract_oauth_scopes_from_card();
109        *self.config.write().await = new_config;
110
111        tracing::info!(agent_name = %agent_name, "Configuration reloaded");
112        Ok(())
113    }
114
115    pub fn create_router(&self) -> Router {
116        let state = Arc::new(AgentHandlerState {
117            db_pool: Arc::clone(&self.db_pool),
118            config: Arc::clone(&self.config),
119            oauth_state: Arc::clone(&self.oauth_state),
120            agent_state: Arc::clone(&self.agent_state),
121            ai_service: Arc::clone(&self.ai_service),
122            stream_semaphore: Arc::clone(&self.stream_semaphore),
123        });
124
125        let post_router = Router::new()
126            .route("/", post(handle_agent_request))
127            .with_state(Arc::clone(&state))
128            .layer(middleware::from_fn_with_state(
129                Arc::clone(&state),
130                agent_oauth_middleware_wrapper,
131            ));
132
133        let get_router = Router::new()
134            .route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
135            .route(ApiPaths::A2A_CARD, get(handle_agent_card))
136            .with_state(state);
137
138        let api_router = Router::new().merge(post_router).merge(get_router);
139
140        let web_dist_path = std::path::Path::new("web/dist");
141        let router = if web_dist_path.exists() {
142            api_router.fallback_service(ServeDir::new(web_dist_path))
143        } else {
144            api_router
145        };
146
147        router.layer(CorsLayer::permissive())
148    }
149
150    pub async fn run(self) -> Result<(), crate::error::AgentError> {
151        Self::log_server_configuration();
152        self.start_server(None).await
153    }
154
155    pub async fn run_with_shutdown(
156        self,
157        shutdown_signal: impl Future<Output = ()> + Send + 'static,
158    ) -> Result<(), crate::error::AgentError> {
159        Self::log_server_configuration();
160        self.start_server(Some(Box::pin(shutdown_signal))).await
161    }
162
163    const fn log_server_configuration() {}
164
165    async fn start_server(
166        self,
167        shutdown_signal: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
168    ) -> Result<(), crate::error::AgentError> {
169        let app = self.create_router();
170        let addr = format!("0.0.0.0:{}", self.port);
171        let listener = tokio::net::TcpListener::bind(&addr).await?;
172
173        match shutdown_signal {
174            Some(signal) => axum::serve(listener, app)
175                .with_graceful_shutdown(signal)
176                .await
177                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
178            None => axum::serve(listener, app)
179                .await
180                .map_err(|e| crate::error::AgentError::Server(e.to_string())),
181        }
182    }
183}