spring_web/
extractor.rs

1pub use axum::extract::*;
2
3use crate::error::{Result, WebError};
4use crate::AppState;
5use anyhow::Context;
6use axum::http::request::Parts;
7use spring::config::{ConfigRegistry, Configurable};
8use spring::plugin::ComponentRegistry;
9use spring::App;
10use std::ops::{Deref, DerefMut};
11use std::result::Result as StdResult;
12use std::sync::Arc;
13
14/// Extending the functionality of RequestParts
15pub trait RequestPartsExt {
16    /// get AppState
17    fn get_app_state(&self) -> &AppState;
18
19    /// get Component
20    fn get_component<T: Clone + Send + Sync + 'static>(&self) -> Result<T>;
21
22    /// get Config
23    fn get_config<T: serde::de::DeserializeOwned + Configurable>(&self) -> Result<T>;
24}
25
26impl RequestPartsExt for Parts {
27    fn get_app_state(&self) -> &AppState {
28        self.extensions
29            .get::<AppState>()
30            .expect("extract app state from extension failed")
31    }
32
33    fn get_component<T: Clone + Send + Sync + 'static>(&self) -> Result<T> {
34        Ok(self
35            .get_app_state()
36            .app
37            .try_get_component()
38            .context("get_component failed")?)
39    }
40
41    fn get_config<T: serde::de::DeserializeOwned + Configurable>(&self) -> Result<T> {
42        self.get_app_state()
43            .app
44            .get_config::<T>()
45            .map_err(|e| WebError::ConfigDeserializeErr(std::any::type_name::<T>(), Box::new(e)))
46    }
47}
48
49/// Extract the components registered by the plugin from AppState
50pub struct Component<T: Clone>(pub T);
51
52impl<T, S> FromRequestParts<S> for Component<T>
53where
54    T: Clone + Send + Sync + 'static,
55    S: Sync,
56{
57    type Rejection = WebError;
58
59    async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
60        parts.get_component::<T>().map(|c| Component(c))
61    }
62}
63
64#[cfg(feature = "openapi")]
65impl<T: Clone> aide::OperationInput for Component<T> {}
66
67impl<T: Clone> Deref for Component<T> {
68    type Target = T;
69
70    fn deref(&self) -> &Self::Target {
71        &self.0
72    }
73}
74
75impl<T: Clone> DerefMut for Component<T> {
76    fn deref_mut(&mut self) -> &mut Self::Target {
77        &mut self.0
78    }
79}
80
81pub struct Config<T>(pub T)
82where
83    T: serde::de::DeserializeOwned + Configurable;
84
85impl<T, S> FromRequestParts<S> for Config<T>
86where
87    T: serde::de::DeserializeOwned + Configurable,
88    S: Sync,
89{
90    type Rejection = WebError;
91
92    async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
93        parts.get_config().map(|c| Config(c))
94    }
95}
96
97#[cfg(feature = "openapi")]
98impl<T> aide::OperationInput for Config<T> where T: serde::de::DeserializeOwned + Configurable {}
99
100impl<T> Deref for Config<T>
101where
102    T: serde::de::DeserializeOwned + Configurable,
103{
104    type Target = T;
105
106    fn deref(&self) -> &Self::Target {
107        &self.0
108    }
109}
110
111impl<T> DerefMut for Config<T>
112where
113    T: serde::de::DeserializeOwned + Configurable,
114{
115    fn deref_mut(&mut self) -> &mut Self::Target {
116        &mut self.0
117    }
118}
119
120/// Extract Arc<App>
121pub struct AppRef(pub Arc<App>);
122
123impl<S> FromRequestParts<S> for AppRef
124where
125    S: Sync,
126{
127    type Rejection = WebError;
128
129    async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
130        Ok(Self(parts.get_app_state().app.clone()))
131    }
132}
133
134#[cfg(feature = "openapi")]
135impl aide::OperationInput for AppRef {}
136
137#[cfg(feature = "socket_io")]
138mod socketio_extractors {
139    use super::*;
140    use crate::socketioxide::adapter::LocalAdapter;
141    use crate::socketioxide::extract::HttpExtension;
142    use crate::socketioxide::handler::connect::FromConnectParts;
143    use crate::socketioxide::handler::disconnect::FromDisconnectParts;
144    use crate::socketioxide::handler::message::FromMessageParts;
145    use crate::socketioxide::socket::{DisconnectReason, Socket};
146    use socketioxide::handler::Value;
147    use std::sync::Arc;
148
149    #[derive(Debug)]
150    pub struct ComponentExtractError(pub String);
151
152    impl std::fmt::Display for ComponentExtractError {
153        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154            write!(f, "Component extraction error: {}", self.0)
155        }
156    }
157
158    impl std::error::Error for ComponentExtractError {}
159
160    impl<T> FromConnectParts<LocalAdapter> for Component<T>
161    where
162        T: Clone + Send + Sync + 'static,
163    {
164        type Error = ComponentExtractError;
165
166        fn from_connect_parts(
167            s: &Arc<Socket<LocalAdapter>>,
168            _auth: &Option<Value>,
169        ) -> StdResult<Self, Self::Error> {
170            let app = HttpExtension::<AppState>::from_connect_parts(s, _auth)
171                .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
172
173            app.app
174                .try_get_component()
175                .map(|c| Component(c))
176                .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
177        }
178    }
179
180    impl<T> FromMessageParts<LocalAdapter> for Component<T>
181    where
182        T: Clone + Send + Sync + 'static,
183    {
184        type Error = ComponentExtractError;
185
186        fn from_message_parts(
187            s: &Arc<Socket<LocalAdapter>>,
188            _data: &mut Value,
189            _ack_id: &Option<i64>,
190        ) -> StdResult<Self, Self::Error> {
191            let app = HttpExtension::<AppState>::from_message_parts(s, _data, _ack_id)
192                .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
193
194            app.app
195                .try_get_component()
196                .map(|c| Component(c))
197                .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
198        }
199    }
200
201    impl<T> FromDisconnectParts<LocalAdapter> for Component<T>
202    where
203        T: Clone + Send + Sync + 'static,
204    {
205        type Error = ComponentExtractError;
206
207        fn from_disconnect_parts(
208            s: &Arc<Socket<LocalAdapter>>,
209            reason: DisconnectReason,
210        ) -> StdResult<Self, Self::Error> {
211            let app = HttpExtension::<AppState>::from_disconnect_parts(s, reason)
212                .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
213
214            app.app
215                .try_get_component()
216                .map(|c| Component(c))
217                .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
218        }
219    }
220}