turul_mcp_aws_lambda/
adapter.rs

1//! HTTP type conversion utilities for Lambda MCP requests
2//!
3//! This module provides comprehensive conversion between lambda_http and hyper types,
4//! enabling seamless integration between Lambda's HTTP model and the SessionMcpHandler.
5
6use std::collections::HashMap;
7
8use bytes::Bytes;
9use http_body_util::{BodyExt, Full};
10use hyper::Response as HyperResponse;
11use lambda_http::{Body as LambdaBody, Request as LambdaRequest, Response as LambdaResponse};
12use tracing::{debug, trace};
13
14use crate::error::{LambdaError, Result};
15
16/// Type alias for the unified MCP response body used by SessionMcpHandler
17type UnifiedMcpBody = http_body_util::combinators::UnsyncBoxBody<Bytes, hyper::Error>;
18
19/// Error mapping function for Full<Bytes>
20fn infallible_to_hyper_error(never: std::convert::Infallible) -> hyper::Error {
21    match never {}
22}
23
24/// Type alias for Full<Bytes> with mapped error type compatible with SessionMcpHandler
25type MappedFullBody =
26    http_body_util::combinators::MapErr<Full<Bytes>, fn(std::convert::Infallible) -> hyper::Error>;
27
28/// Convert lambda_http::Request to hyper::Request<MappedFullBody>
29///
30/// This enables delegation to SessionMcpHandler by converting Lambda's request format
31/// to the hyper format expected by the framework. All headers are preserved.
32pub fn lambda_to_hyper_request(
33    lambda_req: LambdaRequest,
34) -> Result<hyper::Request<MappedFullBody>> {
35    let (parts, lambda_body) = lambda_req.into_parts();
36
37    // Convert LambdaBody to Bytes
38    let body_bytes = match lambda_body {
39        LambdaBody::Empty => Bytes::new(),
40        LambdaBody::Text(s) => Bytes::from(s),
41        LambdaBody::Binary(b) => Bytes::from(b),
42    };
43
44    // Create Full<Bytes> body and map error type to hyper::Error
45    let full_body = Full::new(body_bytes)
46        .map_err(infallible_to_hyper_error as fn(std::convert::Infallible) -> hyper::Error);
47
48    // Create hyper Request with preserved headers and new body type
49    let hyper_req = hyper::Request::from_parts(parts, full_body);
50
51    debug!(
52        "Converted Lambda request: {} {} -> hyper::Request<Full<Bytes>>",
53        hyper_req.method(),
54        hyper_req.uri()
55    );
56
57    Ok(hyper_req)
58}
59
60/// Convert hyper::Response<UnifiedMcpBody> to lambda_http::Response<LambdaBody>
61///
62/// This collects the streaming body into a LambdaBody for non-streaming responses.
63/// Used by the handle() method which returns snapshot responses.
64pub async fn hyper_to_lambda_response(
65    hyper_resp: HyperResponse<UnifiedMcpBody>,
66) -> Result<LambdaResponse<LambdaBody>> {
67    let (parts, body) = hyper_resp.into_parts();
68
69    // Collect the body into bytes
70    let body_bytes = match body.collect().await {
71        Ok(collected) => collected.to_bytes(),
72        Err(err) => {
73            return Err(LambdaError::Body(format!(
74                "Failed to collect response body: {}",
75                err
76            )));
77        }
78    };
79
80    // Convert to LambdaBody
81    let lambda_body = if body_bytes.is_empty() {
82        LambdaBody::Empty
83    } else {
84        // Try to convert to text if it's valid UTF-8, otherwise use binary
85        match String::from_utf8(body_bytes.to_vec()) {
86            Ok(text) => LambdaBody::Text(text),
87            Err(_) => LambdaBody::Binary(body_bytes.to_vec()),
88        }
89    };
90
91    // Create Lambda response with preserved headers
92    let lambda_resp = LambdaResponse::from_parts(parts, lambda_body);
93
94    debug!(
95        "Converted hyper response -> Lambda response (status: {})",
96        lambda_resp.status()
97    );
98
99    Ok(lambda_resp)
100}
101
102/// Convert hyper::Response<UnifiedMcpBody> to lambda_http streaming response
103///
104/// This preserves the streaming body for real-time SSE responses.
105/// Used by the handle_streaming() method for true streaming.
106pub fn hyper_to_lambda_streaming(
107    hyper_resp: HyperResponse<UnifiedMcpBody>,
108) -> lambda_http::Response<UnifiedMcpBody> {
109    let (parts, body) = hyper_resp.into_parts();
110
111    // Direct passthrough - no body collection, preserves streaming
112    let lambda_resp = lambda_http::Response::from_parts(parts, body);
113
114    debug!(
115        "Converted hyper response -> Lambda streaming response (status: {})",
116        lambda_resp.status()
117    );
118
119    lambda_resp
120}
121
122/// Extract MCP-specific headers from Lambda request context
123///
124/// Lambda requests may have additional context that needs to be preserved
125/// for proper MCP protocol handling.
126pub fn extract_mcp_headers(req: &LambdaRequest) -> HashMap<String, String> {
127    let mut mcp_headers = HashMap::new();
128
129    // Extract session ID from headers
130    if let Some(session_id) = req.headers().get("mcp-session-id")
131        && let Ok(session_id_str) = session_id.to_str()
132    {
133        mcp_headers.insert("mcp-session-id".to_string(), session_id_str.to_string());
134    }
135
136    // Extract protocol version
137    if let Some(protocol_version) = req.headers().get("mcp-protocol-version")
138        && let Ok(version_str) = protocol_version.to_str()
139    {
140        mcp_headers.insert("mcp-protocol-version".to_string(), version_str.to_string());
141    }
142
143    // Extract Last-Event-ID for SSE resumability
144    if let Some(last_event_id) = req.headers().get("last-event-id")
145        && let Ok(event_id_str) = last_event_id.to_str()
146    {
147        mcp_headers.insert("last-event-id".to_string(), event_id_str.to_string());
148    }
149
150    trace!("Extracted MCP headers: {:?}", mcp_headers);
151    mcp_headers
152}
153
154/// Add MCP-specific headers to Lambda response
155///
156/// Ensures proper MCP protocol headers are included in the response.
157pub fn inject_mcp_headers(resp: &mut LambdaResponse<LambdaBody>, headers: HashMap<String, String>) {
158    for (name, value) in headers {
159        if let (Ok(header_name), Ok(header_value)) = (
160            http::HeaderName::from_bytes(name.as_bytes()),
161            http::HeaderValue::from_str(&value),
162        ) {
163            resp.headers_mut().insert(header_name, header_value);
164            debug!("Injected MCP header: {} = {}", name, value);
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use http::{HeaderValue, Method, Request, StatusCode};
173    use http_body_util::Full;
174
175    #[test]
176    fn test_lambda_to_hyper_request_conversion() {
177        // Create a test Lambda request with headers and body
178        let mut lambda_req = Request::builder()
179            .method(Method::POST)
180            .uri("/mcp")
181            .body(LambdaBody::Text(
182                r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#.to_string(),
183            ))
184            .unwrap();
185
186        // Add MCP headers
187        let headers = lambda_req.headers_mut();
188        headers.insert("content-type", HeaderValue::from_static("application/json"));
189        headers.insert(
190            "mcp-session-id",
191            HeaderValue::from_static("test-session-123"),
192        );
193        headers.insert(
194            "mcp-protocol-version",
195            HeaderValue::from_static("2025-06-18"),
196        );
197
198        // Test the conversion
199        let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
200
201        // Verify method and URI are preserved
202        assert_eq!(hyper_req.method(), &Method::POST);
203        assert_eq!(hyper_req.uri().path(), "/mcp");
204
205        // Verify headers are preserved
206        assert_eq!(
207            hyper_req.headers().get("content-type").unwrap(),
208            "application/json"
209        );
210        assert_eq!(
211            hyper_req.headers().get("mcp-session-id").unwrap(),
212            "test-session-123"
213        );
214        assert_eq!(
215            hyper_req.headers().get("mcp-protocol-version").unwrap(),
216            "2025-06-18"
217        );
218    }
219
220    #[test]
221    fn test_lambda_to_hyper_empty_body() {
222        let lambda_req = Request::builder()
223            .method(Method::GET)
224            .uri("/sse")
225            .body(LambdaBody::Empty)
226            .unwrap();
227
228        let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
229        assert_eq!(hyper_req.method(), &Method::GET);
230        assert_eq!(hyper_req.uri().path(), "/sse");
231    }
232
233    #[test]
234    fn test_lambda_to_hyper_binary_body() {
235        let test_data = vec![0x48, 0x65, 0x6c, 0x6c, 0x6f]; // "Hello" in bytes
236        let lambda_req = Request::builder()
237            .method(Method::POST)
238            .uri("/binary")
239            .body(LambdaBody::Binary(test_data.clone()))
240            .unwrap();
241
242        let hyper_req = lambda_to_hyper_request(lambda_req).unwrap();
243        assert_eq!(hyper_req.method(), &Method::POST);
244        assert_eq!(hyper_req.uri().path(), "/binary");
245    }
246
247    #[tokio::test]
248    async fn test_hyper_to_lambda_response_conversion() {
249        // Create a test hyper response
250        let json_body = r#"{"jsonrpc":"2.0","id":1,"result":{"capabilities":{}}}"#;
251        let full_body = Full::new(Bytes::from(json_body));
252        let boxed_body = full_body.map_err(|never| match never {}).boxed_unsync();
253
254        let hyper_resp = hyper::Response::builder()
255            .status(StatusCode::OK)
256            .header("content-type", "application/json")
257            .header("mcp-session-id", "resp-session-456")
258            .body(boxed_body)
259            .unwrap();
260
261        // Test the conversion
262        let lambda_resp = hyper_to_lambda_response(hyper_resp).await.unwrap();
263
264        // Verify status and headers are preserved
265        assert_eq!(lambda_resp.status(), StatusCode::OK);
266        assert_eq!(
267            lambda_resp.headers().get("content-type").unwrap(),
268            "application/json"
269        );
270        assert_eq!(
271            lambda_resp.headers().get("mcp-session-id").unwrap(),
272            "resp-session-456"
273        );
274
275        // Verify body is converted to text
276        match lambda_resp.body() {
277            LambdaBody::Text(text) => assert_eq!(text, json_body),
278            _ => panic!("Expected text body"),
279        }
280    }
281
282    #[tokio::test]
283    async fn test_hyper_to_lambda_empty_response() {
284        let empty_body = Full::new(Bytes::new());
285        let boxed_body = empty_body.map_err(|never| match never {}).boxed_unsync();
286
287        let hyper_resp = hyper::Response::builder()
288            .status(StatusCode::NO_CONTENT)
289            .body(boxed_body)
290            .unwrap();
291
292        let lambda_resp = hyper_to_lambda_response(hyper_resp).await.unwrap();
293
294        assert_eq!(lambda_resp.status(), StatusCode::NO_CONTENT);
295        match lambda_resp.body() {
296            LambdaBody::Empty => {} // Expected
297            _ => panic!("Expected empty body"),
298        }
299    }
300
301    #[test]
302    fn test_hyper_to_lambda_streaming() {
303        // Create a streaming response
304        let stream_body = Full::new(Bytes::from("data: test\n\n"));
305        let boxed_body = stream_body.map_err(|never| match never {}).boxed_unsync();
306
307        let hyper_resp = hyper::Response::builder()
308            .status(StatusCode::OK)
309            .header("content-type", "text/event-stream")
310            .header("cache-control", "no-cache")
311            .body(boxed_body)
312            .unwrap();
313
314        // Test streaming conversion (should preserve body as-is)
315        let lambda_resp = hyper_to_lambda_streaming(hyper_resp);
316
317        assert_eq!(lambda_resp.status(), StatusCode::OK);
318        assert_eq!(
319            lambda_resp.headers().get("content-type").unwrap(),
320            "text/event-stream"
321        );
322        assert_eq!(
323            lambda_resp.headers().get("cache-control").unwrap(),
324            "no-cache"
325        );
326        // Body should be preserved as UnifiedMcpBody for streaming
327    }
328
329    #[tokio::test]
330    async fn test_mcp_headers_extraction() {
331        use http::{HeaderValue, Request};
332
333        // Create a test request with MCP headers
334        let mut request = Request::builder()
335            .method("POST")
336            .uri("/mcp")
337            .body(LambdaBody::Empty)
338            .unwrap();
339
340        let headers = request.headers_mut();
341        headers.insert("mcp-session-id", HeaderValue::from_static("sess-123"));
342        headers.insert(
343            "mcp-protocol-version",
344            HeaderValue::from_static("2025-06-18"),
345        );
346        headers.insert("last-event-id", HeaderValue::from_static("event-456"));
347
348        let mcp_headers = extract_mcp_headers(&request);
349
350        assert_eq!(
351            mcp_headers.get("mcp-session-id"),
352            Some(&"sess-123".to_string())
353        );
354        assert_eq!(
355            mcp_headers.get("mcp-protocol-version"),
356            Some(&"2025-06-18".to_string())
357        );
358        assert_eq!(
359            mcp_headers.get("last-event-id"),
360            Some(&"event-456".to_string())
361        );
362    }
363
364    #[tokio::test]
365    async fn test_mcp_headers_injection() {
366        use lambda_http::Body;
367
368        let mut lambda_resp = LambdaResponse::builder()
369            .status(200)
370            .body(Body::Empty)
371            .unwrap();
372
373        let mut headers = HashMap::new();
374        headers.insert("mcp-session-id".to_string(), "sess-789".to_string());
375        headers.insert("mcp-protocol-version".to_string(), "2025-06-18".to_string());
376
377        inject_mcp_headers(&mut lambda_resp, headers);
378
379        assert_eq!(
380            lambda_resp.headers().get("mcp-session-id").unwrap(),
381            "sess-789"
382        );
383        assert_eq!(
384            lambda_resp.headers().get("mcp-protocol-version").unwrap(),
385            "2025-06-18"
386        );
387    }
388}