1pub mod agent_endpoint;
21pub mod chat_endpoint;
22pub mod sse;
23
24pub use agent_endpoint::{agent_router, AgentHandler};
25pub use chat_endpoint::chat_router;
26pub use sse::stream_to_sse;
27
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31use axum::{
32 Router,
33 body::Body,
34 extract::State,
35 http::{Request, StatusCode},
36 middleware::{self, Next},
37 response::Response,
38};
39use tower_http::limit::RequestBodyLimitLayer;
40use wesichain_core::{LlmRequest, LlmResponse, Runnable, WesichainError};
41
42const DEFAULT_BODY_LIMIT: usize = 4 * 1024 * 1024;
44
45async fn bearer_auth_middleware(
54 State(expected): State<Arc<String>>,
55 req: Request<Body>,
56 next: Next,
57) -> Result<Response, StatusCode> {
58 let provided = req
59 .headers()
60 .get(axum::http::header::AUTHORIZATION)
61 .and_then(|v| v.to_str().ok())
62 .unwrap_or("");
63
64 if provided == format!("Bearer {expected}") {
65 Ok(next.run(req).await)
66 } else {
67 Err(StatusCode::UNAUTHORIZED)
68 }
69}
70
71pub struct WesichainServer {
77 router: Router,
78 auth_token: Option<String>,
79 body_limit: usize,
80}
81
82impl Default for WesichainServer {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl WesichainServer {
89 pub fn new() -> Self {
90 Self {
91 router: Router::new(),
92 auth_token: None,
93 body_limit: DEFAULT_BODY_LIMIT,
94 }
95 }
96
97 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
101 self.auth_token = Some(token.into());
102 self
103 }
104
105 pub fn with_body_limit(mut self, bytes: usize) -> Self {
107 self.body_limit = bytes;
108 self
109 }
110
111 pub fn with_chat<L>(mut self, llm: L, default_model: impl Into<String>) -> Self
113 where
114 L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync + 'static,
115 {
116 self.router = self.router.merge(chat_router(llm, default_model));
117 self
118 }
119
120 pub fn with_agent_stream(mut self, handler: AgentHandler) -> Self {
122 self.router = self.router.merge(agent_router(handler));
123 self
124 }
125
126 pub fn with_router(mut self, router: Router) -> Self {
128 self.router = self.router.merge(router);
129 self
130 }
131
132 pub fn build(self) -> Router {
138 let mut router = self.router;
139
140 if let Some(token) = self.auth_token {
143 let token_state = Arc::new(token);
144 router = router.route_layer(
145 middleware::from_fn_with_state(token_state, bearer_auth_middleware),
146 );
147 }
148
149 router.layer(RequestBodyLimitLayer::new(self.body_limit))
151 }
152
153 pub async fn serve(self, addr: SocketAddr) -> Result<(), WesichainError> {
155 let app = self.build();
156 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
157 WesichainError::InvalidConfig(format!("Failed to bind {addr}: {e}"))
158 })?;
159
160 axum::serve(listener, app)
161 .await
162 .map_err(|e| WesichainError::Custom(format!("Server error: {e}")))
163 }
164}