1use axum::{
2 extract::{FromRef, Path, Request, State},
3 http::{self, StatusCode},
4 middleware::Next,
5 response::{IntoResponse, Response},
6 Extension, Json, Router,
7};
8use scalar_cms::{
9 db::{Credentials, DatabaseFactory, User},
10 validations::{Valid, ValidationError},
11 DatabaseConnection, Document, Item, Schema,
12};
13use serde::{de::DeserializeOwned, Serialize};
14
15#[cfg(feature = "img")]
16pub mod img;
17
18pub struct ValidationFailiure(pub ValidationError);
19
20impl IntoResponse for ValidationFailiure {
21 fn into_response(self) -> axum::response::Response {
22 let mut response = Json(self.0).into_response();
23 *response.status_mut() = StatusCode::NOT_ACCEPTABLE;
24 response
25 }
26}
27
28#[cfg(feature = "img")]
29#[doc(hidden)]
30pub fn add_image_routes__<S: Clone + Send + Sync + 'static>(router: Router<S>) -> Router<S>
31where
32 scalar_img::WrappedBucket: FromRef<S>,
33{
34 use axum::extract::DefaultBodyLimit;
35 use img::{list, upload_file, upload_image};
36
37 let merge = Router::new()
38 .route(
39 "/images/upload",
40 axum::routing::put(upload_image).layer(DefaultBodyLimit::max(25_000_000)),
41 )
42 .route(
43 "/files/upload",
44 axum::routing::put(upload_file).layer(DefaultBodyLimit::disable()),
45 )
46 .route("/images/list", axum::routing::get(list));
47
48 router.merge(merge)
49}
50
51#[cfg(not(feature = "img"))]
52#[doc(hidden)]
53pub fn add_image_routes__<S: Clone + Send + Sync + 'static>(router: Router<S>) -> Router<S> {
54 router
55}
56
57#[macro_export]
58#[doc(hidden)]
59macro_rules! crud_routes__ {
60 ($router:ident, $db:ty, $doc:ty) => {
61 let path = format!("/docs/{}", <$doc>::identifier());
62 let drafts_path = format!("{path}/drafts/{{id}}");
63 $router = $router
64 .route(&path, ::axum::routing::get(::scalar_axum::get_all_docs::<$doc, $db>))
65 .route(&format!("{path}/{{id}}"), ::axum::routing::get(::scalar_axum::get_doc_by_id::<$doc, $db>))
66 .route(&drafts_path, ::axum::routing::put(::scalar_axum::update_draft::<$doc, $db>))
67 .route(&format!("{path}/schema"), ::axum::routing::get(::scalar_axum::get_schema::<$doc>));
68 };
69
70 ($router:ident, $db:ty, $($doc:ty),+) => {
71 $(::scalar_axum::crud_routes__!($router, $db, $doc);)*
72 };
73}
74
75#[macro_export]
76#[doc(hidden)]
77macro_rules! publish_routes__ {
78 ($router:ident, $db:ty, $doc:ty) => {
79 let path = format!("/docs/{}", <$doc>::identifier());
80 $router = $router
81 .route(&format!("{path}/{{id}}/publish"), ::axum::routing::post(::scalar_axum::publish_doc::<$doc, $db>));
82 };
83
84 ($router:ident, $db:ty, $($doc:ty),+) => {
85 $(::scalar_axum::publish_routes__!($router, $db, $doc);)*
86 };
87}
88
89#[macro_export]
90#[doc(hidden)]
91macro_rules! validate_routes__ {
92 ($router:ident, $doc:ty) => {
93 $router = $router
94 .route(&format!("/docs/{}/validate", <$doc>::identifier()), ::axum::routing::post(::scalar_axum::validate::<$doc>));
95 };
96
97 ($router:ident, $($doc:ty),+) => {
98 $(::scalar_axum::validate_routes__!($router, $doc);)*
99 };
100}
101
102#[macro_export]
103macro_rules! generate_routes {
104 ({$app_state:ty}, $db_instance:ident: $db:ty, [$($doc:ty),+]) => {
105 {
106 let mut router = ::axum::Router::<$app_state>::new();
107 ::scalar_axum::crud_routes__!(router, $db, $($doc),+);
108 ::scalar_axum::publish_routes__!(router, $db, $($doc),+);
109 async fn get_docs() -> ::axum::Json<Vec<::scalar_cms::DocInfo>> {
110 ::axum::Json(vec![
111 $(::scalar_cms::DocInfo {
112 identifier: <$doc>::identifier(),
113 title: <$doc>::title()
114 }),+
115 ])
116 }
117 router = router.route("/docs", ::axum::routing::get(get_docs));
118
119 router = router.route("/me", ::axum::routing::get(::scalar_axum::me::<$db>));
120 router = ::scalar_axum::add_image_routes__(router);
121 router = router.layer(::axum::middleware::from_fn_with_state($db_instance.clone(), ::scalar_axum::authenticated_connection_middleware::<$db>));
122 router = router.route("/signin", ::axum::routing::post(::scalar_axum::signin::<$db>));
123
124 ::scalar_axum::validate_routes__!(router, $($doc),+);
125
126 router
127 }
128 };
129}
130
131pub async fn authenticated_connection_middleware<F: DatabaseFactory + Clone>(
132 State(db_factory): State<F>,
133 mut req: Request,
134 next: Next,
135) -> Result<Response, StatusCode>
136where
137 <F as DatabaseFactory>::Connection: 'static,
138{
139 let auth_header = req
140 .headers()
141 .get(http::header::AUTHORIZATION)
142 .map(|header| {
143 header
144 .to_str()
145 .map(str::trim)
146 .map_err(|_| StatusCode::BAD_REQUEST)
147 })
148 .ok_or(StatusCode::UNAUTHORIZED)??;
149
150 let connection = db_factory.init().await.map_err(|e| {
151 println!("{e}");
152 StatusCode::INTERNAL_SERVER_ERROR
153 })?;
154
155 let (_, token) = auth_header
156 .starts_with("Bearer ")
157 .then(|| {
158 auth_header
159 .split_at_checked(7)
160 .ok_or(StatusCode::UNAUTHORIZED)
161 })
162 .ok_or(StatusCode::UNAUTHORIZED)??;
163
164 connection.authenticate(token).await.map_err(|e| match e {
165 scalar_cms::db::AuthenticationError::BadToken => StatusCode::UNAUTHORIZED,
166 scalar_cms::db::AuthenticationError::BadCredentials => StatusCode::UNAUTHORIZED,
167 scalar_cms::db::AuthenticationError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
168 })?;
169
170 req.extensions_mut().insert(connection);
171
172 Ok(next.run(req).await)
173}
174
175pub async fn signin<F: DatabaseFactory + Clone>(
177 State(factory): State<F>,
178 Json(credentials): Json<Credentials>,
179) -> Result<String, StatusCode> {
180 let connection = factory.init().await.map_err(|e| {
181 println!("{e}");
182 StatusCode::INTERNAL_SERVER_ERROR
183 })?;
184
185 println!("connection");
186
187 let token = connection.signin(credentials).await.map_err(|e| match e {
188 scalar_cms::db::AuthenticationError::BadToken => StatusCode::UNAUTHORIZED,
189 scalar_cms::db::AuthenticationError::BadCredentials => StatusCode::UNAUTHORIZED,
190 scalar_cms::db::AuthenticationError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
191 })?;
192
193 Ok(token)
194}
195
196pub async fn get_schema<T: Document>() -> Json<Schema> {
197 Json(T::schema())
198}
199
200pub async fn validate<D: Document>(
201 Json(doc): Json<D>,
202) -> Result<(), (StatusCode, Json<ValidationError>)> {
203 doc.validate()
204 .map_err(|e| (StatusCode::UNPROCESSABLE_ENTITY, Json(e)))
205}
206
207pub async fn me<F: DatabaseFactory>(
208 state: Extension<<F as DatabaseFactory>::Connection>,
209) -> Result<Json<User>, StatusCode> {
210 Ok(Json(state.me().await.map_err(|e| {
211 println!("{e}");
212 StatusCode::INTERNAL_SERVER_ERROR
213 })?))
214}
215
216pub async fn update_draft<T: Document + Serialize + DeserializeOwned + Send, F: DatabaseFactory>(
217 state: Extension<<F as DatabaseFactory>::Connection>,
218 Path(id): Path<String>,
219 Json(data): Json<serde_json::Value>,
220) -> Result<Json<Item<serde_json::Value>>, StatusCode> {
221 Ok(Json(
222 state
223 .draft::<T>(&id, data)
224 .await
225 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
226 ))
227}
228
229pub async fn publish_doc<
230 D: Document + Serialize + DeserializeOwned + Send + 'static,
231 F: DatabaseFactory,
232>(
233 Path(id): Path<String>,
234 state: Extension<<F as DatabaseFactory>::Connection>,
235 doc: Json<D>,
236) -> Result<(), StatusCode> {
237 state
238 .publish(
239 &id,
240 None,
241 Valid::new(doc.0).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?,
242 )
243 .await
244 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
245
246 Ok(())
247}
248
249pub async fn get_all_docs<T: Document + Serialize + DeserializeOwned + Send, F: DatabaseFactory>(
250 state: Extension<<F as DatabaseFactory>::Connection>,
251) -> Result<Json<Vec<Item<serde_json::Value>>>, StatusCode> {
252 let items = state
253 .get_all::<T>()
254 .await
255 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
256
257 Ok(Json(items))
258}
259
260pub async fn get_doc_by_id<
261 T: Document + Serialize + DeserializeOwned + Send,
262 F: DatabaseFactory,
263>(
264 state: Extension<<F as DatabaseFactory>::Connection>,
265 id: Path<String>,
266) -> Result<Json<Item<serde_json::Value>>, StatusCode> {
267 state
268 .get_by_id::<T>(id.as_str())
269 .await
270 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
271 .map(Json)
272 .ok_or(StatusCode::NOT_FOUND)
273}