1#[cfg(feature = "poem")]
3pub mod poem_ext {
4 use crate::errors::WsError;
5 use http;
6 use poem::Body;
7 use std::future::Future;
8
9 fn convert<T: Into<Body>>(resp: http::Response<T>) -> poem::Response {
10 let (parts, body) = resp.into_parts();
11 poem::Response::from_parts(
12 poem::ResponseParts {
13 status: parts.status,
14 version: parts.version,
15 headers: parts.headers,
16 extensions: parts.extensions,
17 },
18 body.into(),
19 )
20 }
21
22 pub async fn adapt<T, F1, F2, Fut>(
24 req: &poem::Request,
25 mut handshake_handler: F1,
26 callback: F2,
27 ) -> poem::Response
28 where
29 F1: FnMut(
30 http::Request<()>,
31 )
32 -> Result<(http::Request<()>, http::Response<T>), (http::Response<T>, WsError)>,
33 F2: FnOnce(http::Request<()>, poem::Upgraded) -> Fut + Send + Sync + 'static,
34 Fut: Future + Send + 'static,
35 T: Into<Body> + std::fmt::Debug,
36 {
37 let on_upgrade = match req.take_upgrade() {
38 Err(e) => {
39 tracing::error!("http upgrade failed {e}");
40 return poem::Response::builder()
41 .version(http::Version::HTTP_11)
42 .status(http::StatusCode::BAD_REQUEST)
43 .body(());
44 }
45 Ok(i) => i,
46 };
47
48 let mut builder = http::Request::builder().method(req.method()).uri(req.uri());
49 for (k, v) in req.headers() {
50 builder = builder.header(k, v)
51 }
52 let req = builder.body(()).unwrap();
53 let (req, resp) = match handshake_handler(req) {
54 Ok(i) => i,
55 Err((resp, e)) => {
56 tracing::error!("handshake error {e}");
57 return convert(resp);
58 }
59 };
60 tokio::spawn(async move {
61 match on_upgrade.await {
62 Err(e) => {
63 tracing::error!("http upgrade failed {e}");
64 return;
65 }
66 Ok(upgraded) => {
67 callback(req, upgraded).await;
68 }
69 }
70 });
71 convert(resp)
72 }
73}
74
75#[cfg(feature = "axum")]
77pub mod axum_ext {
78 use http;
79 use std::future::Future;
80
81 use axum::{body::Body, response::Response};
82
83 use crate::errors::WsError;
84
85 pub async fn adapt<T, F1, F2, Fut>(
87 req: axum::extract::Request,
88 mut handshake_handler: F1,
89 callback: F2,
90 ) -> Response
91 where
92 F1: FnMut(
93 http::Request<()>,
94 )
95 -> Result<(http::Request<()>, http::Response<T>), (http::Response<T>, WsError)>,
96 F2: FnOnce(http::Request<()>, hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>) -> Fut
97 + Send
98 + Sync
99 + 'static,
100 Fut: Future + Send + 'static,
101 T: std::fmt::Debug + Into<Body>,
102 {
103 let (mut parts, _) = req.into_parts();
104 let on_upgrade = match parts.extensions.remove::<hyper::upgrade::OnUpgrade>() {
105 Some(on_upgrade) => on_upgrade,
106 None => {
107 tracing::error!("upgraded failed");
108 return Response::builder()
109 .version(axum::http::Version::HTTP_11)
110 .status(axum::http::StatusCode::BAD_REQUEST)
111 .body("".into())
112 .unwrap();
113 }
114 };
115 let req = axum::http::Request::from_parts(parts, ());
116 let (req, resp) = match handshake_handler(req) {
117 Ok(i) => i,
118 Err((resp, e)) => {
119 tracing::error!("handshake error {e}");
120 let (parts, body) = resp.into_parts();
121 return Response::from_parts(parts, body.into());
122 }
123 };
124 tokio::spawn(async move {
125 match on_upgrade.await {
126 Err(e) => {
127 tracing::error!("http upgrade failed {e}");
128 return;
129 }
130 Ok(upgraded) => {
131 callback(req, hyper_util::rt::TokioIo::new(upgraded)).await;
132 }
133 }
134 });
135 let (parts, body) = resp.into_parts();
136 return Response::from_parts(parts, body.into());
137 }
138}