1use std::sync::Arc;
18
19use std::convert::Infallible;
20
21use axum::{
22 Json, Router,
23 extract::{FromRequest, Path, Query, Request, State, rejection::JsonRejection},
24 http::StatusCode,
25 middleware::{self, Next},
26 response::{
27 IntoResponse, Response,
28 sse::{Event, KeepAlive, Sse},
29 },
30 routing::{delete, get, post},
31};
32use serde::{Deserialize, de::DeserializeOwned};
33use solti_model::{OutputEvent, TaskId, TaskPhase, TaskQuery};
34use tokio_stream::StreamExt;
35use tower_http::limit::RequestBodyLimitLayer;
36use tracing::debug;
37
38use crate::{
39 MAX_REQUEST_BYTES,
40 convert::{self, tasks_page_to_proto},
41 error::ApiError,
42 handler::ApiHandler,
43 proto_api,
44 validate::{clamp_list_limit, non_empty_id},
45};
46pub(crate) struct ApiJson<T>(pub T);
52
53impl<T, S> FromRequest<S> for ApiJson<T>
54where
55 T: DeserializeOwned,
56 S: Send + Sync,
57{
58 type Rejection = ApiError;
59
60 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
61 let Json(value) = axum::Json::<T>::from_request(req, state)
62 .await
63 .map_err(map_json_rejection)?;
64 Ok(ApiJson(value))
65 }
66}
67
68fn map_json_rejection(rej: JsonRejection) -> ApiError {
69 if rej.status() == StatusCode::PAYLOAD_TOO_LARGE {
70 return ApiError::PayloadTooLarge(format!(
71 "request body exceeds the maximum of {} bytes",
72 MAX_REQUEST_BYTES
73 ));
74 }
75
76 let msg = rej.body_text();
77 let trimmed = msg
78 .strip_prefix("Failed to deserialize the JSON body into the target type: ")
79 .or_else(|| msg.strip_prefix("Failed to parse the request body as JSON: "))
80 .unwrap_or(&msg)
81 .to_string();
82 ApiError::InvalidRequest(trimmed)
83}
84
85async fn map_413_envelope(req: Request, next: Next) -> Response {
86 let resp = next.run(req).await;
87 if resp.status() == StatusCode::PAYLOAD_TOO_LARGE {
88 let body = serde_json::json!({
89 "error": "PayloadTooLarge",
90 "message": format!(
91 "request body exceeds the maximum of {} bytes",
92 MAX_REQUEST_BYTES
93 ),
94 });
95 return (StatusCode::PAYLOAD_TOO_LARGE, Json(body)).into_response();
96 }
97 resp
98}
99
100pub struct HttpApi<H> {
107 handler: Arc<H>,
108}
109
110impl<H> HttpApi<H>
111where
112 H: ApiHandler,
113{
114 pub fn new(handler: Arc<H>) -> Self {
116 Self { handler }
117 }
118
119 pub fn router(self) -> Router {
123 Router::new()
124 .route(api_url!("/tasks"), post(submit_task::<H>))
125 .route(api_url!("/tasks"), get(list_tasks::<H>))
126 .route(api_url!("/tasks/{id}"), get(get_task_status::<H>))
127 .route(api_url!("/tasks/{id}"), delete(delete_task::<H>))
128 .route(api_url!("/tasks/{id}/runs"), get(list_task_runs::<H>))
129 .route(api_url!("/tasks/{id}/logs"), get(stream_task_logs::<H>))
130 .layer(RequestBodyLimitLayer::new(MAX_REQUEST_BYTES))
131 .layer(middleware::from_fn(map_413_envelope))
132 .with_state(self.handler)
133 }
134}
135
136#[derive(Debug, Deserialize)]
137struct ListTasksParams {
138 slot: Option<String>,
139 status: Option<String>,
140 limit: Option<u32>,
141 offset: Option<u32>,
142}
143
144async fn submit_task<H>(
145 State(handler): State<Arc<H>>,
146 ApiJson(req): ApiJson<proto_api::SubmitTaskRequest>,
147) -> Result<impl IntoResponse, ApiError>
148where
149 H: ApiHandler,
150{
151 let spec = req
152 .spec
153 .ok_or_else(|| ApiError::InvalidRequest("missing spec".into()))?;
154 let spec = convert::convert_create_spec(spec)?;
155
156 debug!(slot = %spec.slot(), kind = ?spec.kind(), "submitting task");
157 let task_id = handler.submit_task(spec).await?;
158
159 let response = proto_api::SubmitTaskResponse {
160 task_id: task_id.to_string(),
161 };
162 Ok((StatusCode::CREATED, Json(response)))
163}
164
165async fn get_task_status<H>(
166 State(handler): State<Arc<H>>,
167 Path(id): Path<String>,
168) -> Result<impl IntoResponse, ApiError>
169where
170 H: ApiHandler,
171{
172 non_empty_id("task_id", &id)?;
173
174 let task_id = TaskId::from(id);
175 debug!(%task_id, "getting task status");
176 let task = handler.get_task_status(&task_id).await?;
177
178 let task = task.map(proto_api::TaskData::try_from).transpose()?;
179 Ok(Json(proto_api::GetTaskStatusResponse { task }))
180}
181
182async fn list_tasks<H>(
183 State(handler): State<Arc<H>>,
184 Query(params): Query<ListTasksParams>,
185) -> Result<impl IntoResponse, ApiError>
186where
187 H: ApiHandler,
188{
189 let mut query = TaskQuery::new();
190
191 if let Some(slot) = params.slot {
192 non_empty_id("slot", &slot)?;
193 query = query.with_slot(slot);
194 }
195
196 if let Some(status_str) = params.status {
197 let status = status_str.parse::<TaskPhase>().map_err(|_| {
198 ApiError::InvalidRequest(format!(
199 "invalid status: '{status_str}' (valid: pending, running, succeeded, failed, timeout, canceled, exhausted)"
200 ))
201 })?;
202 query = query.with_status(status);
203 }
204
205 query = query.with_limit(clamp_list_limit(params.limit.unwrap_or(0)));
206 if let Some(offset) = params.offset {
207 query = query.with_offset(offset as usize);
208 }
209
210 let page = handler.query_tasks(query).await?;
211 debug!(count = page.items.len(), total = page.total, "tasks listed");
212
213 Ok(Json(tasks_page_to_proto(page)?))
214}
215
216async fn list_task_runs<H>(
217 State(handler): State<Arc<H>>,
218 Path(id): Path<String>,
219) -> Result<impl IntoResponse, ApiError>
220where
221 H: ApiHandler,
222{
223 non_empty_id("task_id", &id)?;
224
225 let task_id = TaskId::from(id);
226 debug!(%task_id, "listing task runs");
227 let runs = handler.list_task_runs(&task_id).await?;
228 let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
229
230 Ok(Json(proto_api::ListTaskRunsResponse { runs }))
231}
232
233async fn delete_task<H>(
234 State(handler): State<Arc<H>>,
235 Path(id): Path<String>,
236) -> Result<impl IntoResponse, ApiError>
237where
238 H: ApiHandler,
239{
240 non_empty_id("task_id", &id)?;
241
242 let task_id = TaskId::from(id);
243 handler.delete_task(&task_id).await?;
244 debug!(%task_id, "task deleted");
245
246 Ok(StatusCode::NO_CONTENT)
247}
248
249async fn stream_task_logs<H>(
251 State(handler): State<Arc<H>>,
252 Path(id): Path<String>,
253) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, ApiError>
254where
255 H: ApiHandler,
256{
257 non_empty_id("task_id", &id)?;
258
259 let task_id = TaskId::from(id);
260 debug!(%task_id, "subscribing to task log stream");
261 let stream = handler.stream_task_logs(&task_id).await?;
262
263 let sse_stream = stream.map(|ev| {
264 let name = match &ev {
265 OutputEvent::Chunk(_) => "chunk",
266 OutputEvent::RunStarted { .. } => "run-started",
267 OutputEvent::RunFinished { .. } => "run-finished",
268 OutputEvent::Lagged { .. } => "lagged",
269 };
270 let data = serde_json::to_string(&ev).unwrap_or_else(|_| "{}".into());
271 Ok(Event::default().event(name).data(data))
272 });
273 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
274}