socketioxide/extract/
extensions.rs1use std::convert::Infallible;
2use std::sync::Arc;
3
4use crate::adapter::Adapter;
5use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts};
6use crate::socket::{DisconnectReason, Socket};
7use socketioxide_core::Value;
8
9#[cfg(feature = "extensions")]
10#[cfg_attr(docsrs, doc(cfg(feature = "extensions")))]
11pub use extensions_extract::*;
12
13pub struct ExtensionNotFound<T>(std::marker::PhantomData<T>);
15
16impl<T> std::fmt::Display for ExtensionNotFound<T> {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(
19 f,
20 "Extension of type {} not found, maybe you forgot to insert it in the extensions map?",
21 std::any::type_name::<T>()
22 )
23 }
24}
25impl<T> std::fmt::Debug for ExtensionNotFound<T> {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "ExtensionNotFound {}", std::any::type_name::<T>())
28 }
29}
30impl<T> std::error::Error for ExtensionNotFound<T> {}
31
32fn extract_http_extension<T: Clone + Send + Sync + 'static>(
33 s: &Arc<Socket<impl Adapter>>,
34) -> Result<T, ExtensionNotFound<T>> {
35 s.req_parts()
36 .extensions
37 .get::<T>()
38 .cloned()
39 .ok_or(ExtensionNotFound(std::marker::PhantomData))
40}
41
42pub struct HttpExtension<T>(pub T);
44
45pub struct MaybeHttpExtension<T>(pub Option<T>);
47
48impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for HttpExtension<T> {
49 type Error = ExtensionNotFound<T>;
50 fn from_connect_parts(
51 s: &Arc<Socket<A>>,
52 _: &Option<Value>,
53 ) -> Result<Self, ExtensionNotFound<T>> {
54 extract_http_extension(s).map(HttpExtension)
55 }
56}
57
58impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeHttpExtension<T> {
59 type Error = Infallible;
60 fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
61 Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
62 }
63}
64
65impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for HttpExtension<T> {
66 type Error = ExtensionNotFound<T>;
67 fn from_disconnect_parts(
68 s: &Arc<Socket<A>>,
69 _: DisconnectReason,
70 ) -> Result<Self, ExtensionNotFound<T>> {
71 extract_http_extension(s).map(HttpExtension)
72 }
73}
74impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A>
75 for MaybeHttpExtension<T>
76{
77 type Error = Infallible;
78 fn from_disconnect_parts(s: &Arc<Socket<A>>, _: DisconnectReason) -> Result<Self, Infallible> {
79 Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
80 }
81}
82
83impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for HttpExtension<T> {
84 type Error = ExtensionNotFound<T>;
85 fn from_message_parts(
86 s: &Arc<Socket<A>>,
87 _: &mut Value,
88 _: &Option<i64>,
89 ) -> Result<Self, ExtensionNotFound<T>> {
90 extract_http_extension(s).map(HttpExtension)
91 }
92}
93impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeHttpExtension<T> {
94 type Error = Infallible;
95 fn from_message_parts(
96 s: &Arc<Socket<A>>,
97 _: &mut Value,
98 _: &Option<i64>,
99 ) -> Result<Self, Infallible> {
100 Ok(MaybeHttpExtension(extract_http_extension(s).ok()))
101 }
102}
103
104super::__impl_deref!(HttpExtension);
105super::__impl_deref!(MaybeHttpExtension<T>: Option<T>);
106
107#[cfg(feature = "extensions")]
108mod extensions_extract {
109 use super::*;
110
111 fn extract_extension<T: Clone + Send + Sync + 'static>(
112 s: &Arc<Socket<impl Adapter>>,
113 ) -> Result<T, ExtensionNotFound<T>> {
114 s.extensions
115 .get::<T>()
116 .ok_or(ExtensionNotFound(std::marker::PhantomData))
117 }
118
119 pub struct Extension<T>(pub T);
125
126 pub struct MaybeExtension<T>(pub Option<T>);
128
129 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for Extension<T> {
130 type Error = ExtensionNotFound<T>;
131 fn from_connect_parts(
132 s: &Arc<Socket<A>>,
133 _: &Option<Value>,
134 ) -> Result<Self, ExtensionNotFound<T>> {
135 extract_extension(s).map(Extension)
136 }
137 }
138 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for MaybeExtension<T> {
139 type Error = Infallible;
140 fn from_connect_parts(s: &Arc<Socket<A>>, _: &Option<Value>) -> Result<Self, Infallible> {
141 Ok(MaybeExtension(extract_extension(s).ok()))
142 }
143 }
144 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for Extension<T> {
145 type Error = ExtensionNotFound<T>;
146 fn from_disconnect_parts(
147 s: &Arc<Socket<A>>,
148 _: DisconnectReason,
149 ) -> Result<Self, ExtensionNotFound<T>> {
150 extract_extension(s).map(Extension)
151 }
152 }
153 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for MaybeExtension<T> {
154 type Error = Infallible;
155 fn from_disconnect_parts(
156 s: &Arc<Socket<A>>,
157 _: DisconnectReason,
158 ) -> Result<Self, Infallible> {
159 Ok(MaybeExtension(extract_extension(s).ok()))
160 }
161 }
162 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for Extension<T> {
163 type Error = ExtensionNotFound<T>;
164 fn from_message_parts(
165 s: &Arc<Socket<A>>,
166 _: &mut Value,
167 _: &Option<i64>,
168 ) -> Result<Self, ExtensionNotFound<T>> {
169 extract_extension(s).map(Extension)
170 }
171 }
172 impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for MaybeExtension<T> {
173 type Error = Infallible;
174 fn from_message_parts(
175 s: &Arc<Socket<A>>,
176 _: &mut Value,
177 _: &Option<i64>,
178 ) -> Result<Self, Infallible> {
179 Ok(MaybeExtension(extract_extension(s).ok()))
180 }
181 }
182 super::super::__impl_deref!(Extension);
183 super::super::__impl_deref!(MaybeExtension<T>: Option<T>);
184}