Skip to main content

stormchaser_api/routes/
step.rs

1use crate::db;
2use crate::{AppState, AuthClaims};
3use axum::response::sse::Event;
4use axum::{
5    extract::{Path, Query, State},
6    http::StatusCode,
7    Json,
8};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::time::Duration;
13use stormchaser_model::RunId;
14use stormchaser_model::StepInstanceId;
15use tokio::sync::mpsc;
16use tokio::time::sleep;
17use utoipa::ToSchema;
18
19#[derive(Deserialize, ToSchema)]
20pub struct LogsQuery {
21    #[schema(example = 100)]
22    pub limit: Option<usize>,
23}
24
25/// Format log event.
26pub fn format_log_event(line: &str) -> Event {
27    Event::default().event("log").data(line)
28}
29
30/// Stream step logs api.
31#[utoipa::path(
32    get,
33    path = "/api/v1/runs/{run_id}/steps/{step_id}/logs/stream",
34    params(("run_id" = stormchaser_model::RunId, Path, description="Run ID"), ("step_id" = stormchaser_model::StepInstanceId, Path, description="Step instance ID")),
35    responses(
36        (status = 200, description = "Success"),
37        (status = 400, description = "Bad Request"),
38        (status = 404, description = "Not Found"),
39        (status = 500, description = "Internal Server Error")
40    ),
41    tag = "step"
42)]
43pub async fn stream_step_logs_api(
44    AuthClaims(_claims): AuthClaims,
45    State(state): State<AppState>,
46    Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
47) -> Result<
48    axum::response::sse::Sse<
49        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
50    >,
51    StatusCode,
52> {
53    let log_backend = match &state.log_backend {
54        Some(backend) => backend,
55        None => return Err(StatusCode::NOT_IMPLEMENTED),
56    };
57
58    let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
59        .await
60        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
61        .ok_or(StatusCode::NOT_FOUND)?;
62
63    let rx = log_backend
64        .stream_step_logs(&instance.step_name, step_id)
65        .await
66        .map_err(|e| {
67            tracing::error!("Failed to stream logs: {}", e);
68            StatusCode::INTERNAL_SERVER_ERROR
69        })?;
70
71    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
72        Ok(log_line) => {
73            let events: Vec<_> = log_line
74                .lines()
75                .map(|line| Ok(format_log_event(line)))
76                .collect();
77            tokio_stream::iter(events)
78        }
79        Err(e) => tokio_stream::iter(vec![Ok(Event::default()
80            .event("error")
81            .data(e.to_string()))]),
82    });
83
84    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
85}
86
87/// Get step logs api.
88#[utoipa::path(
89    get,
90    path = "/api/v1/runs/{run_id}/steps/{step_id}/logs",
91    params(
92        ("run_id" = stormchaser_model::RunId, Path, description="Run ID"),
93        ("step_id" = stormchaser_model::StepInstanceId, Path, description="Step instance ID"),
94        ("limit" = Option<usize>, Query, description="Limit log lines")
95    ),
96    responses(
97        (status = 200, description = "Success", body = Vec<String>),
98        (status = 400, description = "Bad Request"),
99        (status = 404, description = "Not Found"),
100        (status = 500, description = "Internal Server Error")
101    ),
102    tag = "step"
103)]
104pub async fn get_step_logs_api(
105    AuthClaims(_claims): AuthClaims,
106    State(state): State<AppState>,
107    Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
108    Query(query): Query<LogsQuery>,
109) -> Result<Json<Vec<String>>, StatusCode> {
110    let log_backend = match &state.log_backend {
111        Some(backend) => backend,
112        None => return Err(StatusCode::NOT_IMPLEMENTED),
113    };
114
115    let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
116        .await
117        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
118        .ok_or(StatusCode::NOT_FOUND)?;
119
120    let logs = log_backend
121        .fetch_step_logs(
122            &instance.step_name,
123            step_id,
124            instance.started_at,
125            instance.finished_at,
126            query.limit,
127        )
128        .await
129        .map_err(|e| {
130            tracing::error!("Failed to fetch logs: {}", e);
131            StatusCode::INTERNAL_SERVER_ERROR
132        })?;
133
134    Ok(Json(logs))
135}
136
137/// Streams run logs.
138#[utoipa::path(
139    get,
140    path = "/api/v1/runs/{run_id}/logs/stream",
141    params(("run_id" = stormchaser_model::RunId, Path, description="Run ID")),
142    responses(
143        (status = 200, description = "Success"),
144        (status = 400, description = "Bad Request"),
145        (status = 404, description = "Not Found"),
146        (status = 500, description = "Internal Server Error")
147    ),
148    tag = "step"
149)]
150pub async fn stream_run_logs_api(
151    AuthClaims(_claims): AuthClaims,
152    State(state): State<AppState>,
153    Path(run_id): Path<RunId>,
154) -> Result<
155    axum::response::sse::Sse<
156        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
157    >,
158    StatusCode,
159> {
160    let log_backend = match &state.log_backend {
161        Some(backend) => backend.clone(),
162        None => return Err(StatusCode::NOT_IMPLEMENTED),
163    };
164
165    let (tx, rx) = mpsc::channel(100);
166
167    let pool = state.pool.clone();
168    tokio::spawn(async move {
169        let mut seen_steps = std::collections::HashSet::new();
170        tracing::debug!("Started run log stream task for run {}", run_id);
171
172        loop {
173            let status = db::get_workflow_run_status(&pool, run_id)
174                .await
175                .unwrap_or(None);
176
177            let is_terminal = matches!(
178                status.as_deref(),
179                Some("succeeded") | Some("failed") | Some("cancelled")
180            );
181
182            let steps = db::get_step_names(&pool, run_id).await.unwrap_or_default();
183
184            if !steps.is_empty() {
185                tracing::trace!("Found {} steps for run {}", steps.len(), run_id);
186            }
187
188            for (step_id, step_name) in steps {
189                if !seen_steps.contains(&step_id) {
190                    seen_steps.insert(step_id);
191                    tracing::debug!(
192                        "Discovered new step {} ({}) for run log stream",
193                        step_name,
194                        step_id
195                    );
196                    let tx_clone = tx.clone();
197                    let step_name_clone = step_name.clone();
198                    let log_backend = log_backend.clone();
199
200                    tokio::spawn(async move {
201                        tracing::debug!(
202                            "Starting sub-task to stream logs for step {}",
203                            step_name_clone
204                        );
205                        if let Ok(mut step_rx) = log_backend
206                            .stream_step_logs(&step_name_clone, step_id)
207                            .await
208                        {
209                            tracing::debug!(
210                                "Successfully connected to log stream for step {}",
211                                step_name_clone
212                            );
213                            while let Some(log_res) = step_rx.recv().await {
214                                match log_res {
215                                    Ok(line) => {
216                                        // SSE data cannot contain newlines. Split and send multiple events.
217                                        for fragment in line.lines() {
218                                            let prefixed =
219                                                format!("[{}] {}", step_name_clone, fragment);
220                                            if tx_clone.send(Ok(prefixed)).await.is_err() {
221                                                return; // Receiver dropped
222                                            }
223                                        }
224                                    }
225                                    Err(e) => {
226                                        tracing::warn!(
227                                            "Error in log stream for step {}: {:?}",
228                                            step_name_clone,
229                                            e
230                                        );
231                                        let _ = tx_clone.send(Err(e)).await;
232                                        break;
233                                    }
234                                }
235                            }
236                            tracing::debug!("Log stream for step {} finished", step_name_clone);
237                        } else {
238                            tracing::error!(
239                                "Failed to connect to log stream for step {}",
240                                step_name_clone
241                            );
242                        }
243                    });
244                }
245            }
246
247            if is_terminal {
248                // Allow some time for final logs to flush
249                sleep(Duration::from_secs(5)).await;
250                break;
251            }
252
253            sleep(Duration::from_secs(2)).await;
254        }
255    });
256
257    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
258        Ok(log_line) => {
259            let events: Vec<_> = log_line
260                .lines()
261                .map(|line| Ok(format_log_event(line)))
262                .collect();
263            tokio_stream::iter(events)
264        }
265        Err(e) => tokio_stream::iter(vec![Ok(Event::default()
266            .event("error")
267            .data(e.to_string()))]),
268    });
269
270    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
271}
272
273#[utoipa::path(
274    get,
275    path = "/api/v1/runs/{run_id}/status/stream",
276    params(
277        ("run_id" = stormchaser_model::RunId, Path, description = "Run ID")
278    ),
279    responses(
280        (status = 200, description = "Status stream (SSE)")
281    ),
282    security(
283        ("bearer_auth" = [])
284    ),
285    tag = "step"
286)]
287/// Stream run status api.
288pub async fn stream_run_status_api(
289    AuthClaims(_claims): AuthClaims,
290    State(state): State<AppState>,
291    Path(run_id): Path<RunId>,
292) -> Result<
293    axum::response::sse::Sse<
294        impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
295    >,
296    StatusCode,
297> {
298    let (tx, rx) = mpsc::channel::<Result<Event, std::convert::Infallible>>(100);
299    let pool = state.pool.clone();
300
301    tokio::spawn(async move {
302        let mut last_run_status: Option<String> = None;
303        let mut last_step_statuses: std::collections::HashMap<
304            stormchaser_model::StepInstanceId,
305            String,
306        > = HashMap::new();
307
308        loop {
309            // Check workflow run status
310            let current_run_status = db::get_combined_run_status(&pool, run_id)
311                .await
312                .unwrap_or(None);
313
314            let is_terminal = matches!(
315                current_run_status.as_deref(),
316                Some("succeeded") | Some("failed") | Some("cancelled")
317            );
318
319            if current_run_status != last_run_status {
320                if let Some(ref status) = current_run_status {
321                    let data = serde_json::json!({ "status": status }).to_string();
322                    let event = Event::default().event("run_status").data(data);
323                    if tx.send(Ok(event)).await.is_err() {
324                        break;
325                    }
326                }
327                last_run_status = current_run_status.clone();
328            }
329
330            // Check step instances statuses
331            let steps = db::get_combined_step_statuses(&pool, run_id)
332                .await
333                .unwrap_or_default();
334
335            for (step_id, step_name, status) in steps {
336                let should_emit = match last_step_statuses.get(&step_id) {
337                    Some(last_status) => last_status != &status,
338                    None => true,
339                };
340
341                if should_emit {
342                    let data = serde_json::json!({
343                        "step_id": step_id,
344                        "step_name": step_name,
345                        "status": status,
346                    })
347                    .to_string();
348                    let event = Event::default().event("step_status").data(data);
349                    if tx.send(Ok(event)).await.is_err() {
350                        return; // Receiver dropped, break out of spawn
351                    }
352                    last_step_statuses.insert(step_id, status);
353                }
354            }
355
356            if is_terminal {
357                // Allow some time for final steps to be flushed and updated
358                sleep(Duration::from_secs(2)).await;
359                break;
360            }
361
362            sleep(Duration::from_secs(1)).await;
363        }
364    });
365
366    let stream = tokio_stream::wrappers::ReceiverStream::new(rx).map(|res| match res {
367        Ok(event) => Ok(event),
368        Err(_) => unreachable!(),
369    });
370
371    Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_format_log_event() {
380        let line = "2026-04-23T19:17:03.318971Z INFO test log";
381        let event = format_log_event(line);
382        // The Debug representation of Event shows the internal fields. We verify it doesn't fail building
383        // and stringifies correctly in typical SSE format.
384        let stringified = format!("{:?}", event);
385        assert!(stringified.contains("log"));
386        assert!(stringified.contains("test log"));
387    }
388}