Skip to main content

systemprompt_agent/services/a2a_server/
server.rs

1use axum::routing::{get, post};
2use axum::{middleware, Router};
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_models::modules::ApiPaths;
6use systemprompt_models::{AgentConfig, AiProvider};
7use tokio::sync::RwLock;
8use tower_http::cors::CorsLayer;
9use tower_http::services::ServeDir;
10
11use super::auth::{agent_oauth_middleware_wrapper, AgentOAuthConfig, AgentOAuthState};
12use super::handlers::{handle_agent_card, handle_agent_request, AgentHandlerState};
13use crate::state::AgentState;
14
15pub struct Server {
16    db_pool: DbPool,
17    config: Arc<RwLock<AgentConfig>>,
18    oauth_state: Arc<AgentOAuthState>,
19    agent_state: Arc<AgentState>,
20    ai_service: Arc<dyn AiProvider>,
21    port: u16,
22}
23
24impl std::fmt::Debug for Server {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("Server")
27            .field("pool", &"SqlitePool")
28            .field("config", &"Arc<RwLock<AgentConfig>>")
29            .field("oauth_state", &"Arc<AgentOAuthState>")
30            .field("agent_state", &"Arc<AgentState>")
31            .field("ai_service", &"<Arc<dyn AiProvider>>")
32            .field("port", &self.port)
33            .finish()
34    }
35}
36
37impl Server {
38    pub async fn new(
39        db_pool: DbPool,
40        agent_state: Arc<AgentState>,
41        ai_service: Arc<dyn AiProvider>,
42        agent_name: Option<String>,
43        port: u16,
44    ) -> anyhow::Result<Self> {
45        use crate::services::registry::AgentRegistry;
46
47        let mut config = if let Some(name) = agent_name {
48            let registry = AgentRegistry::new().await?;
49            registry.get_agent(&name).await?
50        } else {
51            return Err(anyhow::anyhow!("Agent name is required"));
52        };
53
54        config.extract_oauth_scopes_from_card();
55
56        let oauth_config = AgentOAuthConfig::default();
57        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
58        let global_config = systemprompt_models::Config::get()?;
59        let mut oauth_state = AgentOAuthState::new(
60            db_pool.clone(),
61            oauth_config,
62            jwt_secret,
63            global_config.jwt_issuer.clone(),
64            global_config.jwt_audiences.clone(),
65        )
66        .await?;
67
68        oauth_state = oauth_state.with_jwt_provider(Arc::clone(agent_state.jwt_provider()));
69        if let Some(user_provider) = agent_state.user_provider().cloned() {
70            oauth_state = oauth_state.with_user_provider(user_provider);
71        }
72
73        Ok(Self {
74            db_pool,
75            config: Arc::new(RwLock::new(config)),
76            oauth_state: Arc::new(oauth_state),
77            agent_state,
78            ai_service,
79            port,
80        })
81    }
82
83    pub async fn reload_config(&self) -> anyhow::Result<()> {
84        use crate::services::registry::AgentRegistry;
85
86        let agent_name = {
87            let config = self.config.read().await;
88            config.name.clone()
89        };
90
91        let registry = AgentRegistry::new().await?;
92        let mut new_config = registry.get_agent(&agent_name).await?;
93        new_config.extract_oauth_scopes_from_card();
94        *self.config.write().await = new_config;
95
96        tracing::info!(agent_name = %agent_name, "Configuration reloaded");
97        Ok(())
98    }
99
100    pub fn create_router(&self) -> Router {
101        let state = Arc::new(AgentHandlerState {
102            db_pool: self.db_pool.clone(),
103            config: Arc::clone(&self.config),
104            oauth_state: Arc::clone(&self.oauth_state),
105            agent_state: Arc::clone(&self.agent_state),
106            ai_service: Arc::clone(&self.ai_service),
107        });
108
109        let post_router = Router::new()
110            .route("/", post(handle_agent_request))
111            .with_state(state.clone())
112            .layer(middleware::from_fn_with_state(
113                state.clone(),
114                agent_oauth_middleware_wrapper,
115            ));
116
117        let get_router = Router::new()
118            .route(ApiPaths::WELLKNOWN_AGENT_CARD, get(handle_agent_card))
119            .route(ApiPaths::A2A_CARD, get(handle_agent_card))
120            .with_state(state);
121
122        let api_router = Router::new().merge(post_router).merge(get_router);
123
124        let web_dist_path = std::path::Path::new("web/dist");
125        let router = if web_dist_path.exists() {
126            api_router.fallback_service(ServeDir::new(web_dist_path))
127        } else {
128            api_router
129        };
130
131        router.layer(CorsLayer::permissive())
132    }
133
134    pub async fn run(self) -> anyhow::Result<()> {
135        self.log_server_configuration().await;
136        self.start_server(None).await
137    }
138
139    pub async fn run_with_shutdown(
140        self,
141        shutdown_signal: impl std::future::Future<Output = ()> + Send + 'static,
142    ) -> anyhow::Result<()> {
143        self.log_server_configuration().await;
144        self.start_server(Some(Box::pin(shutdown_signal))).await
145    }
146
147    async fn log_server_configuration(&self) {}
148
149    async fn start_server(
150        self,
151        shutdown_signal: Option<std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>>,
152    ) -> anyhow::Result<()> {
153        let app = self.create_router();
154        let addr = format!("0.0.0.0:{}", self.port);
155        let listener = tokio::net::TcpListener::bind(&addr).await?;
156
157        match shutdown_signal {
158            Some(signal) => axum::serve(listener, app)
159                .with_graceful_shutdown(signal)
160                .await
161                .map_err(|e| anyhow::anyhow!("Server error: {}", e)),
162            None => axum::serve(listener, app)
163                .await
164                .map_err(|e| anyhow::anyhow!("Server error: {}", e)),
165        }
166    }
167}