Skip to main content

turul_mcp_aws_lambda/
lib.rs

1//! AWS Lambda integration for turul-mcp-framework
2//!
3//! This crate provides seamless integration between the turul-mcp-framework and AWS Lambda,
4//! enabling serverless deployment of MCP servers with proper session management, CORS handling,
5//! and SSE streaming support.
6//!
7//! ## Architecture
8//!
9//! The crate bridges the gap between Lambda's HTTP execution model and the framework's
10//! hyper-based architecture through:
11//!
12//! - **Type Conversion**: Clean conversion between `lambda_http` and `hyper` types
13//! - **Handler Registration**: Direct tool registration with `JsonRpcDispatcher`
14//! - **Session Management**: DynamoDB-backed session persistence across invocations
15//! - **CORS Support**: Proper CORS header injection for browser clients
16//! - **SSE Streaming**: Server-Sent Events adaptation through Lambda's streaming response
17//!
18//! ## Quick Start
19//!
20//! ```rust,no_run
21//! use turul_mcp_aws_lambda::LambdaMcpServerBuilder;
22//! use turul_mcp_derive::McpTool;
23//! use turul_mcp_server::{McpResult, SessionContext};
24//!
25//! #[derive(McpTool, Clone, Default)]
26//! #[tool(name = "example", description = "Example tool")]
27//! struct ExampleTool {
28//!     #[param(description = "Example parameter")]
29//!     value: String,
30//! }
31//!
32//! impl ExampleTool {
33//!     async fn execute(&self, _session: Option<SessionContext>) -> McpResult<String> {
34//!         Ok(format!("Got: {}", self.value))
35//!     }
36//! }
37//!
38//! #[tokio::main]
39//! async fn main() -> Result<(), lambda_http::Error> {
40//!     let server = LambdaMcpServerBuilder::new()
41//!         .tool(ExampleTool::default())
42//!         .cors_allow_all_origins()
43//!         .build()
44//!         .await?;
45//!
46//!     let handler = server.handler().await?;
47//!
48//!     // run_streaming handles API Gateway completion invocations gracefully
49//!     turul_mcp_aws_lambda::run_streaming(handler).await
50//! }
51//! ```
52//!
53//! ## Streaming Entry Points
54//!
55//! Two entry points replace `lambda_http::run_with_streaming_response()`:
56//!
57//! - [`run_streaming()`] — standard path: pass a [`LambdaMcpHandler`] directly
58//! - [`run_streaming_with()`] — custom dispatch: provide your own closure for
59//!   pre-dispatch logic (e.g., `.well-known` routing) while still getting
60//!   completion-invocation handling for free
61
62pub mod adapter;
63pub mod builder;
64pub mod error;
65pub mod handler;
66pub mod prelude;
67pub mod server;
68
69#[cfg(feature = "cors")]
70pub mod cors;
71
72#[cfg(feature = "sse")]
73pub mod streaming;
74
75// Re-exports for convenience
76/// Builder for creating Lambda MCP servers with fluent configuration API
77pub use builder::LambdaMcpServerBuilder;
78/// Lambda-specific error types and result aliases
79pub use error::{LambdaError, Result};
80/// Lambda request handler with session management and protocol conversion
81pub use handler::LambdaMcpHandler;
82/// Core Lambda MCP server implementation with DynamoDB integration
83pub use server::LambdaMcpServer;
84
85#[cfg(feature = "cors")]
86pub use cors::CorsConfig;
87
88/// Classification of a raw Lambda runtime event payload.
89///
90/// Used by [`run_streaming()`] to distinguish API Gateway requests from
91/// streaming completion invocations and unknown event shapes.
92#[derive(Debug)]
93enum RuntimeEventClassification {
94    /// Valid API Gateway / ALB / Function URL event.
95    ///
96    /// Stores `Box<LambdaRequest>` to avoid a large enum variant (clippy
97    /// `large_enum_variant`). Callers dereference with `(*lambda_request).into()`
98    /// to move the inner `LambdaRequest` into an `http::Request`.
99    ApiGatewayEvent(Box<lambda_http::request::LambdaRequest>),
100    /// AWS streaming completion invocation (contains `invokeCompletionStatus`)
101    StreamingCompletion,
102    /// Unrecognized payload — not API Gateway, not completion
103    UnrecognizedEvent,
104}
105
106/// Classify a raw JSON payload into one of three categories.
107///
108/// Order matters:
109/// 1. Try API Gateway/ALB/WebSocket deserialization first (most common path)
110/// 2. Check for streaming completion signature (`invokeCompletionStatus` at top level)
111/// 3. Everything else is unrecognized
112///
113/// # Completion Detection Heuristic
114///
115/// Streaming completion payloads are identified by the presence of an
116/// `invokeCompletionStatus` field at the top level. This is a compatibility
117/// heuristic based on observed AWS behavior as of 2026-03 — AWS does not
118/// officially document this payload shape.
119///
120/// As of this writing, no API Gateway v1/v2, ALB, or WebSocket event
121/// produced by `lambda_http` contains this field at the top level.
122/// The fixture corpus in `src/fixtures/` guards against drift.
123///
124/// **Precedence**: API Gateway deserialization is attempted first.
125/// If a payload is both a valid API Gateway event AND contains
126/// `invokeCompletionStatus`, it will be classified as `ApiGatewayEvent`.
127/// The completion heuristic only applies to payloads that fail API
128/// Gateway parsing. This means false-positive completion detection is
129/// preferred over retry storms — an intentional design choice.
130fn classify_runtime_event(payload: serde_json::Value) -> RuntimeEventClassification {
131    // Fast path: try API Gateway event deserialization
132    if let Ok(request) =
133        serde_json::from_value::<lambda_http::request::LambdaRequest>(payload.clone())
134    {
135        return RuntimeEventClassification::ApiGatewayEvent(Box::new(request));
136    }
137
138    // Check for streaming completion signature
139    if payload.get("invokeCompletionStatus").is_some() {
140        return RuntimeEventClassification::StreamingCompletion;
141    }
142
143    // Unknown payload shape
144    RuntimeEventClassification::UnrecognizedEvent
145}
146
147type StreamBody = http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>;
148type StreamResult = lambda_runtime::StreamResponse<http_body_util::BodyDataStream<StreamBody>>;
149
150/// Result of [`handle_runtime_payload()`], carrying both the Lambda response
151/// and a static string identifying the event type for logging/observability.
152struct HandleResult {
153    response: StreamResult,
154    /// One of `"api_gateway_event"`, `"streaming_completion"`, or
155    /// `"unrecognized_lambda_payload"`.
156    event_type: &'static str,
157}
158
159/// Process a raw Lambda runtime payload into a streaming response.
160///
161/// Classifies the payload via [`classify_runtime_event()`], dispatches API
162/// Gateway events through `dispatch`, and acknowledges non-API payloads with
163/// an empty 200 response. Returns a [`HandleResult`] so the caller can
164/// inspect `event_type` for logging decisions.
165async fn handle_runtime_payload<F, Fut>(
166    payload: serde_json::Value,
167    context: lambda_runtime::Context,
168    dispatch: F,
169) -> std::result::Result<HandleResult, lambda_http::Error>
170where
171    F: FnOnce(lambda_http::Request) -> Fut,
172    Fut: std::future::Future<
173            Output = std::result::Result<http::Response<StreamBody>, lambda_http::Error>,
174        >,
175{
176    match classify_runtime_event(payload) {
177        RuntimeEventClassification::ApiGatewayEvent(lambda_request) => {
178            use lambda_http::RequestExt;
179            let request: lambda_http::Request = (*lambda_request).into();
180            let request = request.with_lambda_context(context);
181            let response = dispatch(request).await?;
182            Ok(HandleResult {
183                response: into_lambda_stream_response(response),
184                event_type: "api_gateway_event",
185            })
186        }
187        RuntimeEventClassification::StreamingCompletion => Ok(HandleResult {
188            response: into_lambda_stream_response(empty_streaming_response()),
189            event_type: "streaming_completion",
190        }),
191        RuntimeEventClassification::UnrecognizedEvent => Ok(HandleResult {
192            response: into_lambda_stream_response(empty_streaming_response()),
193            event_type: "unrecognized_lambda_payload",
194        }),
195    }
196}
197
198/// Map an event type string to the appropriate tracing log level.
199///
200/// Returns `Some(Level::WARN)` for unrecognized payloads (surfaced in
201/// CloudWatch), `Some(Level::DEBUG)` for completion acks (normally silent),
202/// and `None` for API Gateway events (no extra logging needed).
203fn event_log_level(event_type: &str) -> Option<tracing::Level> {
204    match event_type {
205        "streaming_completion" => Some(tracing::Level::DEBUG),
206        "unrecognized_lambda_payload" => Some(tracing::Level::WARN),
207        _ => None,
208    }
209}
210
211/// Run the Lambda MCP handler with streaming response support.
212///
213/// This replaces `lambda_http::run_with_streaming_response(service_fn(...))` and
214/// gracefully handles API Gateway streaming completion invocations that would
215/// otherwise cause deserialization errors in the Lambda runtime.
216///
217/// ## Problem
218///
219/// When API Gateway uses `response-streaming-invocations`, it sends a secondary
220/// "completion" invocation after the streaming response finishes. This invocation
221/// is NOT an API Gateway proxy event — `lambda_http` cannot deserialize it, causing
222/// ERROR logs and CloudWatch Lambda Error metrics for every streaming response.
223///
224/// ## Solution
225///
226/// This function uses `lambda_runtime::run()` directly with `serde_json::Value`
227/// (which always deserializes), then classifies the payload three ways via
228/// [`classify_runtime_event()`]:
229///
230/// - **`ApiGatewayEvent`** — dispatched to the handler normally
231/// - **`StreamingCompletion`** — acknowledged silently (`debug` log)
232/// - **`UnrecognizedEvent`** — acknowledged with a `warn` log to surface
233///   unexpected payload shapes in CloudWatch without triggering Lambda retries
234///
235/// ## Usage
236///
237/// ```rust,ignore
238/// let handler = server.handler().await?;
239/// turul_mcp_aws_lambda::run_streaming(handler).await
240/// ```
241pub async fn run_streaming(
242    handler: LambdaMcpHandler,
243) -> std::result::Result<(), lambda_http::Error> {
244    use lambda_runtime::{LambdaEvent, service_fn};
245
246    lambda_runtime::run(service_fn(move |event: LambdaEvent<serde_json::Value>| {
247        let handler = handler.clone();
248        async move {
249            let result = handle_runtime_payload(event.payload, event.context, |req| {
250                handler.handle_streaming(req)
251            })
252            .await?;
253
254            match event_log_level(result.event_type) {
255                Some(level) if level == tracing::Level::WARN => {
256                    tracing::warn!(
257                        event_type = result.event_type,
258                        "Received unrecognized Lambda invocation payload"
259                    );
260                }
261                Some(_) => {
262                    tracing::debug!(
263                        event_type = result.event_type,
264                        "Acknowledging streaming completion"
265                    );
266                }
267                None => {}
268            }
269
270            Ok::<_, lambda_http::Error>(result.response)
271        }
272    }))
273    .await
274}
275
276/// Run a custom dispatch function with streaming response support.
277///
278/// Like [`run_streaming()`], but accepts a custom dispatch closure instead of
279/// a [`LambdaMcpHandler`]. Use this when you need pre-dispatch logic
280/// (e.g., `.well-known` routing) that runs before the MCP handler.
281///
282/// The dispatch closure is called once per API Gateway invocation. Streaming
283/// completion invocations and unrecognized payloads are acknowledged
284/// automatically without invoking the closure.
285///
286/// ## Usage
287///
288/// ```rust,ignore
289/// async fn lambda_handler(request: Request) -> Result<Response, Error> {
290///     // pre-dispatch logic here (e.g., .well-known short-circuit)
291///     let handler = HANDLER.get_or_try_init(|| async { ... }).await?;
292///     handler.handle_streaming(request).await
293/// }
294///
295/// turul_mcp_aws_lambda::run_streaming_with(lambda_handler).await
296/// ```
297pub async fn run_streaming_with<F, Fut>(dispatch: F) -> std::result::Result<(), lambda_http::Error>
298where
299    F: Fn(lambda_http::Request) -> Fut + Clone + Send + 'static,
300    Fut: std::future::Future<
301            Output = std::result::Result<http::Response<StreamBody>, lambda_http::Error>,
302        > + Send,
303{
304    use lambda_runtime::{LambdaEvent, service_fn};
305
306    lambda_runtime::run(service_fn(move |event: LambdaEvent<serde_json::Value>| {
307        let dispatch = dispatch.clone();
308        async move {
309            let result = handle_runtime_payload(event.payload, event.context, dispatch).await?;
310
311            match event_log_level(result.event_type) {
312                Some(level) if level == tracing::Level::WARN => {
313                    tracing::warn!(
314                        event_type = result.event_type,
315                        "Received unrecognized Lambda invocation payload"
316                    );
317                }
318                Some(_) => {
319                    tracing::debug!(
320                        event_type = result.event_type,
321                        "Acknowledging streaming completion"
322                    );
323                }
324                None => {}
325            }
326
327            Ok::<_, lambda_http::Error>(result.response)
328        }
329    }))
330    .await
331}
332
333/// Convert an HTTP response into a Lambda `StreamResponse`.
334///
335/// Replicates `lambda_http::streaming::into_stream_response` (which is private)
336/// by extracting status/headers/cookies into `MetadataPrelude` and converting
337/// the body into a `Stream`.
338fn into_lambda_stream_response<B>(
339    response: http::Response<B>,
340) -> lambda_runtime::StreamResponse<http_body_util::BodyDataStream<B>>
341where
342    B: http_body::Body + Unpin + Send + 'static,
343{
344    let (parts, body) = response.into_parts();
345    let mut headers = parts.headers;
346
347    // Extract Set-Cookie headers into the cookies vec (Lambda streaming protocol)
348    let cookies = headers
349        .get_all(http::header::SET_COOKIE)
350        .iter()
351        .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
352        .collect::<Vec<_>>();
353    headers.remove(http::header::SET_COOKIE);
354
355    lambda_runtime::StreamResponse {
356        metadata_prelude: lambda_runtime::MetadataPrelude {
357            headers,
358            status_code: parts.status,
359            cookies,
360        },
361        stream: http_body_util::BodyDataStream::new(body),
362    }
363}
364
365/// Build an empty 200 response for acknowledging completion invocations.
366fn empty_streaming_response()
367-> http::Response<http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>> {
368    use http_body_util::{BodyExt, Full};
369    let body = Full::new(bytes::Bytes::new())
370        .map_err(|e: std::convert::Infallible| match e {})
371        .boxed_unsync();
372    http::Response::builder().status(200).body(body).unwrap()
373}
374
375#[cfg(test)]
376mod streaming_completion_tests {
377    use super::*;
378    use serde_json::json;
379
380    /// Load a test fixture from `src/fixtures/` via `include_str!`.
381    /// Compile-time verified — missing files cause a build error.
382    fn load_fixture(name: &str) -> serde_json::Value {
383        let json_str = match name {
384            "apigw_v1" => include_str!("fixtures/apigw_v1_proxy_event.json"),
385            "apigw_v2" => include_str!("fixtures/apigw_v2_http_api_event.json"),
386            "completion_success" => include_str!("fixtures/streaming_completion_success.json"),
387            "completion_failure" => include_str!("fixtures/streaming_completion_failure.json"),
388            "completion_extra" => include_str!("fixtures/streaming_completion_extra_fields.json"),
389            "completion_api_like" => {
390                include_str!("fixtures/completion_with_api_like_fields.json")
391            }
392            other => panic!("Unknown fixture: {other}"),
393        };
394        serde_json::from_str(json_str).unwrap_or_else(|e| panic!("Bad fixture {name}: {e}"))
395    }
396
397    // ── Fixture tests: API Gateway events → ApiGatewayEvent ──
398
399    #[test]
400    fn test_classify_api_gateway_v1_event() {
401        let payload = load_fixture("apigw_v1");
402        assert!(
403            matches!(
404                classify_runtime_event(payload),
405                RuntimeEventClassification::ApiGatewayEvent(_)
406            ),
407            "API Gateway v1 proxy event must classify as ApiGatewayEvent"
408        );
409    }
410
411    #[test]
412    fn test_classify_api_gateway_v2_event() {
413        let payload = load_fixture("apigw_v2");
414        assert!(
415            matches!(
416                classify_runtime_event(payload),
417                RuntimeEventClassification::ApiGatewayEvent(_)
418            ),
419            "API Gateway v2 HTTP API event must classify as ApiGatewayEvent"
420        );
421    }
422
423    // ── Fixture tests: Streaming completion → StreamingCompletion ──
424
425    #[test]
426    fn test_classify_streaming_completion() {
427        let payload = load_fixture("completion_success");
428        assert!(matches!(
429            classify_runtime_event(payload),
430            RuntimeEventClassification::StreamingCompletion
431        ));
432    }
433
434    #[test]
435    fn test_classify_completion_failure_status() {
436        let payload = load_fixture("completion_failure");
437        assert!(matches!(
438            classify_runtime_event(payload),
439            RuntimeEventClassification::StreamingCompletion
440        ));
441    }
442
443    #[test]
444    fn test_classify_completion_extra_fields() {
445        let payload = load_fixture("completion_extra");
446        assert!(matches!(
447            classify_runtime_event(payload),
448            RuntimeEventClassification::StreamingCompletion
449        ));
450    }
451
452    // ── R5: Precedence edge case ──
453
454    #[test]
455    fn test_classify_completion_with_api_like_fields() {
456        // Intentional: prefer false-positive ack over retries.
457        // A payload with invokeCompletionStatus + partial API Gateway fields
458        // is classified as StreamingCompletion (not UnrecognizedEvent),
459        // because invokeCompletionStatus is the discriminator.
460        //
461        // NOTE: This fixture is intentionally NOT a valid API Gateway event.
462        // Do not "fix" it into one — that would change the expected classification.
463        let payload = load_fixture("completion_api_like");
464        assert!(matches!(
465            classify_runtime_event(payload),
466            RuntimeEventClassification::StreamingCompletion
467        ));
468    }
469
470    // ── Inline tests: Unrecognized payloads → UnrecognizedEvent ──
471
472    #[test]
473    fn test_classify_empty_object() {
474        assert!(matches!(
475            classify_runtime_event(json!({})),
476            RuntimeEventClassification::UnrecognizedEvent
477        ));
478    }
479
480    #[test]
481    fn test_classify_random_object() {
482        assert!(matches!(
483            classify_runtime_event(json!({"foo": "bar", "baz": 123})),
484            RuntimeEventClassification::UnrecognizedEvent
485        ));
486    }
487
488    #[test]
489    fn test_classify_null_payload() {
490        assert!(matches!(
491            classify_runtime_event(json!(null)),
492            RuntimeEventClassification::UnrecognizedEvent
493        ));
494    }
495
496    #[test]
497    fn test_classify_string_payload() {
498        assert!(matches!(
499            classify_runtime_event(json!("hello")),
500            RuntimeEventClassification::UnrecognizedEvent
501        ));
502    }
503
504    #[test]
505    fn test_classify_array_payload() {
506        assert!(matches!(
507            classify_runtime_event(json!([1, 2, 3])),
508            RuntimeEventClassification::UnrecognizedEvent
509        ));
510    }
511
512    #[test]
513    fn test_classify_nested_invoke_status() {
514        // invokeCompletionStatus must be at top level to match
515        let payload = json!({
516            "data": {"invokeCompletionStatus": "Success"}
517        });
518        assert!(matches!(
519            classify_runtime_event(payload),
520            RuntimeEventClassification::UnrecognizedEvent
521        ));
522    }
523
524    // ── Property-style tests ──
525
526    #[test]
527    fn test_classify_never_panics_on_arbitrary_json() {
528        // R4: Only assert no panics — no brittle !ApiGatewayEvent assertion.
529        let payloads = vec![
530            json!(null),
531            json!(true),
532            json!(false),
533            json!(42),
534            json!(-1.5),
535            json!(""),
536            json!("some string"),
537            json!([]),
538            json!([1, "two", null, false]),
539            json!({}),
540            json!({"a": 1}),
541            json!({"requestContext": null}),
542            json!({"requestContext": "not-an-object"}),
543            json!({"httpMethod": "POST"}),
544            json!({"version": "2.0"}),
545            json!({"version": "2.0", "routeKey": "GET /"}),
546            json!({"resource": "/", "httpMethod": "GET"}),
547            json!({"deeply": {"nested": {"invokeCompletionStatus": "Success"}}}),
548            // Large payload
549            serde_json::Value::Object((0..100).map(|i| (format!("key_{i}"), json!(i))).collect()),
550        ];
551
552        for payload in payloads {
553            let _result = classify_runtime_event(payload);
554        }
555    }
556
557    #[test]
558    fn test_classify_invoke_completion_status_always_wins() {
559        // Any object with top-level invokeCompletionStatus that doesn't parse as API GW
560        // should be classified as StreamingCompletion
561        let payloads = vec![
562            json!({"invokeCompletionStatus": "Success"}),
563            json!({"invokeCompletionStatus": "Failure"}),
564            json!({"invokeCompletionStatus": "Unknown"}),
565            json!({"invokeCompletionStatus": 42}),
566            json!({"invokeCompletionStatus": null}),
567            json!({"invokeCompletionStatus": "Success", "requestId": "abc-123"}),
568            json!({"invokeCompletionStatus": "Success", "extra": "field", "nested": {"a": 1}}),
569        ];
570
571        for payload in payloads {
572            let result = classify_runtime_event(payload.clone());
573            assert!(
574                matches!(result, RuntimeEventClassification::StreamingCompletion),
575                "Payload with top-level invokeCompletionStatus must be StreamingCompletion: {payload}"
576            );
577        }
578    }
579
580    // ── event_log_level contract tests (R1) ──
581
582    #[test]
583    fn test_unrecognized_logs_at_warn_level() {
584        assert_eq!(
585            event_log_level("unrecognized_lambda_payload"),
586            Some(tracing::Level::WARN)
587        );
588    }
589
590    #[test]
591    fn test_completion_logs_at_debug_level() {
592        assert_eq!(
593            event_log_level("streaming_completion"),
594            Some(tracing::Level::DEBUG)
595        );
596    }
597
598    #[test]
599    fn test_api_gateway_has_no_extra_logging() {
600        assert_eq!(event_log_level("api_gateway_event"), None);
601    }
602
603    // ── handle_runtime_payload action-path tests (R3) ──
604
605    #[tokio::test]
606    async fn test_handle_completion_does_not_dispatch() {
607        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
608        let dispatched_clone = dispatched.clone();
609
610        let result = handle_runtime_payload(
611            load_fixture("completion_success"),
612            lambda_runtime::Context::default(),
613            |_req| {
614                let d = dispatched_clone.clone();
615                async move {
616                    d.store(true, std::sync::atomic::Ordering::SeqCst);
617                    Ok(empty_streaming_response())
618                }
619            },
620        )
621        .await
622        .expect("handle should succeed");
623
624        assert!(
625            !dispatched.load(std::sync::atomic::Ordering::SeqCst),
626            "Completion events must not dispatch to handler"
627        );
628        assert_eq!(result.event_type, "streaming_completion");
629        assert_eq!(result.response.metadata_prelude.status_code, 200);
630    }
631
632    #[tokio::test]
633    async fn test_handle_unrecognized_does_not_dispatch() {
634        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
635        let dispatched_clone = dispatched.clone();
636
637        let result = handle_runtime_payload(
638            json!({"foo": "bar"}),
639            lambda_runtime::Context::default(),
640            |_req| {
641                let d = dispatched_clone.clone();
642                async move {
643                    d.store(true, std::sync::atomic::Ordering::SeqCst);
644                    Ok(empty_streaming_response())
645                }
646            },
647        )
648        .await
649        .expect("handle should succeed");
650
651        assert!(
652            !dispatched.load(std::sync::atomic::Ordering::SeqCst),
653            "Unrecognized events must not dispatch to handler"
654        );
655        assert_eq!(result.event_type, "unrecognized_lambda_payload");
656        assert_eq!(result.response.metadata_prelude.status_code, 200);
657    }
658
659    #[tokio::test]
660    async fn test_handle_apigw_v1_dispatches() {
661        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
662        let dispatched_clone = dispatched.clone();
663
664        let result = handle_runtime_payload(
665            load_fixture("apigw_v1"),
666            lambda_runtime::Context::default(),
667            |_req| {
668                let d = dispatched_clone.clone();
669                async move {
670                    d.store(true, std::sync::atomic::Ordering::SeqCst);
671                    Ok(empty_streaming_response())
672                }
673            },
674        )
675        .await
676        .expect("handle should succeed");
677
678        assert!(
679            dispatched.load(std::sync::atomic::Ordering::SeqCst),
680            "API Gateway v1 events must dispatch to handler"
681        );
682        assert_eq!(result.event_type, "api_gateway_event");
683    }
684
685    #[tokio::test]
686    async fn test_handle_apigw_v2_dispatches() {
687        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
688        let dispatched_clone = dispatched.clone();
689
690        let result = handle_runtime_payload(
691            load_fixture("apigw_v2"),
692            lambda_runtime::Context::default(),
693            |_req| {
694                let d = dispatched_clone.clone();
695                async move {
696                    d.store(true, std::sync::atomic::Ordering::SeqCst);
697                    Ok(empty_streaming_response())
698                }
699            },
700        )
701        .await
702        .expect("handle should succeed");
703
704        assert!(
705            dispatched.load(std::sync::atomic::Ordering::SeqCst),
706            "API Gateway v2 events must dispatch to handler"
707        );
708        assert_eq!(result.event_type, "api_gateway_event");
709    }
710
711    #[tokio::test]
712    async fn test_handle_unrecognized_surfaces_distinct_event_type() {
713        let result = handle_runtime_payload(
714            json!({"unknown": true}),
715            lambda_runtime::Context::default(),
716            |_req| async { Ok(empty_streaming_response()) },
717        )
718        .await
719        .expect("handle should succeed");
720
721        assert_eq!(result.event_type, "unrecognized_lambda_payload");
722    }
723
724    // ── Existing response conversion tests ──
725
726    #[test]
727    fn test_empty_streaming_response() {
728        let resp = empty_streaming_response();
729        assert_eq!(resp.status(), 200);
730    }
731
732    #[test]
733    fn test_into_lambda_stream_response_preserves_metadata() {
734        use http_body_util::{BodyExt, Full};
735
736        let response = http::Response::builder()
737            .status(401)
738            .header("WWW-Authenticate", "Bearer realm=\"mcp\"")
739            .header("X-Custom", "test")
740            .body(
741                Full::new(bytes::Bytes::from("Unauthorized"))
742                    .map_err(|e: std::convert::Infallible| match e {})
743                    .boxed_unsync(),
744            )
745            .unwrap();
746
747        let stream_resp = into_lambda_stream_response(response);
748        assert_eq!(stream_resp.metadata_prelude.status_code, 401);
749        assert_eq!(
750            stream_resp
751                .metadata_prelude
752                .headers
753                .get("WWW-Authenticate")
754                .unwrap(),
755            "Bearer realm=\"mcp\""
756        );
757        assert_eq!(
758            stream_resp
759                .metadata_prelude
760                .headers
761                .get("X-Custom")
762                .unwrap(),
763            "test"
764        );
765    }
766
767    #[test]
768    fn test_into_lambda_stream_response_extracts_cookies() {
769        use http_body_util::{BodyExt, Full};
770
771        let response = http::Response::builder()
772            .status(200)
773            .header("Set-Cookie", "session=abc; Path=/")
774            .header("Set-Cookie", "theme=dark")
775            .body(
776                Full::new(bytes::Bytes::new())
777                    .map_err(|e: std::convert::Infallible| match e {})
778                    .boxed_unsync(),
779            )
780            .unwrap();
781
782        let stream_resp = into_lambda_stream_response(response);
783        assert_eq!(stream_resp.metadata_prelude.cookies.len(), 2);
784        assert!(
785            stream_resp
786                .metadata_prelude
787                .cookies
788                .contains(&"session=abc; Path=/".to_string())
789        );
790        assert!(
791            stream_resp
792                .metadata_prelude
793                .cookies
794                .contains(&"theme=dark".to_string())
795        );
796        assert!(
797            stream_resp
798                .metadata_prelude
799                .headers
800                .get("Set-Cookie")
801                .is_none()
802        );
803    }
804
805    // ── run_streaming_with dispatch tests ──
806
807    #[tokio::test]
808    async fn test_run_streaming_with_dispatches_apigw_events() {
809        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
810        let dispatched_clone = dispatched.clone();
811
812        let dispatch = move |_req: lambda_http::Request| {
813            let d = dispatched_clone.clone();
814            async move {
815                d.store(true, std::sync::atomic::Ordering::SeqCst);
816                Ok(empty_streaming_response())
817            }
818        };
819
820        let result = handle_runtime_payload(
821            load_fixture("apigw_v1"),
822            lambda_runtime::Context::default(),
823            dispatch,
824        )
825        .await
826        .expect("handle should succeed");
827
828        assert!(
829            dispatched.load(std::sync::atomic::Ordering::SeqCst),
830            "run_streaming_with dispatch must be called for API Gateway events"
831        );
832        assert_eq!(result.event_type, "api_gateway_event");
833    }
834
835    #[tokio::test]
836    async fn test_run_streaming_with_acks_completion_without_dispatch() {
837        let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
838        let dispatched_clone = dispatched.clone();
839
840        let dispatch = move |_req: lambda_http::Request| {
841            let d = dispatched_clone.clone();
842            async move {
843                d.store(true, std::sync::atomic::Ordering::SeqCst);
844                Ok(empty_streaming_response())
845            }
846        };
847
848        let result = handle_runtime_payload(
849            load_fixture("completion_success"),
850            lambda_runtime::Context::default(),
851            dispatch,
852        )
853        .await
854        .expect("handle should succeed");
855
856        assert!(
857            !dispatched.load(std::sync::atomic::Ordering::SeqCst),
858            "run_streaming_with dispatch must NOT be called for completion events"
859        );
860        assert_eq!(result.event_type, "streaming_completion");
861        assert_eq!(result.response.metadata_prelude.status_code, 200);
862    }
863}