Skip to main content

wesichain_server/
lib.rs

1//! HTTP server integration for wesichain.
2//!
3//! Provides a one-liner server builder that exposes LLMs and agents over HTTP.
4//!
5//! # Quick start
6//!
7//! ```ignore
8//! use wesichain_server::WesichainServer;
9//!
10//! #[tokio::main]
11//! async fn main() -> anyhow::Result<()> {
12//!     WesichainServer::new()
13//!         .with_chat(my_llm, "claude-3-5-sonnet-20241022")
14//!         .with_auth_token("my-secret-token")
15//!         .serve("0.0.0.0:3000".parse()?)
16//!         .await
17//! }
18//! ```
19
20pub mod agent_endpoint;
21pub mod chat_endpoint;
22pub mod sse;
23
24pub use agent_endpoint::{agent_router, AgentHandler};
25pub use chat_endpoint::chat_router;
26pub use sse::stream_to_sse;
27
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31use axum::{
32    Router,
33    body::Body,
34    extract::State,
35    http::{Request, StatusCode},
36    middleware::{self, Next},
37    response::Response,
38};
39use tower_http::limit::RequestBodyLimitLayer;
40use wesichain_core::{LlmRequest, LlmResponse, Runnable, WesichainError};
41
42/// Default maximum request body size (4 MiB).
43const DEFAULT_BODY_LIMIT: usize = 4 * 1024 * 1024;
44
45// ---------------------------------------------------------------------------
46// Bearer auth middleware
47// ---------------------------------------------------------------------------
48
49/// Axum middleware that enforces a static Bearer token.
50///
51/// Responds with `401 Unauthorized` when the `Authorization` header is absent
52/// or does not match `Bearer <token>`.
53async fn bearer_auth_middleware(
54    State(expected): State<Arc<String>>,
55    req: Request<Body>,
56    next: Next,
57) -> Result<Response, StatusCode> {
58    let provided = req
59        .headers()
60        .get(axum::http::header::AUTHORIZATION)
61        .and_then(|v| v.to_str().ok())
62        .unwrap_or("");
63
64    if provided == format!("Bearer {expected}") {
65        Ok(next.run(req).await)
66    } else {
67        Err(StatusCode::UNAUTHORIZED)
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Server builder
73// ---------------------------------------------------------------------------
74
75/// One-liner builder for a wesichain HTTP server.
76pub struct WesichainServer {
77    router: Router,
78    auth_token: Option<String>,
79    body_limit: usize,
80}
81
82impl Default for WesichainServer {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl WesichainServer {
89    pub fn new() -> Self {
90        Self {
91            router: Router::new(),
92            auth_token: None,
93            body_limit: DEFAULT_BODY_LIMIT,
94        }
95    }
96
97    /// Require a static Bearer token on every request.
98    ///
99    /// Requests without `Authorization: Bearer <token>` will receive `401`.
100    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
101        self.auth_token = Some(token.into());
102        self
103    }
104
105    /// Override the maximum request body size (default: 4 MiB).
106    pub fn with_body_limit(mut self, bytes: usize) -> Self {
107        self.body_limit = bytes;
108        self
109    }
110
111    /// Mount `POST /v1/chat/completions` backed by the given LLM.
112    pub fn with_chat<L>(mut self, llm: L, default_model: impl Into<String>) -> Self
113    where
114        L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync + 'static,
115    {
116        self.router = self.router.merge(chat_router(llm, default_model));
117        self
118    }
119
120    /// Mount `POST /agent/chat` that streams agent events as SSE.
121    pub fn with_agent_stream(mut self, handler: AgentHandler) -> Self {
122        self.router = self.router.merge(agent_router(handler));
123        self
124    }
125
126    /// Merge any additional Axum router.
127    pub fn with_router(mut self, router: Router) -> Self {
128        self.router = self.router.merge(router);
129        self
130    }
131
132    /// Build the finalised Axum [`Router`] with all middleware applied.
133    ///
134    /// Middleware stack (outermost first):
135    /// 1. `RequestBodyLimitLayer` — rejects bodies larger than `body_limit`
136    /// 2. Bearer token auth layer (optional) — 401 if token is missing/wrong
137    pub fn build(self) -> Router {
138        let mut router = self.router;
139
140        // Bearer token auth (innermost — applied before body limit to reject
141        // unauthenticated requests before reading the body).
142        if let Some(token) = self.auth_token {
143            let token_state = Arc::new(token);
144            router = router.route_layer(
145                middleware::from_fn_with_state(token_state, bearer_auth_middleware),
146            );
147        }
148
149        // Body size cap (outermost).
150        router.layer(RequestBodyLimitLayer::new(self.body_limit))
151    }
152
153    /// Start the server and block until it is stopped.
154    pub async fn serve(self, addr: SocketAddr) -> Result<(), WesichainError> {
155        let app = self.build();
156        let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
157            WesichainError::InvalidConfig(format!("Failed to bind {addr}: {e}"))
158        })?;
159
160        axum::serve(listener, app)
161            .await
162            .map_err(|e| WesichainError::Custom(format!("Server error: {e}")))
163    }
164}