Skip to main content

solti_api/
http.rs

1//! # HTTP/JSON transport.
2//!
3//! Axum router exposing [`ApiHandler`] operations as REST-shaped JSON endpoints.
4//! All paths share the `/api/v<MAJOR>` prefix where `MAJOR` is [`crate::API_VERSION`];
5//!
6//! _the examples below show the current value (`v1`)_.
7//!
8//! | Method | Endpoint                    | Handler             |
9//! |--------|-----------------------------|---------------------|
10//! | POST   | `/api/v1/tasks`             | submit              |
11//! | GET    | `/api/v1/tasks`             | list (query params) |
12//! | GET    | `/api/v1/tasks/{id}`        | get status          |
13//! | GET    | `/api/v1/tasks/{id}/runs`   | list runs           |
14//! | DELETE | `/api/v1/tasks/{id}`        | delete (stop+purge) |
15
16use std::sync::Arc;
17
18use axum::{
19    Json, Router,
20    extract::{FromRequest, Path, Query, Request, State, rejection::JsonRejection},
21    http::StatusCode,
22    middleware::{self, Next},
23    response::{IntoResponse, Response},
24    routing::{delete, get, post},
25};
26use serde::{Deserialize, de::DeserializeOwned};
27use solti_model::{TaskId, TaskPhase, TaskQuery};
28use tower_http::limit::RequestBodyLimitLayer;
29use tracing::debug;
30
31use crate::{
32    MAX_REQUEST_BYTES,
33    convert::{self, tasks_page_to_proto},
34    error::ApiError,
35    handler::ApiHandler,
36    proto_api,
37    validate::{clamp_list_limit, non_empty_id},
38};
39// `api_url!` is `#[macro_export]`, so it's already accessible in this
40// module by its bare name — `use crate::api_url` would be redundant
41// (and warnings about unused imports broke a `cargo publish` on us).
42
43/// Wrapper around `axum::Json<T>` that maps `JsonRejection` into [`ApiError::InvalidRequest`].
44pub(crate) struct ApiJson<T>(pub T);
45
46impl<T, S> FromRequest<S> for ApiJson<T>
47where
48    T: DeserializeOwned,
49    S: Send + Sync,
50{
51    type Rejection = ApiError;
52
53    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
54        let Json(value) = axum::Json::<T>::from_request(req, state)
55            .await
56            .map_err(map_json_rejection)?;
57        Ok(ApiJson(value))
58    }
59}
60
61fn map_json_rejection(rej: JsonRejection) -> ApiError {
62    if rej.status() == StatusCode::PAYLOAD_TOO_LARGE {
63        return ApiError::PayloadTooLarge(format!(
64            "request body exceeds the maximum of {} bytes",
65            MAX_REQUEST_BYTES
66        ));
67    }
68
69    let msg = rej.body_text();
70    let trimmed = msg
71        .strip_prefix("Failed to deserialize the JSON body into the target type: ")
72        .or_else(|| msg.strip_prefix("Failed to parse the request body as JSON: "))
73        .unwrap_or(&msg)
74        .to_string();
75    ApiError::InvalidRequest(trimmed)
76}
77
78async fn map_413_envelope(req: Request, next: Next) -> Response {
79    let resp = next.run(req).await;
80    if resp.status() == StatusCode::PAYLOAD_TOO_LARGE {
81        let body = serde_json::json!({
82            "error": "PayloadTooLarge",
83            "message": format!(
84                "request body exceeds the maximum of {} bytes",
85                MAX_REQUEST_BYTES
86            ),
87        });
88        return (StatusCode::PAYLOAD_TOO_LARGE, Json(body)).into_response();
89    }
90    resp
91}
92
93/// HTTP API service builder.
94///
95/// ## Also
96///
97/// - [`ApiHandler`](crate::ApiHandler) the trait backing all endpoints.
98/// - [`ApiError`](crate::ApiError) mapped to JSON + HTTP status codes.
99pub struct HttpApi<H> {
100    handler: Arc<H>,
101}
102
103impl<H> HttpApi<H>
104where
105    H: ApiHandler,
106{
107    /// Create new HTTP API with the given handler.
108    pub fn new(handler: Arc<H>) -> Self {
109        Self { handler }
110    }
111
112    /// Build axum router with mounted endpoints.
113    ///
114    /// Applies a [`RequestBodyLimitLayer`] capped at [`MAX_REQUEST_BYTES`] bytes to every request.
115    pub fn router(self) -> Router {
116        Router::new()
117            .route(api_url!("/tasks"), post(submit_task::<H>))
118            .route(api_url!("/tasks"), get(list_tasks::<H>))
119            .route(api_url!("/tasks/{id}"), get(get_task_status::<H>))
120            .route(api_url!("/tasks/{id}"), delete(delete_task::<H>))
121            .route(api_url!("/tasks/{id}/runs"), get(list_task_runs::<H>))
122            .layer(RequestBodyLimitLayer::new(MAX_REQUEST_BYTES))
123            .layer(middleware::from_fn(map_413_envelope))
124            .with_state(self.handler)
125    }
126}
127
128#[derive(Debug, Deserialize)]
129struct ListTasksParams {
130    slot: Option<String>,
131    status: Option<String>,
132    limit: Option<u32>,
133    offset: Option<u32>,
134}
135
136async fn submit_task<H>(
137    State(handler): State<Arc<H>>,
138    ApiJson(req): ApiJson<proto_api::SubmitTaskRequest>,
139) -> Result<impl IntoResponse, ApiError>
140where
141    H: ApiHandler,
142{
143    let spec = req
144        .spec
145        .ok_or_else(|| ApiError::InvalidRequest("missing spec".into()))?;
146    let spec = convert::convert_create_spec(spec)?;
147
148    debug!(slot = %spec.slot(), kind = ?spec.kind(), "submitting task");
149    let task_id = handler.submit_task(spec).await?;
150
151    let response = proto_api::SubmitTaskResponse {
152        task_id: task_id.to_string(),
153    };
154    Ok((StatusCode::CREATED, Json(response)))
155}
156
157async fn get_task_status<H>(
158    State(handler): State<Arc<H>>,
159    Path(id): Path<String>,
160) -> Result<impl IntoResponse, ApiError>
161where
162    H: ApiHandler,
163{
164    non_empty_id("task_id", &id)?;
165
166    let task_id = TaskId::from(id);
167    debug!(%task_id, "getting task status");
168    let task = handler.get_task_status(&task_id).await?;
169
170    let task = task.map(proto_api::TaskData::try_from).transpose()?;
171    Ok(Json(proto_api::GetTaskStatusResponse { task }))
172}
173
174async fn list_tasks<H>(
175    State(handler): State<Arc<H>>,
176    Query(params): Query<ListTasksParams>,
177) -> Result<impl IntoResponse, ApiError>
178where
179    H: ApiHandler,
180{
181    let mut query = TaskQuery::new();
182
183    if let Some(slot) = params.slot {
184        non_empty_id("slot", &slot)?;
185        query = query.with_slot(slot);
186    }
187
188    if let Some(status_str) = params.status {
189        let status = status_str.parse::<TaskPhase>().map_err(|_| {
190            ApiError::InvalidRequest(format!(
191                "invalid status: '{status_str}' (valid: pending, running, succeeded, failed, timeout, canceled, exhausted)"
192            ))
193        })?;
194        query = query.with_status(status);
195    }
196
197    query = query.with_limit(clamp_list_limit(params.limit.unwrap_or(0)));
198    if let Some(offset) = params.offset {
199        query = query.with_offset(offset as usize);
200    }
201
202    let page = handler.query_tasks(query).await?;
203    debug!(count = page.items.len(), total = page.total, "tasks listed");
204
205    Ok(Json(tasks_page_to_proto(page)?))
206}
207
208async fn list_task_runs<H>(
209    State(handler): State<Arc<H>>,
210    Path(id): Path<String>,
211) -> Result<impl IntoResponse, ApiError>
212where
213    H: ApiHandler,
214{
215    non_empty_id("task_id", &id)?;
216
217    let task_id = TaskId::from(id);
218    debug!(%task_id, "listing task runs");
219    let runs = handler.list_task_runs(&task_id).await?;
220    let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
221
222    Ok(Json(proto_api::ListTaskRunsResponse { runs }))
223}
224
225async fn delete_task<H>(
226    State(handler): State<Arc<H>>,
227    Path(id): Path<String>,
228) -> Result<impl IntoResponse, ApiError>
229where
230    H: ApiHandler,
231{
232    non_empty_id("task_id", &id)?;
233
234    let task_id = TaskId::from(id);
235    handler.delete_task(&task_id).await?;
236    debug!(%task_id, "task deleted");
237
238    Ok(StatusCode::NO_CONTENT)
239}