1pub mod adapter;
63pub mod builder;
64pub mod error;
65pub mod handler;
66pub mod prelude;
67pub mod server;
68
69#[cfg(feature = "cors")]
70pub mod cors;
71
72#[cfg(feature = "sse")]
73pub mod streaming;
74
75pub use builder::LambdaMcpServerBuilder;
78pub use error::{LambdaError, Result};
80pub use handler::LambdaMcpHandler;
82pub use server::LambdaMcpServer;
84
85#[cfg(feature = "cors")]
86pub use cors::CorsConfig;
87
88#[derive(Debug)]
93enum RuntimeEventClassification {
94 ApiGatewayEvent(Box<lambda_http::request::LambdaRequest>),
100 StreamingCompletion,
102 UnrecognizedEvent,
104}
105
106fn classify_runtime_event(payload: serde_json::Value) -> RuntimeEventClassification {
131 if let Ok(request) =
133 serde_json::from_value::<lambda_http::request::LambdaRequest>(payload.clone())
134 {
135 return RuntimeEventClassification::ApiGatewayEvent(Box::new(request));
136 }
137
138 if payload.get("invokeCompletionStatus").is_some() {
140 return RuntimeEventClassification::StreamingCompletion;
141 }
142
143 RuntimeEventClassification::UnrecognizedEvent
145}
146
147type StreamBody = http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>;
148type StreamResult = lambda_runtime::StreamResponse<http_body_util::BodyDataStream<StreamBody>>;
149
150struct HandleResult {
153 response: StreamResult,
154 event_type: &'static str,
157}
158
159async fn handle_runtime_payload<F, Fut>(
166 payload: serde_json::Value,
167 context: lambda_runtime::Context,
168 dispatch: F,
169) -> std::result::Result<HandleResult, lambda_http::Error>
170where
171 F: FnOnce(lambda_http::Request) -> Fut,
172 Fut: std::future::Future<
173 Output = std::result::Result<http::Response<StreamBody>, lambda_http::Error>,
174 >,
175{
176 match classify_runtime_event(payload) {
177 RuntimeEventClassification::ApiGatewayEvent(lambda_request) => {
178 use lambda_http::RequestExt;
179 let request: lambda_http::Request = (*lambda_request).into();
180 let request = request.with_lambda_context(context);
181 let response = dispatch(request).await?;
182 Ok(HandleResult {
183 response: into_lambda_stream_response(response),
184 event_type: "api_gateway_event",
185 })
186 }
187 RuntimeEventClassification::StreamingCompletion => Ok(HandleResult {
188 response: into_lambda_stream_response(empty_streaming_response()),
189 event_type: "streaming_completion",
190 }),
191 RuntimeEventClassification::UnrecognizedEvent => Ok(HandleResult {
192 response: into_lambda_stream_response(empty_streaming_response()),
193 event_type: "unrecognized_lambda_payload",
194 }),
195 }
196}
197
198fn event_log_level(event_type: &str) -> Option<tracing::Level> {
204 match event_type {
205 "streaming_completion" => Some(tracing::Level::DEBUG),
206 "unrecognized_lambda_payload" => Some(tracing::Level::WARN),
207 _ => None,
208 }
209}
210
211pub async fn run_streaming(
242 handler: LambdaMcpHandler,
243) -> std::result::Result<(), lambda_http::Error> {
244 use lambda_runtime::{LambdaEvent, service_fn};
245
246 lambda_runtime::run(service_fn(move |event: LambdaEvent<serde_json::Value>| {
247 let handler = handler.clone();
248 async move {
249 let result = handle_runtime_payload(event.payload, event.context, |req| {
250 handler.handle_streaming(req)
251 })
252 .await?;
253
254 match event_log_level(result.event_type) {
255 Some(level) if level == tracing::Level::WARN => {
256 tracing::warn!(
257 event_type = result.event_type,
258 "Received unrecognized Lambda invocation payload"
259 );
260 }
261 Some(_) => {
262 tracing::debug!(
263 event_type = result.event_type,
264 "Acknowledging streaming completion"
265 );
266 }
267 None => {}
268 }
269
270 Ok::<_, lambda_http::Error>(result.response)
271 }
272 }))
273 .await
274}
275
276pub async fn run_streaming_with<F, Fut>(dispatch: F) -> std::result::Result<(), lambda_http::Error>
298where
299 F: Fn(lambda_http::Request) -> Fut + Clone + Send + 'static,
300 Fut: std::future::Future<
301 Output = std::result::Result<http::Response<StreamBody>, lambda_http::Error>,
302 > + Send,
303{
304 use lambda_runtime::{LambdaEvent, service_fn};
305
306 lambda_runtime::run(service_fn(move |event: LambdaEvent<serde_json::Value>| {
307 let dispatch = dispatch.clone();
308 async move {
309 let result = handle_runtime_payload(event.payload, event.context, dispatch).await?;
310
311 match event_log_level(result.event_type) {
312 Some(level) if level == tracing::Level::WARN => {
313 tracing::warn!(
314 event_type = result.event_type,
315 "Received unrecognized Lambda invocation payload"
316 );
317 }
318 Some(_) => {
319 tracing::debug!(
320 event_type = result.event_type,
321 "Acknowledging streaming completion"
322 );
323 }
324 None => {}
325 }
326
327 Ok::<_, lambda_http::Error>(result.response)
328 }
329 }))
330 .await
331}
332
333fn into_lambda_stream_response<B>(
339 response: http::Response<B>,
340) -> lambda_runtime::StreamResponse<http_body_util::BodyDataStream<B>>
341where
342 B: http_body::Body + Unpin + Send + 'static,
343{
344 let (parts, body) = response.into_parts();
345 let mut headers = parts.headers;
346
347 let cookies = headers
349 .get_all(http::header::SET_COOKIE)
350 .iter()
351 .map(|c| String::from_utf8_lossy(c.as_bytes()).to_string())
352 .collect::<Vec<_>>();
353 headers.remove(http::header::SET_COOKIE);
354
355 lambda_runtime::StreamResponse {
356 metadata_prelude: lambda_runtime::MetadataPrelude {
357 headers,
358 status_code: parts.status,
359 cookies,
360 },
361 stream: http_body_util::BodyDataStream::new(body),
362 }
363}
364
365fn empty_streaming_response()
367-> http::Response<http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, hyper::Error>> {
368 use http_body_util::{BodyExt, Full};
369 let body = Full::new(bytes::Bytes::new())
370 .map_err(|e: std::convert::Infallible| match e {})
371 .boxed_unsync();
372 http::Response::builder().status(200).body(body).unwrap()
373}
374
375#[cfg(test)]
376mod streaming_completion_tests {
377 use super::*;
378 use serde_json::json;
379
380 fn load_fixture(name: &str) -> serde_json::Value {
383 let json_str = match name {
384 "apigw_v1" => include_str!("fixtures/apigw_v1_proxy_event.json"),
385 "apigw_v2" => include_str!("fixtures/apigw_v2_http_api_event.json"),
386 "completion_success" => include_str!("fixtures/streaming_completion_success.json"),
387 "completion_failure" => include_str!("fixtures/streaming_completion_failure.json"),
388 "completion_extra" => include_str!("fixtures/streaming_completion_extra_fields.json"),
389 "completion_api_like" => {
390 include_str!("fixtures/completion_with_api_like_fields.json")
391 }
392 other => panic!("Unknown fixture: {other}"),
393 };
394 serde_json::from_str(json_str).unwrap_or_else(|e| panic!("Bad fixture {name}: {e}"))
395 }
396
397 #[test]
400 fn test_classify_api_gateway_v1_event() {
401 let payload = load_fixture("apigw_v1");
402 assert!(
403 matches!(
404 classify_runtime_event(payload),
405 RuntimeEventClassification::ApiGatewayEvent(_)
406 ),
407 "API Gateway v1 proxy event must classify as ApiGatewayEvent"
408 );
409 }
410
411 #[test]
412 fn test_classify_api_gateway_v2_event() {
413 let payload = load_fixture("apigw_v2");
414 assert!(
415 matches!(
416 classify_runtime_event(payload),
417 RuntimeEventClassification::ApiGatewayEvent(_)
418 ),
419 "API Gateway v2 HTTP API event must classify as ApiGatewayEvent"
420 );
421 }
422
423 #[test]
426 fn test_classify_streaming_completion() {
427 let payload = load_fixture("completion_success");
428 assert!(matches!(
429 classify_runtime_event(payload),
430 RuntimeEventClassification::StreamingCompletion
431 ));
432 }
433
434 #[test]
435 fn test_classify_completion_failure_status() {
436 let payload = load_fixture("completion_failure");
437 assert!(matches!(
438 classify_runtime_event(payload),
439 RuntimeEventClassification::StreamingCompletion
440 ));
441 }
442
443 #[test]
444 fn test_classify_completion_extra_fields() {
445 let payload = load_fixture("completion_extra");
446 assert!(matches!(
447 classify_runtime_event(payload),
448 RuntimeEventClassification::StreamingCompletion
449 ));
450 }
451
452 #[test]
455 fn test_classify_completion_with_api_like_fields() {
456 let payload = load_fixture("completion_api_like");
464 assert!(matches!(
465 classify_runtime_event(payload),
466 RuntimeEventClassification::StreamingCompletion
467 ));
468 }
469
470 #[test]
473 fn test_classify_empty_object() {
474 assert!(matches!(
475 classify_runtime_event(json!({})),
476 RuntimeEventClassification::UnrecognizedEvent
477 ));
478 }
479
480 #[test]
481 fn test_classify_random_object() {
482 assert!(matches!(
483 classify_runtime_event(json!({"foo": "bar", "baz": 123})),
484 RuntimeEventClassification::UnrecognizedEvent
485 ));
486 }
487
488 #[test]
489 fn test_classify_null_payload() {
490 assert!(matches!(
491 classify_runtime_event(json!(null)),
492 RuntimeEventClassification::UnrecognizedEvent
493 ));
494 }
495
496 #[test]
497 fn test_classify_string_payload() {
498 assert!(matches!(
499 classify_runtime_event(json!("hello")),
500 RuntimeEventClassification::UnrecognizedEvent
501 ));
502 }
503
504 #[test]
505 fn test_classify_array_payload() {
506 assert!(matches!(
507 classify_runtime_event(json!([1, 2, 3])),
508 RuntimeEventClassification::UnrecognizedEvent
509 ));
510 }
511
512 #[test]
513 fn test_classify_nested_invoke_status() {
514 let payload = json!({
516 "data": {"invokeCompletionStatus": "Success"}
517 });
518 assert!(matches!(
519 classify_runtime_event(payload),
520 RuntimeEventClassification::UnrecognizedEvent
521 ));
522 }
523
524 #[test]
527 fn test_classify_never_panics_on_arbitrary_json() {
528 let payloads = vec![
530 json!(null),
531 json!(true),
532 json!(false),
533 json!(42),
534 json!(-1.5),
535 json!(""),
536 json!("some string"),
537 json!([]),
538 json!([1, "two", null, false]),
539 json!({}),
540 json!({"a": 1}),
541 json!({"requestContext": null}),
542 json!({"requestContext": "not-an-object"}),
543 json!({"httpMethod": "POST"}),
544 json!({"version": "2.0"}),
545 json!({"version": "2.0", "routeKey": "GET /"}),
546 json!({"resource": "/", "httpMethod": "GET"}),
547 json!({"deeply": {"nested": {"invokeCompletionStatus": "Success"}}}),
548 serde_json::Value::Object((0..100).map(|i| (format!("key_{i}"), json!(i))).collect()),
550 ];
551
552 for payload in payloads {
553 let _result = classify_runtime_event(payload);
554 }
555 }
556
557 #[test]
558 fn test_classify_invoke_completion_status_always_wins() {
559 let payloads = vec![
562 json!({"invokeCompletionStatus": "Success"}),
563 json!({"invokeCompletionStatus": "Failure"}),
564 json!({"invokeCompletionStatus": "Unknown"}),
565 json!({"invokeCompletionStatus": 42}),
566 json!({"invokeCompletionStatus": null}),
567 json!({"invokeCompletionStatus": "Success", "requestId": "abc-123"}),
568 json!({"invokeCompletionStatus": "Success", "extra": "field", "nested": {"a": 1}}),
569 ];
570
571 for payload in payloads {
572 let result = classify_runtime_event(payload.clone());
573 assert!(
574 matches!(result, RuntimeEventClassification::StreamingCompletion),
575 "Payload with top-level invokeCompletionStatus must be StreamingCompletion: {payload}"
576 );
577 }
578 }
579
580 #[test]
583 fn test_unrecognized_logs_at_warn_level() {
584 assert_eq!(
585 event_log_level("unrecognized_lambda_payload"),
586 Some(tracing::Level::WARN)
587 );
588 }
589
590 #[test]
591 fn test_completion_logs_at_debug_level() {
592 assert_eq!(
593 event_log_level("streaming_completion"),
594 Some(tracing::Level::DEBUG)
595 );
596 }
597
598 #[test]
599 fn test_api_gateway_has_no_extra_logging() {
600 assert_eq!(event_log_level("api_gateway_event"), None);
601 }
602
603 #[tokio::test]
606 async fn test_handle_completion_does_not_dispatch() {
607 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
608 let dispatched_clone = dispatched.clone();
609
610 let result = handle_runtime_payload(
611 load_fixture("completion_success"),
612 lambda_runtime::Context::default(),
613 |_req| {
614 let d = dispatched_clone.clone();
615 async move {
616 d.store(true, std::sync::atomic::Ordering::SeqCst);
617 Ok(empty_streaming_response())
618 }
619 },
620 )
621 .await
622 .expect("handle should succeed");
623
624 assert!(
625 !dispatched.load(std::sync::atomic::Ordering::SeqCst),
626 "Completion events must not dispatch to handler"
627 );
628 assert_eq!(result.event_type, "streaming_completion");
629 assert_eq!(result.response.metadata_prelude.status_code, 200);
630 }
631
632 #[tokio::test]
633 async fn test_handle_unrecognized_does_not_dispatch() {
634 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
635 let dispatched_clone = dispatched.clone();
636
637 let result = handle_runtime_payload(
638 json!({"foo": "bar"}),
639 lambda_runtime::Context::default(),
640 |_req| {
641 let d = dispatched_clone.clone();
642 async move {
643 d.store(true, std::sync::atomic::Ordering::SeqCst);
644 Ok(empty_streaming_response())
645 }
646 },
647 )
648 .await
649 .expect("handle should succeed");
650
651 assert!(
652 !dispatched.load(std::sync::atomic::Ordering::SeqCst),
653 "Unrecognized events must not dispatch to handler"
654 );
655 assert_eq!(result.event_type, "unrecognized_lambda_payload");
656 assert_eq!(result.response.metadata_prelude.status_code, 200);
657 }
658
659 #[tokio::test]
660 async fn test_handle_apigw_v1_dispatches() {
661 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
662 let dispatched_clone = dispatched.clone();
663
664 let result = handle_runtime_payload(
665 load_fixture("apigw_v1"),
666 lambda_runtime::Context::default(),
667 |_req| {
668 let d = dispatched_clone.clone();
669 async move {
670 d.store(true, std::sync::atomic::Ordering::SeqCst);
671 Ok(empty_streaming_response())
672 }
673 },
674 )
675 .await
676 .expect("handle should succeed");
677
678 assert!(
679 dispatched.load(std::sync::atomic::Ordering::SeqCst),
680 "API Gateway v1 events must dispatch to handler"
681 );
682 assert_eq!(result.event_type, "api_gateway_event");
683 }
684
685 #[tokio::test]
686 async fn test_handle_apigw_v2_dispatches() {
687 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
688 let dispatched_clone = dispatched.clone();
689
690 let result = handle_runtime_payload(
691 load_fixture("apigw_v2"),
692 lambda_runtime::Context::default(),
693 |_req| {
694 let d = dispatched_clone.clone();
695 async move {
696 d.store(true, std::sync::atomic::Ordering::SeqCst);
697 Ok(empty_streaming_response())
698 }
699 },
700 )
701 .await
702 .expect("handle should succeed");
703
704 assert!(
705 dispatched.load(std::sync::atomic::Ordering::SeqCst),
706 "API Gateway v2 events must dispatch to handler"
707 );
708 assert_eq!(result.event_type, "api_gateway_event");
709 }
710
711 #[tokio::test]
712 async fn test_handle_unrecognized_surfaces_distinct_event_type() {
713 let result = handle_runtime_payload(
714 json!({"unknown": true}),
715 lambda_runtime::Context::default(),
716 |_req| async { Ok(empty_streaming_response()) },
717 )
718 .await
719 .expect("handle should succeed");
720
721 assert_eq!(result.event_type, "unrecognized_lambda_payload");
722 }
723
724 #[test]
727 fn test_empty_streaming_response() {
728 let resp = empty_streaming_response();
729 assert_eq!(resp.status(), 200);
730 }
731
732 #[test]
733 fn test_into_lambda_stream_response_preserves_metadata() {
734 use http_body_util::{BodyExt, Full};
735
736 let response = http::Response::builder()
737 .status(401)
738 .header("WWW-Authenticate", "Bearer realm=\"mcp\"")
739 .header("X-Custom", "test")
740 .body(
741 Full::new(bytes::Bytes::from("Unauthorized"))
742 .map_err(|e: std::convert::Infallible| match e {})
743 .boxed_unsync(),
744 )
745 .unwrap();
746
747 let stream_resp = into_lambda_stream_response(response);
748 assert_eq!(stream_resp.metadata_prelude.status_code, 401);
749 assert_eq!(
750 stream_resp
751 .metadata_prelude
752 .headers
753 .get("WWW-Authenticate")
754 .unwrap(),
755 "Bearer realm=\"mcp\""
756 );
757 assert_eq!(
758 stream_resp
759 .metadata_prelude
760 .headers
761 .get("X-Custom")
762 .unwrap(),
763 "test"
764 );
765 }
766
767 #[test]
768 fn test_into_lambda_stream_response_extracts_cookies() {
769 use http_body_util::{BodyExt, Full};
770
771 let response = http::Response::builder()
772 .status(200)
773 .header("Set-Cookie", "session=abc; Path=/")
774 .header("Set-Cookie", "theme=dark")
775 .body(
776 Full::new(bytes::Bytes::new())
777 .map_err(|e: std::convert::Infallible| match e {})
778 .boxed_unsync(),
779 )
780 .unwrap();
781
782 let stream_resp = into_lambda_stream_response(response);
783 assert_eq!(stream_resp.metadata_prelude.cookies.len(), 2);
784 assert!(
785 stream_resp
786 .metadata_prelude
787 .cookies
788 .contains(&"session=abc; Path=/".to_string())
789 );
790 assert!(
791 stream_resp
792 .metadata_prelude
793 .cookies
794 .contains(&"theme=dark".to_string())
795 );
796 assert!(
797 stream_resp
798 .metadata_prelude
799 .headers
800 .get("Set-Cookie")
801 .is_none()
802 );
803 }
804
805 #[tokio::test]
808 async fn test_run_streaming_with_dispatches_apigw_events() {
809 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
810 let dispatched_clone = dispatched.clone();
811
812 let dispatch = move |_req: lambda_http::Request| {
813 let d = dispatched_clone.clone();
814 async move {
815 d.store(true, std::sync::atomic::Ordering::SeqCst);
816 Ok(empty_streaming_response())
817 }
818 };
819
820 let result = handle_runtime_payload(
821 load_fixture("apigw_v1"),
822 lambda_runtime::Context::default(),
823 dispatch,
824 )
825 .await
826 .expect("handle should succeed");
827
828 assert!(
829 dispatched.load(std::sync::atomic::Ordering::SeqCst),
830 "run_streaming_with dispatch must be called for API Gateway events"
831 );
832 assert_eq!(result.event_type, "api_gateway_event");
833 }
834
835 #[tokio::test]
836 async fn test_run_streaming_with_acks_completion_without_dispatch() {
837 let dispatched = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
838 let dispatched_clone = dispatched.clone();
839
840 let dispatch = move |_req: lambda_http::Request| {
841 let d = dispatched_clone.clone();
842 async move {
843 d.store(true, std::sync::atomic::Ordering::SeqCst);
844 Ok(empty_streaming_response())
845 }
846 };
847
848 let result = handle_runtime_payload(
849 load_fixture("completion_success"),
850 lambda_runtime::Context::default(),
851 dispatch,
852 )
853 .await
854 .expect("handle should succeed");
855
856 assert!(
857 !dispatched.load(std::sync::atomic::Ordering::SeqCst),
858 "run_streaming_with dispatch must NOT be called for completion events"
859 );
860 assert_eq!(result.event_type, "streaming_completion");
861 assert_eq!(result.response.metadata_prelude.status_code, 200);
862 }
863}