winx_code_agent/
http_server.rs1#![allow(clippy::doc_markdown)]
19
20use std::sync::Arc;
21
22use axum::{
23 extract::{Request, State},
24 http::StatusCode,
25 middleware::{self, Next},
26 response::{IntoResponse, Response},
27 Router,
28};
29use rmcp::transport::streamable_http_server::{
30 session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
31};
32
33use crate::server::WinxService;
34
35type BoxError = Box<dyn std::error::Error + Send + Sync>;
36
37pub async fn start_http_server(
45 bind: &str,
46 token: String,
47 extra_hosts: Vec<String>,
48) -> Result<(), BoxError> {
49 if token.trim().is_empty() {
50 return Err("refusing to start HTTP transport without a token (RCE exposure)".into());
51 }
52
53 let mut config = StreamableHttpServerConfig::default();
55 config.stateful_mode = true;
56 config.allowed_hosts.extend(extra_hosts);
57
58 let shared = WinxService::new();
65 let mcp_service = StreamableHttpService::new(
66 move || Ok(shared.clone()),
67 Arc::new(LocalSessionManager::default()),
68 config,
69 );
70
71 let app = Router::new()
72 .nest_service("/mcp", mcp_service)
73 .layer(middleware::from_fn_with_state(Arc::new(token), require_token));
74
75 let listener = tokio::net::TcpListener::bind(bind).await?;
76 tracing::warn!(
77 "winx remote MCP transport on http://{bind}/mcp — shell/file access is now \
78 network-reachable. Keep it behind an HTTPS tunnel and shut it down when done."
79 );
80 axum::serve(listener, app).await?;
81 Ok(())
82}
83
84async fn require_token(State(token): State<Arc<String>>, request: Request, next: Next) -> Response {
86 if request_has_token(&request, &token) {
87 next.run(request).await
88 } else {
89 (StatusCode::UNAUTHORIZED, "missing or invalid token\n").into_response()
90 }
91}
92
93fn request_has_token(request: &Request, expected: &str) -> bool {
95 let header_match = request
96 .headers()
97 .get(axum::http::header::AUTHORIZATION)
98 .and_then(|value| value.to_str().ok())
99 .and_then(|value| value.strip_prefix("Bearer "))
100 .is_some_and(|presented| constant_time_eq(presented.trim(), expected));
101
102 let query_match = request.uri().query().is_some_and(|query| {
103 query
104 .split('&')
105 .filter_map(|pair| pair.split_once('='))
106 .any(|(key, value)| key == "token" && constant_time_eq(value, expected))
107 });
108
109 header_match || query_match
110}
111
112fn constant_time_eq(a: &str, b: &str) -> bool {
114 let (a, b) = (a.as_bytes(), b.as_bytes());
115 if a.len() != b.len() {
116 return false;
117 }
118 a.iter().zip(b).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
119}
120
121#[cfg(test)]
122mod tests {
123 use super::constant_time_eq;
124
125 #[test]
126 fn token_comparison() {
127 assert!(constant_time_eq("s3cret", "s3cret"));
128 assert!(!constant_time_eq("s3cret", "s3creT"));
129 assert!(!constant_time_eq("s3cret", "s3cret-longer"));
130 assert!(!constant_time_eq("", "x"));
131 }
132}