Skip to main content

socketioxide/extract/
extensions.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3
4use crate::adapter::Adapter;
5use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts};
6use crate::socket::{DisconnectReason, Socket};
7use socketioxide_core::Value;
8
9#[cfg(feature = "extensions")]
10#[cfg_attr(docsrs, doc(cfg(feature = "extensions")))]
11pub use extensions_extract::*;
12
13/// It was impossible to find the given extension.
14pub struct ExtensionNotFound<T>(std::marker::PhantomData<T>);
15
16impl<T> std::fmt::Display for ExtensionNotFound<T> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(
19            f,
20            "Extension of type {} not found, maybe you forgot to insert it in the extensions map?",
21            std::any::type_name::<T>()
22        )
23    }
24}
25impl<T> std::fmt::Debug for ExtensionNotFound<T> {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        write!(f, "ExtensionNotFound {}", std::any::type_name::<T>())
28    }
29}
30impl<T> std::error::Error for ExtensionNotFound<T> {}
31
32fn extract_http_extension<T: Clone + Send + Sync + 'static>(
33    s: &Arc<Socket<impl Adapter>>,
34) -> Result<T, ExtensionNotFound<T>> {
35    s.req_parts()
36        .extensions
37        .get::<T>()
38        .cloned()
39        .ok_or(ExtensionNotFound(std::marker::PhantomData))
40}
41
42/// An Extractor that returns a clone extension from the request parts.
43pub struct HttpExtension<T>(pub T);
44
45/// An Extractor that returns a clone extension from the request parts if it exists.
46pub struct MaybeHttpExtension<T>(pub Option<T>);
47
48impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for HttpExtension<T> {
49    type Error = ExtensionNotFound<T>;
50    fn from_connect_parts(
51        s: &Arc<Socket<A>>,
52        _: &Option<Value>,
53    ) -> Result<Self, ExtensionNotFound<T>> {
54        extract_http_extension(s).map(HttpExtension)
55    }
56}
57
58impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeHttpExtension<T> {
59    type Error = Infallible;
60    fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
61        Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
62    }
63}
64
65impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for HttpExtension<T> {
66    type Error = ExtensionNotFound<T>;
67    fn from_disconnect_parts(
68        s: &Arc<Socket<A>>,
69        _: DisconnectReason,
70    ) -> Result<Self, ExtensionNotFound<T>> {
71        extract_http_extension(s).map(HttpExtension)
72    }
73}
74impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A>
75    for MaybeHttpExtension<T>
76{
77    type Error = Infallible;
78    fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
79        Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
80    }
81}
82
83impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for HttpExtension<T> {
84    type Error = ExtensionNotFound<T>;
85    fn from_message_parts(
86        s: &Arc<Socket<A>>,
87        _: &mut Value,
88        _: &Option<i64>,
89    ) -> Result<Self, ExtensionNotFound<T>> {
90        extract_http_extension(s).map(HttpExtension)
91    }
92}
93impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeHttpExtension<T> {
94    type Error = Infallible;
95    fn from_message_parts(
96        s: &Arc<Socket<A>>,
97        _: &mut Value,
98        _: &Option<i64>,
99    ) -> Result<Self, Infallible> {
100        Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
101    }
102}
103
104super::__impl_deref!(HttpExtension);
105super::__impl_deref!(MaybeHttpExtension<T>: Option<T>);
106
107#[cfg(feature = "extensions")]
108mod extensions_extract {
109    use super::*;
110
111    fn extract_extension<T: Clone + Send + Sync + 'static>(
112        s: &Arc<Socket<impl Adapter>>,
113    ) -> Result<T, ExtensionNotFound<T>> {
114        s.extensions
115            .get::<T>()
116            .ok_or(ExtensionNotFound(std::marker::PhantomData))
117    }
118
119    /// An Extractor that returns the extension of the given type.
120    /// If the extension is not found,
121    /// the handler won't be called and an error log will be print if the `tracing` feature is enabled.
122    ///
123    /// You can use [`MaybeExtension`] if the extensions you are requesting _may_ not exists.
124    pub struct Extension<T>(pub T);
125
126    /// An Extractor that returns the extension of the given type T if it exists or [`None`] otherwise.
127    pub struct MaybeExtension<T>(pub Option<T>);
128
129    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for Extension<T> {
130        type Error = ExtensionNotFound<T>;
131        fn from_connect_parts(
132            s: &Arc<Socket<A>>,
133            _: &Option<Value>,
134        ) -> Result<Self, ExtensionNotFound<T>> {
135            extract_extension(s).map(Extension)
136        }
137    }
138    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeExtension<T> {
139        type Error = Infallible;
140        fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
141            Ok(MaybeExtension(extract_extension(s).ok()))
142        }
143    }
144    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for Extension<T> {
145        type Error = ExtensionNotFound<T>;
146        fn from_disconnect_parts(
147            s: &Arc<Socket<A>>,
148            _: DisconnectReason,
149        ) -> Result<Self, ExtensionNotFound<T>> {
150            extract_extension(s).map(Extension)
151        }
152    }
153    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for MaybeExtension<T> {
154        type Error = Infallible;
155        fn from_disconnect_parts(
156            s: &Arc<Socket<A>>,
157            _: DisconnectReason,
158        ) -> Result<Self, Infallible> {
159            Ok(MaybeExtension(extract_extension(s).ok()))
160        }
161    }
162    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for Extension<T> {
163        type Error = ExtensionNotFound<T>;
164        fn from_message_parts(
165            s: &Arc<Socket<A>>,
166            _: &mut Value,
167            _: &Option<i64>,
168        ) -> Result<Self, ExtensionNotFound<T>> {
169            extract_extension(s).map(Extension)
170        }
171    }
172    impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeExtension<T> {
173        type Error = Infallible;
174        fn from_message_parts(
175            s: &Arc<Socket<A>>,
176            _: &mut Value,
177            _: &Option<i64>,
178        ) -> Result<Self, Infallible> {
179            Ok(MaybeExtension(extract_extension(s).ok()))
180        }
181    }
182    super::super::__impl_deref!(Extension);
183    super::super::__impl_deref!(MaybeExtension<T>: Option<T>);
184}