1pub mod adapter;
54pub mod builder;
55pub mod error;
56pub mod handler;
57pub mod prelude;
58pub mod server;
59
60#[cfg(feature = "cors")]
61pub mod cors;
62
63#[cfg(feature = "sse")]
64pub mod streaming;
65
66pub use builder::LambdaMcpServerBuilder;
69pub use error::{LambdaError, Result};
71pub use handler::LambdaMcpHandler;
73pub use server::LambdaMcpServer;
75
76#[cfg(feature = "cors")]
77pub use cors::CorsConfig;
78
79#[derive(Debug)]
84enum RuntimeEventClassification {
85 ApiGatewayEvent(Box<lambda_http::request::LambdaRequest>),
91 StreamingCompletion,
93 UnrecognizedEvent,
95}
96
97fn classify_runtime_event(payload: serde_json::Value) -> RuntimeEventClassification {
122 if let Ok(request) =
124 serde_json::from_value::<lambda_http::request::LambdaRequest>(payload.clone())
125 {
126 return RuntimeEventClassification::ApiGatewayEvent(Box::new(request));
127 }
128
129 if payload.get("invokeCompletionStatus").is_some() {
131 return RuntimeEventClassification::StreamingCompletion;
132 }
133
134 RuntimeEventClassification::UnrecognizedEvent
136}
137
138type StreamBody = http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>;
139type StreamResult = lambda_runtime::StreamResponse<http_body_util::BodyDataStream<StreamBody>>;
140
141struct HandleResult {
144 response: StreamResult,
145 event_type: &'static str,
148}
149
150async fn handle_runtime_payload<F, Fut>(
157 payload: serde_json::Value,
158 context: lambda_runtime::Context,
159 dispatch: F,
160) -> std::result::Result<HandleResult, lambda_http::Error>
161where
162 F: FnOnce(lambda_http::Request) -> Fut,
163 Fut: std::future::Future<
164 Output = std::result::Result<http::Response<StreamBody>, lambda_http::Error>,
165 >,
166{
167 match classify_runtime_event(payload) {
168 RuntimeEventClassification::ApiGatewayEvent(lambda_request) => {
169 use lambda_http::RequestExt;
170 let request: lambda_http::Request = (*lambda_request).into();
171 let request = request.with_lambda_context(context);
172 let response = dispatch(request).await?;
173 Ok(HandleResult {
174 response: into_lambda_stream_response(response),
175 event_type: "api_gateway_event",
176 })
177 }
178 RuntimeEventClassification::StreamingCompletion => Ok(HandleResult {
179 response: into_lambda_stream_response(empty_streaming_response()),
180 event_type: "streaming_completion",
181 }),
182 RuntimeEventClassification::UnrecognizedEvent => Ok(HandleResult {
183 response: into_lambda_stream_response(empty_streaming_response()),
184 event_type: "unrecognized_lambda_payload",
185 }),
186 }
187}
188
189fn event_log_level(event_type: &str) -> Option<tracing::Level> {
195 match event_type {
196 "streaming_completion" => Some(tracing::Level::DEBUG),
197 "unrecognized_lambda_payload" => Some(tracing::Level::WARN),
198 _ => None,
199 }
200}
201
202pub async fn run_streaming(
233 handler: LambdaMcpHandler,
234) -> std::result::Result<(), lambda_http::Error> {
235 use lambda_runtime::{LambdaEvent, service_fn};
236
237 lambda_runtime::run(service_fn(move |event: LambdaEvent<serde_json::Value>| {
238 let handler = handler.clone();
239 async move {
240 let result = handle_runtime_payload(event.payload, event.context, |req| {
241 handler.handle_streaming(req)
242 })
243 .await?;
244
245 match event_log_level(result.event_type) {
246 Some(level) if level == tracing::Level::WARN => {
247 tracing::warn!(
248 event_type = result.event_type,
249 "Received unrecognized Lambda invocation payload"
250 );
251 }
252 Some(_) => {
253 tracing::debug!(
254 event_type = result.event_type,
255 "Acknowledging streaming completion"
256 );
257 }
258 None => {}
259 }
260
261 Ok::<_, lambda_http::Error>(result.response)
262 }
263 }))
264 .await
265}
266
267fn into_lambda_stream_response<B>(
273 response: http::Response<B>,
274) -> lambda_runtime::StreamResponse<http_body_util::BodyDataStream<B>>
275where
276 B: http_body::Body + Unpin + Send + 'static,
277{
278 let (parts, body) = response.into_parts();
279 let mut headers = parts.headers;
280
281 let cookies = headers
283 .get_all(http::header::SET_COOKIE)
284 .iter()
285 .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
286 .collect::<Vec<_>>();
287 headers.remove(http::header::SET_COOKIE);
288
289 lambda_runtime::StreamResponse {
290 metadata_prelude: lambda_runtime::MetadataPrelude {
291 headers,
292 status_code: parts.status,
293 cookies,
294 },
295 stream: http_body_util::BodyDataStream::new(body),
296 }
297}
298
299fn empty_streaming_response()
301-> http::Response<http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>> {
302 use http_body_util::{BodyExt, Full};
303 let body = Full::new(bytes::Bytes::new())
304 .map_err(|e: std::convert::Infallible| match e {})
305 .boxed_unsync();
306 http::Response::builder().status(200).body(body).unwrap()
307}
308
309#[cfg(test)]
310mod streaming_completion_tests {
311 use super::*;
312 use serde_json::json;
313
314 fn load_fixture(name: &str) -> serde_json::Value {
317 let json_str = match name {
318 "apigw_v1" => include_str!("fixtures/apigw_v1_proxy_event.json"),
319 "apigw_v2" => include_str!("fixtures/apigw_v2_http_api_event.json"),
320 "completion_success" => include_str!("fixtures/streaming_completion_success.json"),
321 "completion_failure" => include_str!("fixtures/streaming_completion_failure.json"),
322 "completion_extra" => include_str!("fixtures/streaming_completion_extra_fields.json"),
323 "completion_api_like" => {
324 include_str!("fixtures/completion_with_api_like_fields.json")
325 }
326 other => panic!("Unknown fixture: {other}"),
327 };
328 serde_json::from_str(json_str).unwrap_or_else(|e| panic!("Bad fixture {name}: {e}"))
329 }
330
331 #[test]
334 fn test_classify_api_gateway_v1_event() {
335 let payload = load_fixture("apigw_v1");
336 assert!(
337 matches!(
338 classify_runtime_event(payload),
339 RuntimeEventClassification::ApiGatewayEvent(_)
340 ),
341 "API Gateway v1 proxy event must classify as ApiGatewayEvent"
342 );
343 }
344
345 #[test]
346 fn test_classify_api_gateway_v2_event() {
347 let payload = load_fixture("apigw_v2");
348 assert!(
349 matches!(
350 classify_runtime_event(payload),
351 RuntimeEventClassification::ApiGatewayEvent(_)
352 ),
353 "API Gateway v2 HTTP API event must classify as ApiGatewayEvent"
354 );
355 }
356
357 #[test]
360 fn test_classify_streaming_completion() {
361 let payload = load_fixture("completion_success");
362 assert!(matches!(
363 classify_runtime_event(payload),
364 RuntimeEventClassification::StreamingCompletion
365 ));
366 }
367
368 #[test]
369 fn test_classify_completion_failure_status() {
370 let payload = load_fixture("completion_failure");
371 assert!(matches!(
372 classify_runtime_event(payload),
373 RuntimeEventClassification::StreamingCompletion
374 ));
375 }
376
377 #[test]
378 fn test_classify_completion_extra_fields() {
379 let payload = load_fixture("completion_extra");
380 assert!(matches!(
381 classify_runtime_event(payload),
382 RuntimeEventClassification::StreamingCompletion
383 ));
384 }
385
386 #[test]
389 fn test_classify_completion_with_api_like_fields() {
390 let payload = load_fixture("completion_api_like");
398 assert!(matches!(
399 classify_runtime_event(payload),
400 RuntimeEventClassification::StreamingCompletion
401 ));
402 }
403
404 #[test]
407 fn test_classify_empty_object() {
408 assert!(matches!(
409 classify_runtime_event(json!({})),
410 RuntimeEventClassification::UnrecognizedEvent
411 ));
412 }
413
414 #[test]
415 fn test_classify_random_object() {
416 assert!(matches!(
417 classify_runtime_event(json!({"foo": "bar", "baz": 123})),
418 RuntimeEventClassification::UnrecognizedEvent
419 ));
420 }
421
422 #[test]
423 fn test_classify_null_payload() {
424 assert!(matches!(
425 classify_runtime_event(json!(null)),
426 RuntimeEventClassification::UnrecognizedEvent
427 ));
428 }
429
430 #[test]
431 fn test_classify_string_payload() {
432 assert!(matches!(
433 classify_runtime_event(json!("hello")),
434 RuntimeEventClassification::UnrecognizedEvent
435 ));
436 }
437
438 #[test]
439 fn test_classify_array_payload() {
440 assert!(matches!(
441 classify_runtime_event(json!([1, 2, 3])),
442 RuntimeEventClassification::UnrecognizedEvent
443 ));
444 }
445
446 #[test]
447 fn test_classify_nested_invoke_status() {
448 let payload = json!({
450 "data": {"invokeCompletionStatus": "Success"}
451 });
452 assert!(matches!(
453 classify_runtime_event(payload),
454 RuntimeEventClassification::UnrecognizedEvent
455 ));
456 }
457
458 #[test]
461 fn test_classify_never_panics_on_arbitrary_json() {
462 let payloads = vec![
464 json!(null),
465 json!(true),
466 json!(false),
467 json!(42),
468 json!(-1.5),
469 json!(""),
470 json!("some string"),
471 json!([]),
472 json!([1, "two", null, false]),
473 json!({}),
474 json!({"a": 1}),
475 json!({"requestContext": null}),
476 json!({"requestContext": "not-an-object"}),
477 json!({"httpMethod": "POST"}),
478 json!({"version": "2.0"}),
479 json!({"version": "2.0", "routeKey": "GET /"}),
480 json!({"resource": "/", "httpMethod": "GET"}),
481 json!({"deeply": {"nested": {"invokeCompletionStatus": "Success"}}}),
482 serde_json::Value::Object((0..100).map(|i| (format!("key_{i}"), json!(i))).collect()),
484 ];
485
486 for payload in payloads {
487 let _result = classify_runtime_event(payload);
488 }
489 }
490
491 #[test]
492 fn test_classify_invoke_completion_status_always_wins() {
493 let payloads = vec![
496 json!({"invokeCompletionStatus": "Success"}),
497 json!({"invokeCompletionStatus": "Failure"}),
498 json!({"invokeCompletionStatus": "Unknown"}),
499 json!({"invokeCompletionStatus": 42}),
500 json!({"invokeCompletionStatus": null}),
501 json!({"invokeCompletionStatus": "Success", "requestId": "abc-123"}),
502 json!({"invokeCompletionStatus": "Success", "extra": "field", "nested": {"a": 1}}),
503 ];
504
505 for payload in payloads {
506 let result = classify_runtime_event(payload.clone());
507 assert!(
508 matches!(result, RuntimeEventClassification::StreamingCompletion),
509 "Payload with top-level invokeCompletionStatus must be StreamingCompletion: {payload}"
510 );
511 }
512 }
513
514 #[test]
517 fn test_unrecognized_logs_at_warn_level() {
518 assert_eq!(
519 event_log_level("unrecognized_lambda_payload"),
520 Some(tracing::Level::WARN)
521 );
522 }
523
524 #[test]
525 fn test_completion_logs_at_debug_level() {
526 assert_eq!(
527 event_log_level("streaming_completion"),
528 Some(tracing::Level::DEBUG)
529 );
530 }
531
532 #[test]
533 fn test_api_gateway_has_no_extra_logging() {
534 assert_eq!(event_log_level("api_gateway_event"), None);
535 }
536
537 #[tokio::test]
540 async fn test_handle_completion_does_not_dispatch() {
541 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
542 let dispatched_clone = dispatched.clone();
543
544 let result = handle_runtime_payload(
545 load_fixture("completion_success"),
546 lambda_runtime::Context::default(),
547 |_req| {
548 let d = dispatched_clone.clone();
549 async move {
550 d.store(true, std::sync::atomic::Ordering::SeqCst);
551 Ok(empty_streaming_response())
552 }
553 },
554 )
555 .await
556 .expect("handle should succeed");
557
558 assert!(
559 !dispatched.load(std::sync::atomic::Ordering::SeqCst),
560 "Completion events must not dispatch to handler"
561 );
562 assert_eq!(result.event_type, "streaming_completion");
563 assert_eq!(result.response.metadata_prelude.status_code, 200);
564 }
565
566 #[tokio::test]
567 async fn test_handle_unrecognized_does_not_dispatch() {
568 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
569 let dispatched_clone = dispatched.clone();
570
571 let result = handle_runtime_payload(
572 json!({"foo": "bar"}),
573 lambda_runtime::Context::default(),
574 |_req| {
575 let d = dispatched_clone.clone();
576 async move {
577 d.store(true, std::sync::atomic::Ordering::SeqCst);
578 Ok(empty_streaming_response())
579 }
580 },
581 )
582 .await
583 .expect("handle should succeed");
584
585 assert!(
586 !dispatched.load(std::sync::atomic::Ordering::SeqCst),
587 "Unrecognized events must not dispatch to handler"
588 );
589 assert_eq!(result.event_type, "unrecognized_lambda_payload");
590 assert_eq!(result.response.metadata_prelude.status_code, 200);
591 }
592
593 #[tokio::test]
594 async fn test_handle_apigw_v1_dispatches() {
595 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
596 let dispatched_clone = dispatched.clone();
597
598 let result = handle_runtime_payload(
599 load_fixture("apigw_v1"),
600 lambda_runtime::Context::default(),
601 |_req| {
602 let d = dispatched_clone.clone();
603 async move {
604 d.store(true, std::sync::atomic::Ordering::SeqCst);
605 Ok(empty_streaming_response())
606 }
607 },
608 )
609 .await
610 .expect("handle should succeed");
611
612 assert!(
613 dispatched.load(std::sync::atomic::Ordering::SeqCst),
614 "API Gateway v1 events must dispatch to handler"
615 );
616 assert_eq!(result.event_type, "api_gateway_event");
617 }
618
619 #[tokio::test]
620 async fn test_handle_apigw_v2_dispatches() {
621 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
622 let dispatched_clone = dispatched.clone();
623
624 let result = handle_runtime_payload(
625 load_fixture("apigw_v2"),
626 lambda_runtime::Context::default(),
627 |_req| {
628 let d = dispatched_clone.clone();
629 async move {
630 d.store(true, std::sync::atomic::Ordering::SeqCst);
631 Ok(empty_streaming_response())
632 }
633 },
634 )
635 .await
636 .expect("handle should succeed");
637
638 assert!(
639 dispatched.load(std::sync::atomic::Ordering::SeqCst),
640 "API Gateway v2 events must dispatch to handler"
641 );
642 assert_eq!(result.event_type, "api_gateway_event");
643 }
644
645 #[tokio::test]
646 async fn test_handle_unrecognized_surfaces_distinct_event_type() {
647 let result = handle_runtime_payload(
648 json!({"unknown": true}),
649 lambda_runtime::Context::default(),
650 |_req| async { Ok(empty_streaming_response()) },
651 )
652 .await
653 .expect("handle should succeed");
654
655 assert_eq!(result.event_type, "unrecognized_lambda_payload");
656 }
657
658 #[test]
661 fn test_empty_streaming_response() {
662 let resp = empty_streaming_response();
663 assert_eq!(resp.status(), 200);
664 }
665
666 #[test]
667 fn test_into_lambda_stream_response_preserves_metadata() {
668 use http_body_util::{BodyExt, Full};
669
670 let response = http::Response::builder()
671 .status(401)
672 .header("WWW-Authenticate", "Bearer realm=\"mcp\"")
673 .header("X-Custom", "test")
674 .body(
675 Full::new(bytes::Bytes::from("Unauthorized"))
676 .map_err(|e: std::convert::Infallible| match e {})
677 .boxed_unsync(),
678 )
679 .unwrap();
680
681 let stream_resp = into_lambda_stream_response(response);
682 assert_eq!(stream_resp.metadata_prelude.status_code, 401);
683 assert_eq!(
684 stream_resp
685 .metadata_prelude
686 .headers
687 .get("WWW-Authenticate")
688 .unwrap(),
689 "Bearer realm=\"mcp\""
690 );
691 assert_eq!(
692 stream_resp
693 .metadata_prelude
694 .headers
695 .get("X-Custom")
696 .unwrap(),
697 "test"
698 );
699 }
700
701 #[test]
702 fn test_into_lambda_stream_response_extracts_cookies() {
703 use http_body_util::{BodyExt, Full};
704
705 let response = http::Response::builder()
706 .status(200)
707 .header("Set-Cookie", "session=abc; Path=/")
708 .header("Set-Cookie", "theme=dark")
709 .body(
710 Full::new(bytes::Bytes::new())
711 .map_err(|e: std::convert::Infallible| match e {})
712 .boxed_unsync(),
713 )
714 .unwrap();
715
716 let stream_resp = into_lambda_stream_response(response);
717 assert_eq!(stream_resp.metadata_prelude.cookies.len(), 2);
718 assert!(
719 stream_resp
720 .metadata_prelude
721 .cookies
722 .contains(&"session=abc; Path=/".to_string())
723 );
724 assert!(
725 stream_resp
726 .metadata_prelude
727 .cookies
728 .contains(&"theme=dark".to_string())
729 );
730 assert!(
731 stream_resp
732 .metadata_prelude
733 .headers
734 .get("Set-Cookie")
735 .is_none()
736 );
737 }
738}