1use crate::db;
2use crate::{AppState, AuthClaims};
3use axum::response::sse::Event;
4use axum::{
5 extract::{Path, Query, State},
6 http::StatusCode,
7 Json,
8};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::time::Duration;
13use stormchaser_model::step::StepStatus;
14use stormchaser_model::RunId;
15use stormchaser_model::RunStatus;
16use stormchaser_model::StepInstanceId;
17use tokio::sync::mpsc;
18use tokio::time::sleep;
19use utoipa::ToSchema;
20
21#[derive(Deserialize, ToSchema)]
22pub struct LogsQuery {
23 #[schema(example = 100)]
24 pub limit: Option<usize>,
25}
26
27pub fn format_log_event(line: &str) -> Event {
29 Event::default().event("log").data(line)
30}
31
32#[utoipa::path(
34 get,
35 path = "/api/v1/runs/{run_id}/steps/{step_id}/logs/stream",
36 params(("run_id" = RunId, Path, description="Run ID"), ("step_id" = StepInstanceId, Path, description="Step instance ID")),
37 responses(
38 (status = 200, description = "Success"),
39 (status = 400, description = "Bad Request"),
40 (status = 404, description = "Not Found"),
41 (status = 500, description = "Internal Server Error")
42 ),
43 tag = "step"
44)]
45pub async fn stream_step_logs_api(
46 AuthClaims(_claims): AuthClaims,
47 State(state): State<AppState>,
48 Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
49) -> Result<
50 axum::response::sse::Sse<
51 impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
52 >,
53 StatusCode,
54> {
55 let log_backend = match &state.log_backend {
56 Some(backend) => backend,
57 None => return Err(StatusCode::NOT_IMPLEMENTED),
58 };
59
60 let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
61 .await
62 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
63 .ok_or(StatusCode::NOT_FOUND)?;
64
65 let rx = log_backend
66 .stream_step_logs(&instance.step_name, step_id)
67 .await
68 .map_err(|e| {
69 tracing::error!("Failed to stream logs: {}", e);
70 StatusCode::INTERNAL_SERVER_ERROR
71 })?;
72
73 let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
74 Ok(log_line) => {
75 let events: Vec<_> = log_line
76 .lines()
77 .map(|line| Ok(format_log_event(line)))
78 .collect();
79 tokio_stream::iter(events)
80 }
81 Err(e) => tokio_stream::iter(vec![Ok(Event::default()
82 .event("error")
83 .data(e.to_string()))]),
84 });
85
86 Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
87}
88
89#[utoipa::path(
91 get,
92 path = "/api/v1/runs/{run_id}/steps/{step_id}/logs",
93 params(
94 ("run_id" = RunId, Path, description="Run ID"),
95 ("step_id" = StepInstanceId, Path, description="Step instance ID"),
96 ("limit" = Option<usize>, Query, description="Limit log lines")
97 ),
98 responses(
99 (status = 200, description = "Success", body = Vec<String>),
100 (status = 400, description = "Bad Request"),
101 (status = 404, description = "Not Found"),
102 (status = 500, description = "Internal Server Error")
103 ),
104 tag = "step"
105)]
106pub async fn get_step_logs_api(
107 AuthClaims(_claims): AuthClaims,
108 State(state): State<AppState>,
109 Path((run_id, step_id)): Path<(RunId, StepInstanceId)>,
110 Query(query): Query<LogsQuery>,
111) -> Result<Json<Vec<String>>, StatusCode> {
112 let log_backend = match &state.log_backend {
113 Some(backend) => backend,
114 None => return Err(StatusCode::NOT_IMPLEMENTED),
115 };
116
117 let instance = db::get_step_instance_by_id(&state.pool, run_id, step_id)
118 .await
119 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
120 .ok_or(StatusCode::NOT_FOUND)?;
121
122 let logs = log_backend
123 .fetch_step_logs(
124 &instance.step_name,
125 step_id,
126 instance.started_at,
127 instance.finished_at,
128 query.limit,
129 )
130 .await
131 .map_err(|e| {
132 tracing::error!("Failed to fetch logs: {}", e);
133 StatusCode::INTERNAL_SERVER_ERROR
134 })?;
135
136 Ok(Json(logs))
137}
138
139#[utoipa::path(
141 get,
142 path = "/api/v1/runs/{run_id}/logs/stream",
143 params(("run_id" = RunId, Path, description="Run ID")),
144 responses(
145 (status = 200, description = "Success"),
146 (status = 400, description = "Bad Request"),
147 (status = 404, description = "Not Found"),
148 (status = 500, description = "Internal Server Error")
149 ),
150 tag = "step"
151)]
152pub async fn stream_run_logs_api(
153 AuthClaims(_claims): AuthClaims,
154 State(state): State<AppState>,
155 Path(run_id): Path<RunId>,
156) -> Result<
157 axum::response::sse::Sse<
158 impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
159 >,
160 StatusCode,
161> {
162 let log_backend = match &state.log_backend {
163 Some(backend) => backend.clone(),
164 None => return Err(StatusCode::NOT_IMPLEMENTED),
165 };
166
167 let (tx, rx) = mpsc::channel(100);
168
169 let pool = state.pool.clone();
170 tokio::spawn(async move {
171 let mut seen_steps = std::collections::HashSet::new();
172 tracing::debug!("Started run log stream task for run {}", run_id);
173
174 loop {
175 let status = db::get_workflow_run_status(&pool, run_id)
176 .await
177 .unwrap_or(None);
178
179 let is_terminal = matches!(
180 status,
181 Some(RunStatus::Succeeded) | Some(RunStatus::Failed) | Some(RunStatus::Aborted)
182 );
183
184 let steps = db::get_step_names(&pool, run_id).await.unwrap_or_default();
185
186 if !steps.is_empty() {
187 tracing::trace!("Found {} steps for run {}", steps.len(), run_id);
188 }
189
190 for (step_id, step_name) in steps {
191 if !seen_steps.contains(&step_id) {
192 seen_steps.insert(step_id);
193 tracing::debug!(
194 "Discovered new step {} ({}) for run log stream",
195 step_name,
196 step_id
197 );
198 let tx_clone = tx.clone();
199 let step_name_clone = step_name.clone();
200 let log_backend = log_backend.clone();
201
202 tokio::spawn(async move {
203 tracing::debug!(
204 "Starting sub-task to stream logs for step {}",
205 step_name_clone
206 );
207 if let Ok(mut step_rx) = log_backend
208 .stream_step_logs(&step_name_clone, step_id)
209 .await
210 {
211 tracing::debug!(
212 "Successfully connected to log stream for step {}",
213 step_name_clone
214 );
215 while let Some(log_res) = step_rx.recv().await {
216 match log_res {
217 Ok(line) => {
218 for fragment in line.lines() {
220 let prefixed =
221 format!("[{}] {}", step_name_clone, fragment);
222 if tx_clone.send(Ok(prefixed)).await.is_err() {
223 return; }
225 }
226 }
227 Err(e) => {
228 tracing::warn!(
229 "Error in log stream for step {}: {:?}",
230 step_name_clone,
231 e
232 );
233 let _ = tx_clone.send(Err(e)).await;
234 break;
235 }
236 }
237 }
238 tracing::debug!("Log stream for step {} finished", step_name_clone);
239 } else {
240 tracing::error!(
241 "Failed to connect to log stream for step {}",
242 step_name_clone
243 );
244 }
245 });
246 }
247 }
248
249 if is_terminal {
250 sleep(Duration::from_secs(5)).await;
252 break;
253 }
254
255 sleep(Duration::from_secs(2)).await;
256 }
257 });
258
259 let stream = tokio_stream::wrappers::ReceiverStream::new(rx).flat_map(|res| match res {
260 Ok(log_line) => {
261 let events: Vec<_> = log_line
262 .lines()
263 .map(|line| Ok(format_log_event(line)))
264 .collect();
265 tokio_stream::iter(events)
266 }
267 Err(e) => tokio_stream::iter(vec![Ok(Event::default()
268 .event("error")
269 .data(e.to_string()))]),
270 });
271
272 Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
273}
274
275#[utoipa::path(
276 get,
277 path = "/api/v1/runs/{run_id}/status/stream",
278 params(
279 ("run_id" = RunId, Path, description = "Run ID")
280 ),
281 responses(
282 (status = 200, description = "Status stream (SSE)")
283 ),
284 security(
285 ("bearer_auth" = [])
286 ),
287 tag = "step"
288)]
289pub async fn stream_run_status_api(
291 AuthClaims(_claims): AuthClaims,
292 State(state): State<AppState>,
293 Path(run_id): Path<RunId>,
294) -> Result<
295 axum::response::sse::Sse<
296 impl futures::stream::Stream<Item = Result<Event, std::convert::Infallible>>,
297 >,
298 StatusCode,
299> {
300 let (tx, rx) = mpsc::channel::<Result<Event, std::convert::Infallible>>(100);
301 let pool = state.pool.clone();
302
303 tokio::spawn(async move {
304 let mut last_run_status: Option<RunStatus> = None;
305 let mut last_step_statuses: std::collections::HashMap<StepInstanceId, StepStatus> =
306 HashMap::new();
307
308 loop {
309 let current_run_status = db::get_combined_run_status(&pool, run_id)
311 .await
312 .unwrap_or(None);
313
314 let is_terminal = matches!(
315 current_run_status,
316 Some(RunStatus::Succeeded) | Some(RunStatus::Failed) | Some(RunStatus::Aborted)
317 );
318
319 if current_run_status != last_run_status {
320 if let Some(ref status) = current_run_status {
321 let data = serde_json::json!({ "status": status }).to_string();
322 let event = Event::default().event("run_status").data(data);
323 if tx.send(Ok(event)).await.is_err() {
324 break;
325 }
326 }
327 last_run_status = current_run_status.clone();
328 }
329
330 let steps = db::get_combined_step_statuses(&pool, run_id)
332 .await
333 .unwrap_or_default();
334
335 for (step_id, step_name, status) in steps {
336 let should_emit = match last_step_statuses.get(&step_id) {
337 Some(last_status) => last_status != &status,
338 None => true,
339 };
340
341 if should_emit {
342 let data = serde_json::json!({
343 "step_id": step_id,
344 "step_name": step_name,
345 "status": status,
346 })
347 .to_string();
348 let event = Event::default().event("step_status").data(data);
349 if tx.send(Ok(event)).await.is_err() {
350 return; }
352 last_step_statuses.insert(step_id, status);
353 }
354 }
355
356 if is_terminal {
357 sleep(Duration::from_secs(2)).await;
359 break;
360 }
361
362 sleep(Duration::from_secs(1)).await;
363 }
364 });
365
366 let stream = tokio_stream::wrappers::ReceiverStream::new(rx).map(|res| match res {
367 Ok(event) => Ok(event),
368 Err(_) => unreachable!(),
369 });
370
371 Ok(axum::response::sse::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_format_log_event() {
380 let line = "2026-04-23T19:17:03.318971Z INFO test log";
381 let event = format_log_event(line);
382 let stringified = format!("{:?}", event);
385 assert!(stringified.contains("log"));
386 assert!(stringified.contains("test log"));
387 }
388}