ws_tool/
extension.rs

1/// poem websocket extension
2#[cfg(feature = "poem")]
3pub mod poem_ext {
4    use crate::errors::WsError;
5    use http;
6    use poem::Body;
7    use std::future::Future;
8
9    fn convert<T: Into<Body>>(resp: http::Response<T>) -> poem::Response {
10        let (parts, body) = resp.into_parts();
11        poem::Response::from_parts(
12            poem::ResponseParts {
13                status: parts.status,
14                version: parts.version,
15                headers: parts.headers,
16                extensions: parts.extensions,
17            },
18            body.into(),
19        )
20    }
21
22    /// accept poem raw request
23    pub async fn adapt<T, F1, F2, Fut>(
24        req: &poem::Request,
25        mut handshake_handler: F1,
26        callback: F2,
27    ) -> poem::Response
28    where
29        F1: FnMut(
30            http::Request<()>,
31        )
32            -> Result<(http::Request<()>, http::Response<T>), (http::Response<T>, WsError)>,
33        F2: FnOnce(http::Request<()>, poem::Upgraded) -> Fut + Send + Sync + 'static,
34        Fut: Future + Send + 'static,
35        T: Into<Body> + std::fmt::Debug,
36    {
37        let on_upgrade = match req.take_upgrade() {
38            Err(e) => {
39                tracing::error!("http upgrade failed {e}");
40                return poem::Response::builder()
41                    .version(http::Version::HTTP_11)
42                    .status(http::StatusCode::BAD_REQUEST)
43                    .body(());
44            }
45            Ok(i) => i,
46        };
47
48        let mut builder = http::Request::builder().method(req.method()).uri(req.uri());
49        for (k, v) in req.headers() {
50            builder = builder.header(k, v)
51        }
52        let req = builder.body(()).unwrap();
53        let (req, resp) = match handshake_handler(req) {
54            Ok(i) => i,
55            Err((resp, e)) => {
56                tracing::error!("handshake error {e}");
57                return convert(resp);
58            }
59        };
60        tokio::spawn(async move {
61            match on_upgrade.await {
62                Err(e) => {
63                    tracing::error!("http upgrade failed {e}");
64                    return;
65                }
66                Ok(upgraded) => {
67                    callback(req, upgraded).await;
68                }
69            }
70        });
71        convert(resp)
72    }
73}
74
75/// axum websocket extension
76#[cfg(feature = "axum")]
77pub mod axum_ext {
78    use http;
79    use std::future::Future;
80
81    use axum::{body::Body, response::Response};
82
83    use crate::errors::WsError;
84
85    /// accept axum raw request
86    pub async fn adapt<T, F1, F2, Fut>(
87        req: axum::extract::Request,
88        mut handshake_handler: F1,
89        callback: F2,
90    ) -> Response
91    where
92        F1: FnMut(
93            http::Request<()>,
94        )
95            -> Result<(http::Request<()>, http::Response<T>), (http::Response<T>, WsError)>,
96        F2: FnOnce(http::Request<()>, hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>) -> Fut
97            + Send
98            + Sync
99            + 'static,
100        Fut: Future + Send + 'static,
101        T: std::fmt::Debug + Into<Body>,
102    {
103        let (mut parts, _) = req.into_parts();
104        let on_upgrade = match parts.extensions.remove::<hyper::upgrade::OnUpgrade>() {
105            Some(on_upgrade) => on_upgrade,
106            None => {
107                tracing::error!("upgraded failed");
108                return Response::builder()
109                    .version(axum::http::Version::HTTP_11)
110                    .status(axum::http::StatusCode::BAD_REQUEST)
111                    .body("".into())
112                    .unwrap();
113            }
114        };
115        let req = axum::http::Request::from_parts(parts, ());
116        let (req, resp) = match handshake_handler(req) {
117            Ok(i) => i,
118            Err((resp, e)) => {
119                tracing::error!("handshake error {e}");
120                let (parts, body) = resp.into_parts();
121                return Response::from_parts(parts, body.into());
122            }
123        };
124        tokio::spawn(async move {
125            match on_upgrade.await {
126                Err(e) => {
127                    tracing::error!("http upgrade failed {e}");
128                    return;
129                }
130                Ok(upgraded) => {
131                    callback(req, hyper_util::rt::TokioIo::new(upgraded)).await;
132                }
133            }
134        });
135        let (parts, body) = resp.into_parts();
136        return Response::from_parts(parts, body.into());
137    }
138}