stepflow_action/action/
action_set_data.rs

1use stepflow_base::ObjectStoreFiltered;
2use stepflow_data::{StateDataFiltered, var::{Var, VarId}};
3use super::{ActionResult, Action, ActionId, Step, StateData, ActionError};
4
5
6/// Action that sets output data after a set number of attempts
7#[derive(Debug)]
8pub struct SetDataAction {
9  id: ActionId,
10  count: u64,
11  after_attempt: u64,
12  data: StateData,
13}
14
15impl SetDataAction {
16  /// `data` is returned as [`ActionResult::Finished`] after `after_attempt` number of tries.
17  /// If `after_attempt` is set to zero, it will set the data on the first call to [`start`](SetDataAction::start).
18  pub fn new(id: ActionId, data: StateData, after_attempt: u64) -> Self {
19    SetDataAction {
20      id,
21      count: 0,
22      after_attempt,
23      data,
24    }
25  }
26
27  pub fn boxed(self) -> Box<dyn Action + Sync + Send> {
28    Box::new(self)
29  }
30}
31
32impl Action for SetDataAction {
33  fn id(&self) -> &ActionId {
34    &self.id
35  }
36
37  fn start(&mut self, _step: &Step, _step_name: Option<&str>, _step_data: &StateDataFiltered, _vars: &ObjectStoreFiltered<Box<dyn Var + Send + Sync>, VarId>)
38    -> Result<ActionResult, ActionError>
39  {
40    if self.count >= self.after_attempt {
41      Ok(ActionResult::Finished(self.data.clone()))
42    } else {
43      self.count += 1;
44      Ok(ActionResult::CannotFulfill)
45    }
46  }
47}
48
49
50
51#[cfg(test)]
52mod tests {
53  use std::collections::HashSet;
54  use stepflow_base::ObjectStoreFiltered;
55  use stepflow_data::{StateData, StateDataFiltered};
56  use stepflow_test_util::test_id;
57  use crate::{ActionResult, Action, ActionId};
58  use super::SetDataAction;
59  use super::super::test_action_setup;
60
61  #[test]
62  fn on_attempts() {
63    let (step, state_data, var_store, var_id, val) = test_action_setup();
64    let mut allowed_ids = HashSet::new();
65    allowed_ids.insert(var_id.clone());
66    let vars = ObjectStoreFiltered::new(&var_store, allowed_ids);
67    let step_data_filtered = StateDataFiltered::new(&state_data, HashSet::new());
68
69    let mut expected_output = StateData::new();
70    let var = vars.get(&var_id).unwrap();
71    expected_output.insert(var, val.clone()).unwrap();
72
73    let mut action_now = SetDataAction::new(
74      test_id!(ActionId),
75      state_data.clone(),
76      0);
77    assert!(matches!(
78      action_now.start(&step, None, &step_data_filtered, &vars),
79      Ok(ActionResult::Finished(output)) if output == expected_output));
80
81    let mut action_after_3 = SetDataAction::new(
82      test_id!(ActionId),
83      state_data.clone(),
84      3);
85    for _ in 0..3 {
86      assert_eq!(
87        action_after_3.start(&step, None, &step_data_filtered, &vars),
88        Ok(ActionResult::CannotFulfill));
89    }
90    assert!(matches!(
91      action_after_3.start(&step, None, &step_data_filtered, &vars),
92      Ok(ActionResult::Finished(output)) if output == expected_output));
93  }
94}