1use std::collections::HashMap;
7
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::Response as HyperResponse;
11use lambda_http::{Body as LambdaBody, Request as LambdaRequest, Response as LambdaResponse};
12use tracing::{debug, trace};
13
14use crate::error::{LambdaError, Result};
15
16type UnifiedMcpBody = http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>;
18
19fn infallible_to_hyper_error(never: std::convert::Infallible) -> hyper::Error {
21 match never {}
22}
23
24type MappedFullBody =
26 http_body_util::combinators::MapErr<Full<Bytes>, fn(std::convert::Infallible) -> hyper::Error>;
27
28pub fn lambda_to_hyper_request(
33 lambda_req: LambdaRequest,
34) -> Result<hyper::Request<MappedFullBody>> {
35 let (parts, lambda_body) = lambda_req.into_parts();
36
37 let body_bytes = match lambda_body {
39 LambdaBody::Empty => Bytes::new(),
40 LambdaBody::Text(s) => Bytes::from(s),
41 LambdaBody::Binary(b) => Bytes::from(b),
42 };
43
44 let full_body = Full::new(body_bytes)
46 .map_err(infallible_to_hyper_error as fn(std::convert::Infallible) -> hyper::Error);
47
48 let hyper_req = hyper::Request::from_parts(parts, full_body);
50
51 debug!(
52 "Converted Lambda request: {} {} -> hyper::Request<Full<Bytes>>",
53 hyper_req.method(),
54 hyper_req.uri()
55 );
56
57 Ok(hyper_req)
58}
59
60pub async fn hyper_to_lambda_response(
65 hyper_resp: HyperResponse<UnifiedMcpBody>,
66) -> Result<LambdaResponse<LambdaBody>> {
67 let (parts, body) = hyper_resp.into_parts();
68
69 let body_bytes = match body.collect().await {
71 Ok(collected) => collected.to_bytes(),
72 Err(err) => {
73 return Err(LambdaError::Body(format!(
74 "Failed to collect response body: {}",
75 err
76 )));
77 }
78 };
79
80 let lambda_body = if body_bytes.is_empty() {
82 LambdaBody::Empty
83 } else {
84 match String::from_utf8(body_bytes.to_vec()) {
86 Ok(text) => LambdaBody::Text(text),
87 Err(_) => LambdaBody::Binary(body_bytes.to_vec()),
88 }
89 };
90
91 let lambda_resp = LambdaResponse::from_parts(parts, lambda_body);
93
94 debug!(
95 "Converted hyper response -> Lambda response (status: {})",
96 lambda_resp.status()
97 );
98
99 Ok(lambda_resp)
100}
101
102pub fn hyper_to_lambda_streaming(
107 hyper_resp: HyperResponse<UnifiedMcpBody>,
108) -> lambda_http::Response<UnifiedMcpBody> {
109 let (parts, body) = hyper_resp.into_parts();
110
111 let lambda_resp = lambda_http::Response::from_parts(parts, body);
113
114 debug!(
115 "Converted hyper response -> Lambda streaming response (status: {})",
116 lambda_resp.status()
117 );
118
119 lambda_resp
120}
121
122pub fn extract_mcp_headers(req: &LambdaRequest) -> HashMap<String, String> {
127 let mut mcp_headers = HashMap::new();
128
129 if let Some(session_id) = req.headers().get("mcp-session-id")
131 && let Ok(session_id_str) = session_id.to_str()
132 {
133 mcp_headers.insert("mcp-session-id".to_string(), session_id_str.to_string());
134 }
135
136 if let Some(protocol_version) = req.headers().get("mcp-protocol-version")
138 && let Ok(version_str) = protocol_version.to_str()
139 {
140 mcp_headers.insert("mcp-protocol-version".to_string(), version_str.to_string());
141 }
142
143 if let Some(last_event_id) = req.headers().get("last-event-id")
145 && let Ok(event_id_str) = last_event_id.to_str()
146 {
147 mcp_headers.insert("last-event-id".to_string(), event_id_str.to_string());
148 }
149
150 trace!("Extracted MCP headers: {:?}", mcp_headers);
151 mcp_headers
152}
153
154pub fn inject_mcp_headers(resp: &mut LambdaResponse<LambdaBody>, headers: HashMap<String, String>) {
158 for (name, value) in headers {
159 if let (Ok(header_name), Ok(header_value)) = (
160 http::HeaderName::from_bytes(name.as_bytes()),
161 http::HeaderValue::from_str(&value),
162 ) {
163 resp.headers_mut().insert(header_name, header_value);
164 debug!("Injected MCP header: {} = {}", name, value);
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use http::{HeaderValue, Method, Request, StatusCode};
173 use http_body_util::Full;
174
175 #[test]
176 fn test_lambda_to_hyper_request_conversion() {
177 let mut lambda_req = Request::builder()
179 .method(Method::POST)
180 .uri("/mcp")
181 .body(LambdaBody::Text(
182 r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#.to_string(),
183 ))
184 .unwrap();
185
186 let headers = lambda_req.headers_mut();
188 headers.insert("content-type", HeaderValue::from_static("application/json"));
189 headers.insert(
190 "mcp-session-id",
191 HeaderValue::from_static("test-session-123"),
192 );
193 headers.insert(
194 "mcp-protocol-version",
195 HeaderValue::from_static("2025-06-18"),
196 );
197
198 let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
200
201 assert_eq!(hyper_req.method(), &Method::POST);
203 assert_eq!(hyper_req.uri().path(), "/mcp");
204
205 assert_eq!(
207 hyper_req.headers().get("content-type").unwrap(),
208 "application/json"
209 );
210 assert_eq!(
211 hyper_req.headers().get("mcp-session-id").unwrap(),
212 "test-session-123"
213 );
214 assert_eq!(
215 hyper_req.headers().get("mcp-protocol-version").unwrap(),
216 "2025-06-18"
217 );
218 }
219
220 #[test]
221 fn test_lambda_to_hyper_empty_body() {
222 let lambda_req = Request::builder()
223 .method(Method::GET)
224 .uri("/sse")
225 .body(LambdaBody::Empty)
226 .unwrap();
227
228 let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
229 assert_eq!(hyper_req.method(), &Method::GET);
230 assert_eq!(hyper_req.uri().path(), "/sse");
231 }
232
233 #[test]
234 fn test_lambda_to_hyper_binary_body() {
235 let test_data = vec![0x48, 0x65, 0x6c, 0x6c, 0x6f]; let lambda_req = Request::builder()
237 .method(Method::POST)
238 .uri("/binary")
239 .body(LambdaBody::Binary(test_data.clone()))
240 .unwrap();
241
242 let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
243 assert_eq!(hyper_req.method(), &Method::POST);
244 assert_eq!(hyper_req.uri().path(), "/binary");
245 }
246
247 #[tokio::test]
248 async fn test_hyper_to_lambda_response_conversion() {
249 let json_body = r#"{"jsonrpc":"2.0","id":1,"result":{"capabilities":{}}}"#;
251 let full_body = Full::new(Bytes::from(json_body));
252 let boxed_body = full_body.map_err(|never| match never {}).boxed_unsync();
253
254 let hyper_resp = hyper::Response::builder()
255 .status(StatusCode::OK)
256 .header("content-type", "application/json")
257 .header("mcp-session-id", "resp-session-456")
258 .body(boxed_body)
259 .unwrap();
260
261 let lambda_resp = hyper_to_lambda_response(hyper_resp).await.unwrap();
263
264 assert_eq!(lambda_resp.status(), StatusCode::OK);
266 assert_eq!(
267 lambda_resp.headers().get("content-type").unwrap(),
268 "application/json"
269 );
270 assert_eq!(
271 lambda_resp.headers().get("mcp-session-id").unwrap(),
272 "resp-session-456"
273 );
274
275 match lambda_resp.body() {
277 LambdaBody::Text(text) => assert_eq!(text, json_body),
278 _ => panic!("Expected text body"),
279 }
280 }
281
282 #[tokio::test]
283 async fn test_hyper_to_lambda_empty_response() {
284 let empty_body = Full::new(Bytes::new());
285 let boxed_body = empty_body.map_err(|never| match never {}).boxed_unsync();
286
287 let hyper_resp = hyper::Response::builder()
288 .status(StatusCode::NO_CONTENT)
289 .body(boxed_body)
290 .unwrap();
291
292 let lambda_resp = hyper_to_lambda_response(hyper_resp).await.unwrap();
293
294 assert_eq!(lambda_resp.status(), StatusCode::NO_CONTENT);
295 match lambda_resp.body() {
296 LambdaBody::Empty => {} _ => panic!("Expected empty body"),
298 }
299 }
300
301 #[test]
302 fn test_hyper_to_lambda_streaming() {
303 let stream_body = Full::new(Bytes::from("data: test\n\n"));
305 let boxed_body = stream_body.map_err(|never| match never {}).boxed_unsync();
306
307 let hyper_resp = hyper::Response::builder()
308 .status(StatusCode::OK)
309 .header("content-type", "text/event-stream")
310 .header("cache-control", "no-cache")
311 .body(boxed_body)
312 .unwrap();
313
314 let lambda_resp = hyper_to_lambda_streaming(hyper_resp);
316
317 assert_eq!(lambda_resp.status(), StatusCode::OK);
318 assert_eq!(
319 lambda_resp.headers().get("content-type").unwrap(),
320 "text/event-stream"
321 );
322 assert_eq!(
323 lambda_resp.headers().get("cache-control").unwrap(),
324 "no-cache"
325 );
326 }
328
329 #[tokio::test]
330 async fn test_mcp_headers_extraction() {
331 use http::{HeaderValue, Request};
332
333 let mut request = Request::builder()
335 .method("POST")
336 .uri("/mcp")
337 .body(LambdaBody::Empty)
338 .unwrap();
339
340 let headers = request.headers_mut();
341 headers.insert("mcp-session-id", HeaderValue::from_static("sess-123"));
342 headers.insert(
343 "mcp-protocol-version",
344 HeaderValue::from_static("2025-06-18"),
345 );
346 headers.insert("last-event-id", HeaderValue::from_static("event-456"));
347
348 let mcp_headers = extract_mcp_headers(&request);
349
350 assert_eq!(
351 mcp_headers.get("mcp-session-id"),
352 Some(&"sess-123".to_string())
353 );
354 assert_eq!(
355 mcp_headers.get("mcp-protocol-version"),
356 Some(&"2025-06-18".to_string())
357 );
358 assert_eq!(
359 mcp_headers.get("last-event-id"),
360 Some(&"event-456".to_string())
361 );
362 }
363
364 #[tokio::test]
365 async fn test_mcp_headers_injection() {
366 use lambda_http::Body;
367
368 let mut lambda_resp = LambdaResponse::builder()
369 .status(200)
370 .body(Body::Empty)
371 .unwrap();
372
373 let mut headers = HashMap::new();
374 headers.insert("mcp-session-id".to_string(), "sess-789".to_string());
375 headers.insert("mcp-protocol-version".to_string(), "2025-06-18".to_string());
376
377 inject_mcp_headers(&mut lambda_resp, headers);
378
379 assert_eq!(
380 lambda_resp.headers().get("mcp-session-id").unwrap(),
381 "sess-789"
382 );
383 assert_eq!(
384 lambda_resp.headers().get("mcp-protocol-version").unwrap(),
385 "2025-06-18"
386 );
387 }
388}