1pub use axum::extract::*;
2
3use crate::error::{Result, WebError};
4use crate::AppState;
5use anyhow::Context;
6use axum::http::request::Parts;
7use spring::config::{ConfigRegistry, Configurable};
8use spring::plugin::ComponentRegistry;
9use spring::App;
10use std::ops::{Deref, DerefMut};
11use std::result::Result as StdResult;
12use std::sync::Arc;
13
14pub trait RequestPartsExt {
16 fn get_app_state(&self) -> &AppState;
18
19 fn get_component<T: Clone + Send + Sync + 'static>(&self) -> Result<T>;
21
22 fn get_config<T: serde::de::DeserializeOwned + Configurable>(&self) -> Result<T>;
24}
25
26impl RequestPartsExt for Parts {
27 fn get_app_state(&self) -> &AppState {
28 self.extensions
29 .get::<AppState>()
30 .expect("extract app state from extension failed")
31 }
32
33 fn get_component<T: Clone + Send + Sync + 'static>(&self) -> Result<T> {
34 Ok(self
35 .get_app_state()
36 .app
37 .try_get_component()
38 .context("get_component failed")?)
39 }
40
41 fn get_config<T: serde::de::DeserializeOwned + Configurable>(&self) -> Result<T> {
42 self.get_app_state()
43 .app
44 .get_config::<T>()
45 .map_err(|e| WebError::ConfigDeserializeErr(std::any::type_name::<T>(), Box::new(e)))
46 }
47}
48
49pub struct Component<T: Clone>(pub T);
51
52impl<T, S> FromRequestParts<S> for Component<T>
53where
54 T: Clone + Send + Sync + 'static,
55 S: Sync,
56{
57 type Rejection = WebError;
58
59 async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
60 parts.get_component::<T>().map(|c| Component(c))
61 }
62}
63
64#[cfg(feature = "openapi")]
65impl<T: Clone> aide::OperationInput for Component<T> {}
66
67impl<T: Clone> Deref for Component<T> {
68 type Target = T;
69
70 fn deref(&self) -> &Self::Target {
71 &self.0
72 }
73}
74
75impl<T: Clone> DerefMut for Component<T> {
76 fn deref_mut(&mut self) -> &mut Self::Target {
77 &mut self.0
78 }
79}
80
81pub struct Config<T>(pub T)
82where
83 T: serde::de::DeserializeOwned + Configurable;
84
85impl<T, S> FromRequestParts<S> for Config<T>
86where
87 T: serde::de::DeserializeOwned + Configurable,
88 S: Sync,
89{
90 type Rejection = WebError;
91
92 async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
93 parts.get_config().map(|c| Config(c))
94 }
95}
96
97#[cfg(feature = "openapi")]
98impl<T> aide::OperationInput for Config<T> where T: serde::de::DeserializeOwned + Configurable {}
99
100impl<T> Deref for Config<T>
101where
102 T: serde::de::DeserializeOwned + Configurable,
103{
104 type Target = T;
105
106 fn deref(&self) -> &Self::Target {
107 &self.0
108 }
109}
110
111impl<T> DerefMut for Config<T>
112where
113 T: serde::de::DeserializeOwned + Configurable,
114{
115 fn deref_mut(&mut self) -> &mut Self::Target {
116 &mut self.0
117 }
118}
119
120pub struct AppRef(pub Arc<App>);
122
123impl<S> FromRequestParts<S> for AppRef
124where
125 S: Sync,
126{
127 type Rejection = WebError;
128
129 async fn from_request_parts(parts: &mut Parts, _s: &S) -> StdResult<Self, Self::Rejection> {
130 Ok(Self(parts.get_app_state().app.clone()))
131 }
132}
133
134#[cfg(feature = "socket_io")]
135mod socketio_extractors {
136 use super::*;
137 use crate::socketioxide::adapter::LocalAdapter;
138 use crate::socketioxide::extract::HttpExtension;
139 use crate::socketioxide::handler::connect::FromConnectParts;
140 use crate::socketioxide::handler::disconnect::FromDisconnectParts;
141 use crate::socketioxide::handler::message::FromMessageParts;
142 use crate::socketioxide::socket::{DisconnectReason, Socket};
143 use socketioxide::handler::Value;
144 use std::sync::Arc;
145
146 #[derive(Debug)]
147 pub struct ComponentExtractError(pub String);
148
149 impl std::fmt::Display for ComponentExtractError {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 write!(f, "Component extraction error: {}", self.0)
152 }
153 }
154
155 impl std::error::Error for ComponentExtractError {}
156
157 impl<T> FromConnectParts<LocalAdapter> for Component<T>
158 where
159 T: Clone + Send + Sync + 'static,
160 {
161 type Error = ComponentExtractError;
162
163 fn from_connect_parts(
164 s: &Arc<Socket<LocalAdapter>>,
165 _auth: &Option<Value>,
166 ) -> StdResult<Self, Self::Error> {
167 let app = HttpExtension::<AppState>::from_connect_parts(s, _auth)
168 .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
169
170 app.app
171 .try_get_component()
172 .map(|c| Component(c))
173 .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
174 }
175 }
176
177 impl<T> FromMessageParts<LocalAdapter> for Component<T>
178 where
179 T: Clone + Send + Sync + 'static,
180 {
181 type Error = ComponentExtractError;
182
183 fn from_message_parts(
184 s: &Arc<Socket<LocalAdapter>>,
185 _data: &mut Value,
186 _ack_id: &Option<i64>,
187 ) -> StdResult<Self, Self::Error> {
188 let app = HttpExtension::<AppState>::from_message_parts(s, _data, _ack_id)
189 .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
190
191 app.app
192 .try_get_component()
193 .map(|c| Component(c))
194 .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
195 }
196 }
197
198 impl<T> FromDisconnectParts<LocalAdapter> for Component<T>
199 where
200 T: Clone + Send + Sync + 'static,
201 {
202 type Error = ComponentExtractError;
203
204 fn from_disconnect_parts(
205 s: &Arc<Socket<LocalAdapter>>,
206 reason: DisconnectReason,
207 ) -> StdResult<Self, Self::Error> {
208 let app = HttpExtension::<AppState>::from_disconnect_parts(s, reason)
209 .map_err(|e| ComponentExtractError(format!("Failed to extract AppState: {}", e)))?;
210
211 app.app
212 .try_get_component()
213 .map(|c| Component(c))
214 .map_err(|e| ComponentExtractError(format!("Failed to get component: {}", e)))
215 }
216 }
217}