systemprompt_agent/services/a2a_server/
server.rs1use axum::routing::{get, post};
2use axum::{Router, middleware};
3use std::pin::Pin;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_models::modules::ApiPaths;
7use systemprompt_models::{AgentConfig, AiProvider};
8use tokio::sync::RwLock;
9use tower_http::cors::CorsLayer;
10use tower_http::services::ServeDir;
11
12use super::auth::{AgentOAuthConfig, AgentOAuthState, agent_oauth_middleware_wrapper};
13use super::handlers::{AgentHandlerState, handle_agent_card, handle_agent_request};
14use crate::state::AgentState;
15
16pub struct Server {
17 db_pool: DbPool,
18 config: Arc<RwLock<AgentConfig>>,
19 oauth_state: Arc<AgentOAuthState>,
20 agent_state: Arc<AgentState>,
21 ai_service: Arc<dyn AiProvider>,
22 port: u16,
23}
24
25impl std::fmt::Debug for Server {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("Server")
28 .field("db_pool", &"<DbPool>")
29 .field("config", &"Arc<RwLock<AgentConfig>>")
30 .field("oauth_state", &"Arc<AgentOAuthState>")
31 .field("agent_state", &"Arc<AgentState>")
32 .field("ai_service", &"<Arc<dyn AiProvider>>")
33 .field("port", &self.port)
34 .finish()
35 }
36}
37
38impl Server {
39 pub async fn new(
40 db_pool: DbPool,
41 agent_state: Arc<AgentState>,
42 ai_service: Arc<dyn AiProvider>,
43 agent_name: Option<String>,
44 port: u16,
45 ) -> anyhow::Result<Self> {
46 use crate::services::registry::AgentRegistry;
47
48 let mut config = if let Some(name) = agent_name {
49 let registry = AgentRegistry::new()?;
50 registry.get_agent(&name).await?
51 } else {
52 return Err(anyhow::anyhow!("Agent name is required"));
53 };
54
55 config.extract_oauth_scopes_from_card();
56
57 let oauth_config = AgentOAuthConfig::default();
58 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
59 let global_config = systemprompt_models::Config::get()?;
60 let mut oauth_state = AgentOAuthState::new(
61 Arc::clone(&db_pool),
62 oauth_config,
63 jwt_secret,
64 global_config.jwt_issuer.clone(),
65 global_config.jwt_audiences.clone(),
66 );
67
68 oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
69 if let Some(user_provider) = agent_state.user_provider().cloned() {
70 oauth_state = oauth_state.with_user_provider(user_provider);
71 }
72
73 Ok(Self {
74 db_pool,
75 config: Arc::new(RwLock::new(config)),
76 oauth_state: Arc::new(oauth_state),
77 agent_state,
78 ai_service,
79 port,
80 })
81 }
82
83 pub async fn reload_config(&self) -> anyhow::Result<()> {
84 use crate::services::registry::AgentRegistry;
85
86 let agent_name = {
87 let config = self.config.read().await;
88 config.name.clone()
89 };
90
91 let registry = AgentRegistry::new()?;
92 let mut new_config = registry.get_agent(&agent_name).await?;
93 new_config.extract_oauth_scopes_from_card();
94 *self.config.write().await = new_config;
95
96 tracing::info!(agent_name = %agent_name, "Configuration reloaded");
97 Ok(())
98 }
99
100 pub fn create_router(&self) -> Router {
101 let state = Arc::new(AgentHandlerState {
102 db_pool: Arc::clone(&self.db_pool),
103 config: Arc::clone(&self.config),
104 oauth_state: Arc::clone(&self.oauth_state),
105 agent_state: Arc::clone(&self.agent_state),
106 ai_service: Arc::clone(&self.ai_service),
107 });
108
109 let post_router = Router::new()
110 .route("/", post(handle_agent_request))
111 .with_state(Arc::clone(&state))
112 .layer(middleware::from_fn_with_state(
113 Arc::clone(&state),
114 agent_oauth_middleware_wrapper,
115 ));
116
117 let get_router = Router::new()
118 .route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
119 .route(ApiPaths::A2A_CARD, get(handle_agent_card))
120 .with_state(state);
121
122 let api_router = Router::new().merge(post_router).merge(get_router);
123
124 let web_dist_path = std::path::Path::new("web/dist");
125 let router = if web_dist_path.exists() {
126 api_router.fallback_service(ServeDir::new(web_dist_path))
127 } else {
128 api_router
129 };
130
131 router.layer(CorsLayer::permissive())
132 }
133
134 pub async fn run(self) -> anyhow::Result<()> {
135 Self::log_server_configuration();
136 self.start_server(None).await
137 }
138
139 pub async fn run_with_shutdown(
140 self,
141 shutdown_signal: impl Future<Output = ()> + Send + 'static,
142 ) -> anyhow::Result<()> {
143 Self::log_server_configuration();
144 self.start_server(Some(Box::pin(shutdown_signal))).await
145 }
146
147 const fn log_server_configuration() {}
148
149 async fn start_server(
150 self,
151 shutdown_signal: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
152 ) -> anyhow::Result<()> {
153 let app = self.create_router();
154 let addr = format!("0.0.0.0:{}", self.port);
155 let listener = tokio::net::TcpListener::bind(&addr).await?;
156
157 match shutdown_signal {
158 Some(signal) => axum::serve(listener, app)
159 .with_graceful_shutdown(signal)
160 .await
161 .map_err(|e| anyhow::anyhow!("Server error: {}", e)),
162 None => axum::serve(listener, app)
163 .await
164 .map_err(|e| anyhow::anyhow!("Server error: {}", e)),
165 }
166 }
167}