1use std::sync::Arc;
2
3use async_trait::async_trait;
4use redis::AsyncCommands;
5
6use rustvello_core::error::{RustvelloError, RustvelloResult};
7use rustvello_core::state_backend::{
8 StateBackendCore, StateBackendQuery, StateBackendRunner, StoredRunnerContext,
9};
10use rustvello_proto::call::CallDTO;
11use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
12use rustvello_proto::invocation::{InvocationDTO, InvocationHistory, WorkflowIdentity};
13
14use crate::connection::{redis_err, scan_keys, RedisPool};
15use rustvello_core::error::TaskError;
16
17fn prefixed_key(prefix: &str, suffix: &str) -> String {
18 let mut s = String::with_capacity(prefix.len() + suffix.len());
19 s.push_str(prefix);
20 s.push_str(suffix);
21 s
22}
23
24#[non_exhaustive]
29pub struct RedisStateBackend {
30 pool: Arc<RedisPool>,
31 inv_prefix: String,
32 call_prefix: String,
33 result_prefix: String,
34 error_prefix: String,
35 history_prefix: String,
36 wf_prefix: String,
37 child_prefix: String,
38 wf_types_key: String,
39 wf_runs_prefix: String,
40 wf_data_prefix: String,
41 app_infos_key: String,
42 wf_sub_prefix: String,
43 runner_prefix: String,
44 runner_inv_prefix: String,
45 history_ts_key: String,
46}
47
48impl RedisStateBackend {
49 pub fn new(pool: Arc<RedisPool>) -> Self {
50 let p = pool.prefix();
51 Self {
52 inv_prefix: format!("{p}state:inv:"),
53 call_prefix: format!("{p}state:call:"),
54 result_prefix: format!("{p}state:result:"),
55 error_prefix: format!("{p}state:error:"),
56 history_prefix: format!("{p}state:history:"),
57 wf_prefix: format!("{p}state:wf:"),
58 child_prefix: format!("{p}state:child:"),
59 wf_types_key: format!("{p}state:wf_types"),
60 wf_runs_prefix: format!("{p}state:wf_runs:"),
61 wf_data_prefix: format!("{p}state:wf_data:"),
62 app_infos_key: format!("{p}state:app_infos"),
63 wf_sub_prefix: format!("{p}state:wf_sub:"),
64 runner_prefix: format!("{p}state:runner:"),
65 runner_inv_prefix: format!("{p}state:runner_inv:"),
66 history_ts_key: format!("{p}state:history_ts"),
67 pool,
68 }
69 }
70}
71
72#[async_trait]
73impl StateBackendCore for RedisStateBackend {
74 async fn upsert_invocation(
75 &self,
76 invocation: &InvocationDTO,
77 call: &CallDTO,
78 ) -> RustvelloResult<()> {
79 let mut conn = self.pool.conn().await?;
80 let inv_json =
81 serde_json::to_string(invocation).map_err(|e| RustvelloError::Serialization {
82 message: e.to_string(),
83 })?;
84 let call_json = serde_json::to_string(call).map_err(|e| RustvelloError::Serialization {
85 message: e.to_string(),
86 })?;
87
88 conn.set::<_, _, ()>(
89 prefixed_key(&self.inv_prefix, invocation.invocation_id.as_ref()),
90 &inv_json,
91 )
92 .await
93 .map_err(redis_err)?;
94
95 conn.set::<_, _, ()>(
96 prefixed_key(&self.call_prefix, &call.call_id.to_string()),
97 &call_json,
98 )
99 .await
100 .map_err(redis_err)?;
101
102 if let Some(wf) = &invocation.workflow {
104 conn.sadd::<_, _, ()>(
105 prefixed_key(&self.wf_prefix, wf.workflow_id.as_ref()),
106 invocation.invocation_id.as_str(),
107 )
108 .await
109 .map_err(redis_err)?;
110 }
111
112 if let Some(parent_id) = &invocation.parent_invocation_id {
114 conn.sadd::<_, _, ()>(
115 prefixed_key(&self.child_prefix, parent_id.as_ref()),
116 invocation.invocation_id.as_str(),
117 )
118 .await
119 .map_err(redis_err)?;
120 }
121
122 Ok(())
123 }
124
125 async fn get_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<InvocationDTO> {
126 let mut conn = self.pool.conn().await?;
127 let val: Option<String> = conn
128 .get(prefixed_key(&self.inv_prefix, invocation_id.as_ref()))
129 .await
130 .map_err(redis_err)?;
131 match val {
132 Some(s) => serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
133 message: e.to_string(),
134 }),
135 None => Err(RustvelloError::InvocationNotFound {
136 invocation_id: invocation_id.clone(),
137 }),
138 }
139 }
140
141 async fn get_call(&self, call_id: &CallId) -> RustvelloResult<CallDTO> {
142 let mut conn = self.pool.conn().await?;
143 let val: Option<String> = conn
144 .get(prefixed_key(&self.call_prefix, &call_id.to_string()))
145 .await
146 .map_err(redis_err)?;
147 match val {
148 Some(s) => serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
149 message: e.to_string(),
150 }),
151 None => Err(RustvelloError::state_backend(format!(
152 "call not found: {}",
153 call_id
154 ))),
155 }
156 }
157
158 async fn store_result(
159 &self,
160 invocation_id: &InvocationId,
161 result: &str,
162 ) -> RustvelloResult<()> {
163 let mut conn = self.pool.conn().await?;
164 conn.set::<_, _, ()>(
165 prefixed_key(&self.result_prefix, invocation_id.as_ref()),
166 result,
167 )
168 .await
169 .map_err(redis_err)
170 }
171
172 async fn get_result(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<String>> {
173 let mut conn = self.pool.conn().await?;
174 conn.get(prefixed_key(&self.result_prefix, invocation_id.as_ref()))
175 .await
176 .map_err(redis_err)
177 }
178
179 async fn store_error(
180 &self,
181 invocation_id: &InvocationId,
182 error: &TaskError,
183 ) -> RustvelloResult<()> {
184 let mut conn = self.pool.conn().await?;
185 let json = serde_json::to_string(error).map_err(|e| RustvelloError::Serialization {
186 message: e.to_string(),
187 })?;
188 conn.set::<_, _, ()>(
189 prefixed_key(&self.error_prefix, invocation_id.as_ref()),
190 &json,
191 )
192 .await
193 .map_err(redis_err)
194 }
195
196 async fn get_error(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<TaskError>> {
197 let mut conn = self.pool.conn().await?;
198 let val: Option<String> = conn
199 .get(prefixed_key(&self.error_prefix, invocation_id.as_ref()))
200 .await
201 .map_err(redis_err)?;
202 match val {
203 Some(s) => {
204 let err: TaskError =
205 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
206 message: e.to_string(),
207 })?;
208 Ok(Some(err))
209 }
210 None => Ok(None),
211 }
212 }
213
214 async fn add_history(&self, history: &InvocationHistory) -> RustvelloResult<()> {
215 let mut conn = self.pool.conn().await?;
216 let json = serde_json::to_string(history).map_err(|e| RustvelloError::Serialization {
217 message: e.to_string(),
218 })?;
219 conn.rpush::<_, _, ()>(
221 prefixed_key(&self.history_prefix, history.invocation_id.as_ref()),
222 &json,
223 )
224 .await
225 .map_err(redis_err)?;
226 let ts = history
228 .history_timestamp
229 .unwrap_or(history.status_record.timestamp);
230 conn.zadd::<_, _, _, ()>(&self.history_ts_key, &json, ts.timestamp_millis() as f64)
231 .await
232 .map_err(redis_err)?;
233 let rid = history
235 .runner_id
236 .as_ref()
237 .or(history.status_record.runner_id.as_ref());
238 if let Some(r) = rid {
239 conn.sadd::<_, _, ()>(
240 prefixed_key(&self.runner_inv_prefix, r.as_str()),
241 history.invocation_id.as_str(),
242 )
243 .await
244 .map_err(redis_err)?;
245 }
246 Ok(())
247 }
248
249 async fn get_history(
250 &self,
251 invocation_id: &InvocationId,
252 ) -> RustvelloResult<Vec<InvocationHistory>> {
253 let mut conn = self.pool.conn().await?;
254 let vals: Vec<String> = conn
255 .lrange(
256 prefixed_key(&self.history_prefix, invocation_id.as_ref()),
257 0,
258 -1,
259 )
260 .await
261 .map_err(redis_err)?;
262 vals.into_iter()
263 .map(|s| {
264 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
265 message: e.to_string(),
266 })
267 })
268 .collect()
269 }
270
271 async fn purge(&self) -> RustvelloResult<()> {
272 let prefixes = [
273 &self.inv_prefix,
274 &self.call_prefix,
275 &self.result_prefix,
276 &self.error_prefix,
277 &self.history_prefix,
278 &self.wf_prefix,
279 &self.child_prefix,
280 &self.wf_runs_prefix,
281 &self.wf_data_prefix,
282 &self.wf_sub_prefix,
283 &self.runner_prefix,
284 &self.runner_inv_prefix,
285 ];
286 let mut conn = self.pool.conn().await?;
287 for prefix in prefixes {
288 let keys = scan_keys(&mut conn, &format!("{}*", prefix)).await?;
289 if !keys.is_empty() {
290 conn.del::<_, ()>(keys).await.map_err(redis_err)?;
291 }
292 }
293 for key in [
294 &self.wf_types_key,
295 &self.app_infos_key,
296 &self.history_ts_key,
297 ] {
298 conn.del::<_, ()>(key).await.map_err(redis_err)?;
299 }
300 Ok(())
301 }
302}
303
304#[async_trait]
305impl StateBackendQuery for RedisStateBackend {
306 async fn get_workflow_invocations(
307 &self,
308 workflow_id: &InvocationId,
309 ) -> RustvelloResult<Vec<InvocationId>> {
310 let mut conn = self.pool.conn().await?;
311 let members: Vec<String> = conn
312 .smembers(prefixed_key(&self.wf_prefix, workflow_id.as_ref()))
313 .await
314 .map_err(redis_err)?;
315 Ok(members.into_iter().map(InvocationId::from_string).collect())
316 }
317
318 async fn get_child_invocations(
319 &self,
320 parent_invocation_id: &InvocationId,
321 ) -> RustvelloResult<Vec<InvocationId>> {
322 let mut conn = self.pool.conn().await?;
323 let members: Vec<String> = conn
324 .smembers(format!("{}{}", &self.child_prefix, parent_invocation_id))
325 .await
326 .map_err(redis_err)?;
327 Ok(members.into_iter().map(InvocationId::from_string).collect())
328 }
329
330 async fn store_workflow_run(&self, workflow: &WorkflowIdentity) -> RustvelloResult<()> {
331 let mut conn = self.pool.conn().await?;
332 let type_key = workflow.workflow_type.to_string();
333 conn.sadd::<_, _, ()>(&self.wf_types_key, &type_key)
335 .await
336 .map_err(redis_err)?;
337 let json = serde_json::to_string(workflow).map_err(|e| RustvelloError::Serialization {
339 message: e.to_string(),
340 })?;
341 conn.hset::<_, _, _, ()>(
342 prefixed_key(&self.wf_runs_prefix, &type_key),
343 workflow.workflow_id.as_str(),
344 &json,
345 )
346 .await
347 .map_err(redis_err)?;
348 Ok(())
349 }
350
351 async fn get_all_workflow_types(&self) -> RustvelloResult<Vec<TaskId>> {
352 let mut conn = self.pool.conn().await?;
353 let members: Vec<String> = conn.smembers(&self.wf_types_key).await.map_err(redis_err)?;
354 members
355 .into_iter()
356 .map(|s| {
357 s.parse::<TaskId>()
358 .map_err(|e| RustvelloError::state_backend(format!("invalid task_id: {e}")))
359 })
360 .collect()
361 }
362
363 async fn get_workflow_runs(
364 &self,
365 workflow_type: &TaskId,
366 ) -> RustvelloResult<Vec<WorkflowIdentity>> {
367 let mut conn = self.pool.conn().await?;
368 let vals: Vec<String> = conn
369 .hvals(prefixed_key(
370 &self.wf_runs_prefix,
371 &workflow_type.to_string(),
372 ))
373 .await
374 .map_err(redis_err)?;
375 vals.into_iter()
376 .map(|s| {
377 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
378 message: e.to_string(),
379 })
380 })
381 .collect()
382 }
383
384 async fn set_workflow_data(
385 &self,
386 workflow_id: &InvocationId,
387 key: &str,
388 value: &str,
389 ) -> RustvelloResult<()> {
390 let mut conn = self.pool.conn().await?;
391 conn.hset::<_, _, _, ()>(
392 prefixed_key(&self.wf_data_prefix, workflow_id.as_ref()),
393 key,
394 value,
395 )
396 .await
397 .map_err(redis_err)?;
398 Ok(())
399 }
400
401 async fn get_workflow_data(
402 &self,
403 workflow_id: &InvocationId,
404 key: &str,
405 ) -> RustvelloResult<Option<String>> {
406 let mut conn = self.pool.conn().await?;
407 conn.hget(
408 prefixed_key(&self.wf_data_prefix, workflow_id.as_ref()),
409 key,
410 )
411 .await
412 .map_err(redis_err)
413 }
414
415 async fn store_app_info(&self, app_id: &str, info_json: &str) -> RustvelloResult<()> {
416 let mut conn = self.pool.conn().await?;
417 conn.hset::<_, _, _, ()>(&self.app_infos_key, app_id, info_json)
418 .await
419 .map_err(redis_err)?;
420 Ok(())
421 }
422
423 async fn get_app_info(&self, app_id: &str) -> RustvelloResult<Option<String>> {
424 let mut conn = self.pool.conn().await?;
425 conn.hget(&self.app_infos_key, app_id)
426 .await
427 .map_err(redis_err)
428 }
429
430 async fn get_all_app_infos(&self) -> RustvelloResult<Vec<(String, String)>> {
431 let mut conn = self.pool.conn().await?;
432 let map: Vec<(String, String)> =
433 conn.hgetall(&self.app_infos_key).await.map_err(redis_err)?;
434 Ok(map)
435 }
436
437 async fn store_workflow_sub_invocation(
438 &self,
439 workflow_id: &InvocationId,
440 sub_inv_id: &InvocationId,
441 ) -> RustvelloResult<()> {
442 let mut conn = self.pool.conn().await?;
443 conn.sadd::<_, _, ()>(
444 prefixed_key(&self.wf_sub_prefix, workflow_id.as_ref()),
445 sub_inv_id.as_str(),
446 )
447 .await
448 .map_err(redis_err)?;
449 Ok(())
450 }
451
452 async fn get_workflow_sub_invocations(
453 &self,
454 workflow_id: &InvocationId,
455 ) -> RustvelloResult<Vec<InvocationId>> {
456 let mut conn = self.pool.conn().await?;
457 let members: Vec<String> = conn
458 .smembers(prefixed_key(&self.wf_sub_prefix, workflow_id.as_ref()))
459 .await
460 .map_err(redis_err)?;
461 Ok(members.into_iter().map(InvocationId::from_string).collect())
462 }
463
464 async fn get_all_workflow_runs(&self) -> RustvelloResult<Vec<WorkflowIdentity>> {
465 let mut conn = self.pool.conn().await?;
466 let types: Vec<String> = conn.smembers(&self.wf_types_key).await.map_err(redis_err)?;
467 let mut all = Vec::new();
468 for t in &types {
469 let vals: Vec<String> = conn
470 .hvals(prefixed_key(&self.wf_runs_prefix, t))
471 .await
472 .map_err(redis_err)?;
473 for s in vals {
474 let wf: WorkflowIdentity =
475 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
476 message: e.to_string(),
477 })?;
478 all.push(wf);
479 }
480 }
481 Ok(all)
482 }
483}
484
485#[async_trait]
486impl StateBackendRunner for RedisStateBackend {
487 async fn store_runner_context(&self, context: &StoredRunnerContext) -> RustvelloResult<()> {
488 let mut conn = self.pool.conn().await?;
489 let json = serde_json::to_string(context).map_err(|e| RustvelloError::Serialization {
490 message: e.to_string(),
491 })?;
492 conn.set::<_, _, ()>(prefixed_key(&self.runner_prefix, &context.runner_id), &json)
493 .await
494 .map_err(redis_err)?;
495 Ok(())
496 }
497
498 async fn get_runner_context(
499 &self,
500 runner_id: &str,
501 ) -> RustvelloResult<Option<StoredRunnerContext>> {
502 let mut conn = self.pool.conn().await?;
503 let val: Option<String> = conn
504 .get(prefixed_key(&self.runner_prefix, runner_id))
505 .await
506 .map_err(redis_err)?;
507 match val {
508 Some(s) => {
509 let ctx: StoredRunnerContext =
510 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
511 message: e.to_string(),
512 })?;
513 Ok(Some(ctx))
514 }
515 None => Ok(None),
516 }
517 }
518
519 async fn get_runner_contexts_by_parent(
520 &self,
521 parent_runner_id: &str,
522 ) -> RustvelloResult<Vec<StoredRunnerContext>> {
523 let mut conn = self.pool.conn().await?;
524 let keys = scan_keys(&mut conn, &format!("{}*", &self.runner_prefix)).await?;
525 let mut result = Vec::new();
526 for key in keys {
527 let val: Option<String> = conn.get(&key).await.map_err(redis_err)?;
528 if let Some(s) = val {
529 let ctx: StoredRunnerContext =
530 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
531 message: e.to_string(),
532 })?;
533 if ctx.parent_runner_id.as_deref() == Some(parent_runner_id) {
534 result.push(ctx);
535 }
536 }
537 }
538 Ok(result)
539 }
540
541 async fn get_invocation_ids_by_runner(
542 &self,
543 runner_id: &str,
544 limit: usize,
545 offset: usize,
546 ) -> RustvelloResult<Vec<InvocationId>> {
547 let mut conn = self.pool.conn().await?;
548 let members: Vec<String> = conn
549 .smembers(prefixed_key(&self.runner_inv_prefix, runner_id))
550 .await
551 .map_err(redis_err)?;
552 let iter = members.into_iter().skip(offset);
553 let ids: Vec<InvocationId> = if limit > 0 {
554 iter.take(limit).map(InvocationId::from_string).collect()
555 } else {
556 iter.map(InvocationId::from_string).collect()
557 };
558 Ok(ids)
559 }
560
561 async fn count_invocations_by_runner(&self, runner_id: &str) -> RustvelloResult<usize> {
562 let mut conn = self.pool.conn().await?;
563 let count: usize = conn
564 .scard(prefixed_key(&self.runner_inv_prefix, runner_id))
565 .await
566 .map_err(redis_err)?;
567 Ok(count)
568 }
569
570 async fn get_history_in_timerange(
571 &self,
572 start: chrono::DateTime<chrono::Utc>,
573 end: chrono::DateTime<chrono::Utc>,
574 limit: usize,
575 offset: usize,
576 ) -> RustvelloResult<Vec<InvocationHistory>> {
577 let mut conn = self.pool.conn().await?;
578 let min = start.timestamp_millis() as f64;
579 let max = end.timestamp_millis() as f64;
580 let vals: Vec<String> = redis::cmd("ZRANGEBYSCORE")
582 .arg(&self.history_ts_key)
583 .arg(min)
584 .arg(max)
585 .query_async(&mut conn)
586 .await
587 .map_err(redis_err)?;
588 let iter = vals.into_iter().skip(offset);
589 let selected: Vec<String> = if limit > 0 {
590 iter.take(limit).collect()
591 } else {
592 iter.collect()
593 };
594 selected
595 .into_iter()
596 .map(|s| {
597 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
598 message: e.to_string(),
599 })
600 })
601 .collect()
602 }
603
604 async fn get_matching_runner_contexts(
605 &self,
606 partial_id: &str,
607 ) -> RustvelloResult<Vec<StoredRunnerContext>> {
608 let mut conn = self.pool.conn().await?;
609 let pattern = format!("{}*{}*", &self.runner_prefix, partial_id);
610 let keys = scan_keys(&mut conn, &pattern).await?;
611 let mut result = Vec::new();
612 for key in keys {
613 let val: Option<String> = conn.get(&key).await.map_err(redis_err)?;
614 if let Some(s) = val {
615 let ctx: StoredRunnerContext =
616 serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
617 message: e.to_string(),
618 })?;
619 result.push(ctx);
620 }
621 }
622 Ok(result)
623 }
624}