1use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio_stream::StreamExt;
11use tonic::{Request, Response, Status};
12use tracing::debug;
13
14use solti_model::TaskQuery;
15
16use crate::convert::{output_event_to_proto, proto_to_domain_status, tasks_page_to_proto};
17use crate::error::ApiError;
18use crate::handler::ApiHandler;
19use crate::metrics::{ApiMetricsHandle, Transport, noop_api_metrics};
20use crate::proto_api::{self, solti_api_server::SoltiApi, solti_api_server::SoltiApiServer};
21use crate::validate::{clamp_list_limit, non_empty_id};
22
23pub struct SoltiApiService<H> {
30 handler: Arc<H>,
31 metrics: ApiMetricsHandle,
32}
33
34impl<H> SoltiApiService<H>
35where
36 H: ApiHandler,
37{
38 pub fn new(handler: Arc<H>) -> Self {
40 Self::new_with_metrics(handler, noop_api_metrics())
41 }
42
43 pub fn new_with_metrics(handler: Arc<H>, metrics: ApiMetricsHandle) -> Self {
45 Self { handler, metrics }
46 }
47
48 async fn instrument<F, T>(&self, method: &'static str, fut: F) -> Result<Response<T>, Status>
49 where
50 F: Future<Output = Result<Response<T>, Status>>,
51 {
52 self.metrics.record_in_flight_delta(Transport::Grpc, 1);
53 let start = Instant::now();
54 let result = fut.await;
55 let duration_ms = start.elapsed().as_millis() as u64;
56 let status = match &result {
57 Ok(_) => 0u16,
58 Err(s) => s.code() as u16,
59 };
60 let path = format!("/solti.v1.SoltiApi/{}", method);
61 self.metrics
62 .record_request(Transport::Grpc, method, &path, status, duration_ms);
63 self.metrics.record_in_flight_delta(Transport::Grpc, -1);
64 result
65 }
66}
67
68pub fn build_grpc_server<H>(handler: Arc<H>) -> SoltiApiServer<SoltiApiService<H>>
84where
85 H: ApiHandler,
86{
87 build_grpc_server_with_metrics(handler, noop_api_metrics())
88}
89
90pub fn build_grpc_server_with_metrics<H>(
92 handler: Arc<H>,
93 metrics: ApiMetricsHandle,
94) -> SoltiApiServer<SoltiApiService<H>>
95where
96 H: ApiHandler,
97{
98 SoltiApiServer::new(SoltiApiService::new_with_metrics(handler, metrics))
99 .max_decoding_message_size(crate::MAX_REQUEST_BYTES)
100 .max_encoding_message_size(crate::MAX_REQUEST_BYTES)
101}
102
103#[tonic::async_trait]
104impl<H> SoltiApi for SoltiApiService<H>
105where
106 H: ApiHandler,
107{
108 async fn submit_task(
109 &self,
110 request: Request<proto_api::SubmitTaskRequest>,
111 ) -> Result<Response<proto_api::SubmitTaskResponse>, Status> {
112 self.instrument("SubmitTask", async move {
113 let req = request.into_inner();
114
115 let spec = req
116 .spec
117 .ok_or_else(|| Status::invalid_argument("missing spec"))?;
118
119 let spec =
120 crate::convert::convert_create_spec(spec).map_err(|e: ApiError| Status::from(e))?;
121
122 debug!(slot = %spec.slot(), kind = ?spec.kind(), "grpc: submitting task");
123 let task_id = self.handler.submit_task(spec).await.map_err(Status::from)?;
124
125 Ok(Response::new(proto_api::SubmitTaskResponse {
126 task_id: task_id.to_string(),
127 }))
128 })
129 .await
130 }
131
132 async fn get_task_status(
133 &self,
134 request: Request<proto_api::GetTaskStatusRequest>,
135 ) -> Result<Response<proto_api::GetTaskStatusResponse>, Status> {
136 self.instrument("GetTaskStatus", async move {
137 let req = request.into_inner();
138
139 non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
140
141 let task_id = solti_model::TaskId::from(req.task_id);
142 debug!(%task_id, "grpc: getting task status");
143
144 let info = self
145 .handler
146 .get_task_status(&task_id)
147 .await
148 .map_err(Status::from)?;
149
150 let task = info
151 .map(proto_api::TaskData::try_from)
152 .transpose()
153 .map_err(Status::from)?;
154
155 Ok(Response::new(proto_api::GetTaskStatusResponse { task }))
156 })
157 .await
158 }
159
160 async fn list_tasks(
161 &self,
162 request: Request<proto_api::ListTasksRequest>,
163 ) -> Result<Response<proto_api::ListTasksResponse>, Status> {
164 self.instrument("ListTasks", async move {
165 let req = request.into_inner();
166
167 let mut query = TaskQuery::new();
168
169 if let Some(slot) = req.slot {
170 non_empty_id("slot", &slot).map_err(Status::from)?;
171 query = query.with_slot(slot);
172 }
173
174 if let Some(status_raw) = req.status {
175 let status = proto_to_domain_status(status_raw).map_err(Status::from)?;
176 query = query.with_status(status);
177 }
178
179 query = query.with_limit(clamp_list_limit(req.limit));
180 if req.offset > 0 {
181 query = query.with_offset(req.offset as usize);
182 }
183
184 let page = self
185 .handler
186 .query_tasks(query)
187 .await
188 .map_err(Status::from)?;
189
190 debug!(
191 count = page.items.len(),
192 total = page.total,
193 "grpc: tasks listed"
194 );
195
196 let response = tasks_page_to_proto(page).map_err(Status::from)?;
197 Ok(Response::new(response))
198 })
199 .await
200 }
201
202 async fn list_task_runs(
203 &self,
204 request: Request<proto_api::ListTaskRunsRequest>,
205 ) -> Result<Response<proto_api::ListTaskRunsResponse>, Status> {
206 self.instrument("ListTaskRuns", async move {
207 let req = request.into_inner();
208
209 non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
210
211 let task_id = solti_model::TaskId::from(req.task_id);
212 debug!(%task_id, "grpc: listing task runs");
213
214 let runs = self
215 .handler
216 .list_task_runs(&task_id)
217 .await
218 .map_err(Status::from)?;
219
220 let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
221
222 Ok(Response::new(proto_api::ListTaskRunsResponse { runs }))
223 })
224 .await
225 }
226
227 async fn delete_task(
228 &self,
229 request: Request<proto_api::DeleteTaskRequest>,
230 ) -> Result<Response<proto_api::DeleteTaskResponse>, Status> {
231 self.instrument("DeleteTask", async move {
232 let req = request.into_inner();
233
234 non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
235
236 let task_id = solti_model::TaskId::from(req.task_id);
237 debug!(%task_id, "grpc: deleting task");
238
239 self.handler
240 .delete_task(&task_id)
241 .await
242 .map_err(Status::from)?;
243
244 debug!(%task_id, "grpc: task deleted");
245 Ok(Response::new(proto_api::DeleteTaskResponse {}))
246 })
247 .await
248 }
249
250 type StreamTaskLogsStream = Pin<
252 Box<
253 dyn tokio_stream::Stream<Item = Result<proto_api::OutputEventProto, Status>>
254 + Send
255 + 'static,
256 >,
257 >;
258
259 async fn stream_task_logs(
260 &self,
261 request: Request<proto_api::StreamTaskLogsRequest>,
262 ) -> Result<Response<Self::StreamTaskLogsStream>, Status> {
263 let req = request.into_inner();
264 non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
265
266 let task_id = solti_model::TaskId::from(req.task_id);
267 debug!(%task_id, "grpc: subscribing to task log stream");
268
269 let domain_stream = self
270 .handler
271 .stream_task_logs(&task_id)
272 .await
273 .map_err(Status::from)?;
274
275 let proto_stream = domain_stream.map(|ev| Ok(output_event_to_proto(ev)));
276 Ok(Response::new(Box::pin(proto_stream)))
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 use std::time::{Duration, UNIX_EPOCH};
285
286 use async_trait::async_trait;
287 use bytes::Bytes;
288 use solti_model::{
289 OutputChunk, OutputEvent, StreamKind as ModelStreamKind, Task, TaskId, TaskPage, TaskQuery,
290 TaskRun, TaskSpec,
291 };
292
293 use crate::error::ApiError;
294 use crate::handler::{ApiHandler, OutputEventStream};
295
296 struct StreamMock;
297
298 #[async_trait]
299 impl ApiHandler for StreamMock {
300 async fn submit_task(&self, _spec: TaskSpec) -> Result<TaskId, ApiError> {
301 unreachable!()
302 }
303 async fn get_task_status(&self, _id: &TaskId) -> Result<Option<Task>, ApiError> {
304 unreachable!()
305 }
306 async fn query_tasks(&self, _q: TaskQuery) -> Result<TaskPage<Task>, ApiError> {
307 unreachable!()
308 }
309 async fn list_task_runs(&self, _id: &TaskId) -> Result<Vec<TaskRun>, ApiError> {
310 unreachable!()
311 }
312 async fn delete_task(&self, _id: &TaskId) -> Result<(), ApiError> {
313 unreachable!()
314 }
315 async fn stream_task_logs(&self, id: &TaskId) -> Result<OutputEventStream, ApiError> {
316 if id.as_str() == "missing" {
317 return Err(ApiError::TaskNotFound(id.to_string()));
318 }
319 let events = vec![
320 OutputEvent::RunStarted {
321 attempt: 1,
322 started_at: UNIX_EPOCH + Duration::from_millis(1000),
323 },
324 OutputEvent::Chunk(OutputChunk {
325 attempt: 1,
326 stream: ModelStreamKind::Stdout,
327 seq: 0,
328 ts: UNIX_EPOCH + Duration::from_millis(1100),
329 line: Bytes::from_static(b"hello-grpc"),
330 }),
331 OutputEvent::RunFinished {
332 attempt: 1,
333 exit_code: Some(0),
334 finished_at: UNIX_EPOCH + Duration::from_millis(1500),
335 },
336 ];
337 Ok(Box::pin(tokio_stream::iter(events)))
338 }
339 }
340
341 fn service() -> SoltiApiService<StreamMock> {
342 SoltiApiService::new(Arc::new(StreamMock))
343 }
344
345 #[tokio::test]
346 async fn stream_task_logs_returns_three_proto_events_in_order() {
347 let svc = service();
348 let req = Request::new(proto_api::StreamTaskLogsRequest {
349 task_id: "tsk_1".into(),
350 });
351
352 let response = svc.stream_task_logs(req).await.expect("stream Ok");
353 let mut stream = response.into_inner();
354
355 match stream.next().await.unwrap().unwrap().kind.unwrap() {
356 proto_api::output_event_proto::Kind::RunStarted(r) => {
357 assert_eq!(r.attempt, 1);
358 assert_eq!(r.started_at, 1000);
359 }
360 other => panic!("expected RunStarted, got {other:?}"),
361 }
362
363 match stream.next().await.unwrap().unwrap().kind.unwrap() {
364 proto_api::output_event_proto::Kind::Chunk(c) => {
365 assert_eq!(c.attempt, 1);
366 assert_eq!(c.stream, proto_api::OutputStreamKind::Stdout as i32);
367 assert_eq!(c.seq, 0);
368 assert_eq!(&c.line[..], b"hello-grpc");
369 }
370 other => panic!("expected Chunk, got {other:?}"),
371 }
372
373 match stream.next().await.unwrap().unwrap().kind.unwrap() {
374 proto_api::output_event_proto::Kind::RunFinished(r) => {
375 assert_eq!(r.attempt, 1);
376 assert_eq!(r.exit_code, Some(0));
377 assert_eq!(r.finished_at, 1500);
378 }
379 other => panic!("expected RunFinished, got {other:?}"),
380 }
381 assert!(stream.next().await.is_none(), "stream must terminate");
382 }
383
384 #[tokio::test]
385 async fn stream_task_logs_rejects_empty_task_id() {
386 let svc = service();
387 let req = Request::new(proto_api::StreamTaskLogsRequest {
388 task_id: " ".into(),
389 });
390 let status = match svc.stream_task_logs(req).await {
391 Err(s) => s,
392 Ok(_) => panic!("expected error status"),
393 };
394 assert_eq!(status.code(), tonic::Code::InvalidArgument);
395 }
396
397 #[tokio::test]
398 async fn stream_task_logs_maps_task_not_found_to_not_found_status() {
399 let svc = service();
400 let req = Request::new(proto_api::StreamTaskLogsRequest {
401 task_id: "missing".into(),
402 });
403 let status = match svc.stream_task_logs(req).await {
404 Err(s) => s,
405 Ok(_) => panic!("expected error status"),
406 };
407 assert_eq!(status.code(), tonic::Code::NotFound);
408 }
409}