Skip to main content

stormchaser_api/routes/
step.rs

1use crate::db;
2use crate::{AppState, AuthClaims};
3use axum::response::sse::Event;
4use axum::{
5    extract::{Path, Query, State},
6    http::StatusCode,
7    Json,
8};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::time::Duration;
13use stormchaser_model::step::StepStatus;
14use stormchaser_model::RunId;
15use stormchaser_model::RunStatus;
16use stormchaser_model::StepInstanceId;
17use tokio::sync::mpsc;
18use tokio::time::sleep;
19use utoipa::ToSchema;
20
21#[derive(Deserialize, ToSchema)]
22pub struct LogsQuery {
23    #[schema(example = 100)]
24    pub limit: Option<usize>,
25}
26
27/// Format log event.
28pub fn format_log_event(line: &str) -> Event {
29    Event::default().event("log").data(line)
30}
31
32/// Stream step logs api.
33#[utoipa::path(
34    get,
35    path = "/api/v1/runs/{run_id}/steps/{step_id}/logs/stream",
36    params(("run_id" = RunId, Path, description="Run ID"), ("step_id" = StepInstanceId, Path, description="Step instance ID")),
37    responses(
38        (status = 200, description = "Success"),
39        (status = 400, description = "Bad Request"),
40        (status = 404, description = "Not Found"),
41        (status = 500, description = "Internal Server Error")
42    ),
43    tag = "step"
44)]
45pub async fn stream_step_logs_api(
46    AuthClaims(_claims): AuthClaims,
47    State(state): State<AppState>,
48    Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
49) -> Result<
50    axum::response::sse::Sse<
51        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
52    >,
53    StatusCode,
54> {
55    let log_backend = match &state.log_backend {
56        Some(backend) => backend,
57        None => return Err(StatusCode::NOT_IMPLEMENTED),
58    };
59
60    let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
61        .await
62        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
63        .ok_or(StatusCode::NOT_FOUND)?;
64
65    let rx = log_backend
66        .stream_step_logs(&instance.step_name, step_id)
67        .await
68        .map_err(|e| {
69            tracing::error!("Failed to stream logs: {}", e);
70            StatusCode::INTERNAL_SERVER_ERROR
71        })?;
72
73    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
74        Ok(log_line) => {
75            let events: Vec<_> = log_line
76                .lines()
77                .map(|line| Ok(format_log_event(line)))
78                .collect();
79            tokio_stream::iter(events)
80        }
81        Err(e) => tokio_stream::iter(vec![Ok(Event::default()
82            .event("error")
83            .data(e.to_string()))]),
84    });
85
86    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
87}
88
89/// Get step logs api.
90#[utoipa::path(
91    get,
92    path = "/api/v1/runs/{run_id}/steps/{step_id}/logs",
93    params(
94        ("run_id" = RunId, Path, description="Run ID"),
95        ("step_id" = StepInstanceId, Path, description="Step instance ID"),
96        ("limit" = Option<usize>, Query, description="Limit log lines")
97    ),
98    responses(
99        (status = 200, description = "Success", body = Vec<String>),
100        (status = 400, description = "Bad Request"),
101        (status = 404, description = "Not Found"),
102        (status = 500, description = "Internal Server Error")
103    ),
104    tag = "step"
105)]
106pub async fn get_step_logs_api(
107    AuthClaims(_claims): AuthClaims,
108    State(state): State<AppState>,
109    Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
110    Query(query): Query<LogsQuery>,
111) -> Result<Json<Vec<String>>, StatusCode> {
112    let log_backend = match &state.log_backend {
113        Some(backend) => backend,
114        None => return Err(StatusCode::NOT_IMPLEMENTED),
115    };
116
117    let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
118        .await
119        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
120        .ok_or(StatusCode::NOT_FOUND)?;
121
122    let logs = log_backend
123        .fetch_step_logs(
124            &instance.step_name,
125            step_id,
126            instance.started_at,
127            instance.finished_at,
128            query.limit,
129        )
130        .await
131        .map_err(|e| {
132            tracing::error!("Failed to fetch logs: {}", e);
133            StatusCode::INTERNAL_SERVER_ERROR
134        })?;
135
136    Ok(Json(logs))
137}
138
139/// Streams run logs.
140#[utoipa::path(
141    get,
142    path = "/api/v1/runs/{run_id}/logs/stream",
143    params(("run_id" = RunId, Path, description="Run ID")),
144    responses(
145        (status = 200, description = "Success"),
146        (status = 400, description = "Bad Request"),
147        (status = 404, description = "Not Found"),
148        (status = 500, description = "Internal Server Error")
149    ),
150    tag = "step"
151)]
152pub async fn stream_run_logs_api(
153    AuthClaims(_claims): AuthClaims,
154    State(state): State<AppState>,
155    Path(run_id): Path<RunId>,
156) -> Result<
157    axum::response::sse::Sse<
158        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
159    >,
160    StatusCode,
161> {
162    let log_backend = match &state.log_backend {
163        Some(backend) => backend.clone(),
164        None => return Err(StatusCode::NOT_IMPLEMENTED),
165    };
166
167    let (tx, rx) = mpsc::channel(100);
168
169    let pool = state.pool.clone();
170    tokio::spawn(async move {
171        let mut seen_steps = std::collections::HashSet::new();
172        tracing::debug!("Started run log stream task for run {}", run_id);
173
174        loop {
175            let status = db::get_workflow_run_status(&pool, run_id)
176                .await
177                .unwrap_or(None);
178
179            let is_terminal = matches!(
180                status,
181                Some(RunStatus::Succeeded) | Some(RunStatus::Failed) | Some(RunStatus::Aborted)
182            );
183
184            let steps = db::get_step_names(&pool, run_id).await.unwrap_or_default();
185
186            if !steps.is_empty() {
187                tracing::trace!("Found {} steps for run {}", steps.len(), run_id);
188            }
189
190            for (step_id, step_name) in steps {
191                if !seen_steps.contains(&step_id) {
192                    seen_steps.insert(step_id);
193                    tracing::debug!(
194                        "Discovered new step {} ({}) for run log stream",
195                        step_name,
196                        step_id
197                    );
198                    let tx_clone = tx.clone();
199                    let step_name_clone = step_name.clone();
200                    let log_backend = log_backend.clone();
201
202                    tokio::spawn(async move {
203                        tracing::debug!(
204                            "Starting sub-task to stream logs for step {}",
205                            step_name_clone
206                        );
207                        if let Ok(mut step_rx) = log_backend
208                            .stream_step_logs(&step_name_clone, step_id)
209                            .await
210                        {
211                            tracing::debug!(
212                                "Successfully connected to log stream for step {}",
213                                step_name_clone
214                            );
215                            while let Some(log_res) = step_rx.recv().await {
216                                match log_res {
217                                    Ok(line) => {
218                                        // SSE data cannot contain newlines. Split and send multiple events.
219                                        for fragment in line.lines() {
220                                            let prefixed =
221                                                format!("[{}] {}", step_name_clone, fragment);
222                                            if tx_clone.send(Ok(prefixed)).await.is_err() {
223                                                return; // Receiver dropped
224                                            }
225                                        }
226                                    }
227                                    Err(e) => {
228                                        tracing::warn!(
229                                            "Error in log stream for step {}: {:?}",
230                                            step_name_clone,
231                                            e
232                                        );
233                                        let _ = tx_clone.send(Err(e)).await;
234                                        break;
235                                    }
236                                }
237                            }
238                            tracing::debug!("Log stream for step {} finished", step_name_clone);
239                        } else {
240                            tracing::error!(
241                                "Failed to connect to log stream for step {}",
242                                step_name_clone
243                            );
244                        }
245                    });
246                }
247            }
248
249            if is_terminal {
250                // Allow some time for final logs to flush
251                sleep(Duration::from_secs(5)).await;
252                break;
253            }
254
255            sleep(Duration::from_secs(2)).await;
256        }
257    });
258
259    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
260        Ok(log_line) => {
261            let events: Vec<_> = log_line
262                .lines()
263                .map(|line| Ok(format_log_event(line)))
264                .collect();
265            tokio_stream::iter(events)
266        }
267        Err(e) => tokio_stream::iter(vec![Ok(Event::default()
268            .event("error")
269            .data(e.to_string()))]),
270    });
271
272    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
273}
274
275#[utoipa::path(
276    get,
277    path = "/api/v1/runs/{run_id}/status/stream",
278    params(
279        ("run_id" = RunId, Path, description = "Run ID")
280    ),
281    responses(
282        (status = 200, description = "Status stream (SSE)")
283    ),
284    security(
285        ("bearer_auth" = [])
286    ),
287    tag = "step"
288)]
289/// Stream run status api.
290pub async fn stream_run_status_api(
291    AuthClaims(_claims): AuthClaims,
292    State(state): State<AppState>,
293    Path(run_id): Path<RunId>,
294) -> Result<
295    axum::response::sse::Sse<
296        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
297    >,
298    StatusCode,
299> {
300    let (tx, rx) = mpsc::channel::<Result<Event, std::convert::Infallible>>(100);
301    let pool = state.pool.clone();
302
303    tokio::spawn(async move {
304        let mut last_run_status: Option<RunStatus> = None;
305        let mut last_step_statuses: std::collections::HashMap<StepInstanceId, StepStatus> =
306            HashMap::new();
307
308        loop {
309            // Check workflow run status
310            let current_run_status = db::get_combined_run_status(&pool, run_id)
311                .await
312                .unwrap_or(None);
313
314            let is_terminal = matches!(
315                current_run_status,
316                Some(RunStatus::Succeeded) | Some(RunStatus::Failed) | Some(RunStatus::Aborted)
317            );
318
319            if current_run_status != last_run_status {
320                if let Some(ref status) = current_run_status {
321                    let data = serde_json::json!({ "status": status }).to_string();
322                    let event = Event::default().event("run_status").data(data);
323                    if tx.send(Ok(event)).await.is_err() {
324                        break;
325                    }
326                }
327                last_run_status = current_run_status.clone();
328            }
329
330            // Check step instances statuses
331            let steps = db::get_combined_step_statuses(&pool, run_id)
332                .await
333                .unwrap_or_default();
334
335            for (step_id, step_name, status) in steps {
336                let should_emit = match last_step_statuses.get(&step_id) {
337                    Some(last_status) => last_status != &status,
338                    None => true,
339                };
340
341                if should_emit {
342                    let data = serde_json::json!({
343                        "step_id": step_id,
344                        "step_name": step_name,
345                        "status": status,
346                    })
347                    .to_string();
348                    let event = Event::default().event("step_status").data(data);
349                    if tx.send(Ok(event)).await.is_err() {
350                        return; // Receiver dropped, break out of spawn
351                    }
352                    last_step_statuses.insert(step_id, status);
353                }
354            }
355
356            if is_terminal {
357                // Allow some time for final steps to be flushed and updated
358                sleep(Duration::from_secs(2)).await;
359                break;
360            }
361
362            sleep(Duration::from_secs(1)).await;
363        }
364    });
365
366    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).map(|res| match res {
367        Ok(event) => Ok(event),
368        Err(_) => unreachable!(),
369    });
370
371    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_format_log_event() {
380        let line = "2026-04-23T19:17:03.318971Z INFO test log";
381        let event = format_log_event(line);
382        // The Debug representation of Event shows the internal fields. We verify it doesn't fail building
383        // and stringifies correctly in typical SSE format.
384        let stringified = format!("{:?}", event);
385        assert!(stringified.contains("log"));
386        assert!(stringified.contains("test log"));
387    }
388}