wb_cache/test/simulation/
actor.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use fieldx_plus::Agent;
6use indicatif::ProgressBar;
7use sea_orm::prelude::*;
8use tokio::time::Instant;
9use tracing::debug;
10use tracing::instrument;
11
12use super::db::cache::DBProvider;
13use super::db::driver::DatabaseDriver;
14use super::db::entity::InventoryRecord;
15use super::db::prelude::*;
16use super::progress::traits::MaybeProgress;
17use super::scriptwriter::steps::ScriptTitle;
18use super::scriptwriter::steps::Step;
19use super::types::simerr;
20use super::types::SimError;
21use super::types::SimErrorAny;
22use super::SimulationApp;
23
24#[async_trait]
25pub trait TestActor<APP>: DBProvider + Agent<RcApp = Result<Arc<APP>, SimErrorAny>> + Debug
26where
27    APP: SimulationApp + Send + Sync + 'static,
28{
29    fn progress(&self) -> Result<Arc<Option<ProgressBar>>, SimError>;
30    fn current_day(&self) -> i32;
31    fn set_title(&self, title: &ScriptTitle) -> Result<(), SimError>;
32    fn prelude(&self) -> Result<(), SimError>;
33
34    async fn set_current_day(&self, day: i32) -> Result<(), SimError>;
35    async fn add_customer(&self, db: &DatabaseConnection, customer: &Customer) -> Result<(), SimError>;
36    async fn add_inventory_record(
37        &self,
38        db: &DatabaseConnection,
39        inventory_record: &InventoryRecord,
40    ) -> Result<(), SimError>;
41    async fn add_order(&self, db: &DatabaseConnection, order: &Order) -> Result<(), SimError>;
42    async fn add_product(&self, db: &DatabaseConnection, product: &Product) -> Result<(), SimError>;
43    async fn add_session(&self, db: &DatabaseConnection, session: &Session) -> Result<(), SimError>;
44    async fn check_inventory(
45        &self,
46        db: &DatabaseConnection,
47        product_id: i32,
48        stock: i64,
49        comment: &str,
50    ) -> Result<(), SimError>;
51    async fn update_inventory_record(
52        &self,
53        db: &DatabaseConnection,
54        product_id: i32,
55        quantity: i64,
56    ) -> Result<(), SimError>;
57    async fn collect_sessions(&self, db: &DatabaseConnection) -> Result<(), SimError>;
58    async fn update_order(&self, db: &DatabaseConnection, order: &Order) -> Result<(), SimError>;
59    async fn update_product_view_count(&self, db: &DatabaseConnection, product_id: i32) -> Result<(), SimError>;
60    async fn update_session(&self, db: &DatabaseConnection, session: &Session) -> Result<(), SimError>;
61    async fn step_complete(&self, db: &DatabaseConnection, step_num: usize) -> Result<(), SimError>;
62
63    fn inv_rec_compare(
64        &self,
65        inventory_record: &Option<InventoryRecord>,
66        product_id: i32,
67        stock: i64,
68        comment: &str,
69    ) -> Result<(), SimError> {
70        if let Some(inventory_record) = inventory_record {
71            if inventory_record.stock != stock {
72                return Err(simerr!(
73                    "Inventory check '{}' failed for product ID {}: expected {}, but found {}",
74                    comment,
75                    product_id,
76                    stock,
77                    inventory_record.stock
78                )
79                .into());
80            }
81        }
82        else {
83            return Err(simerr!("Inventory record not found for product ID {}", product_id).into());
84        }
85
86        Ok(())
87    }
88
89    #[instrument(level = "trace", skip(screenplay))]
90    async fn act(&self, screenplay: &[Step]) -> Result<(), SimError> {
91        self.prelude()?;
92
93        let db = self.db_connection()?;
94        let progress = self.progress()?;
95
96        progress.maybe_set_length(screenplay.len() as u64);
97
98        let mut checkpoint = Instant::now();
99        let mut err = Ok(());
100        let mut post_steps = 0;
101
102        if !matches!(screenplay.first(), Some(Step::Title { .. })) {
103            Err(simerr!("Screenplay must start with a title"))?;
104        }
105
106        for (step_num, step) in screenplay.iter().enumerate() {
107            if err.is_err() {
108                if post_steps < 10 {
109                    post_steps += 1;
110                    continue;
111                }
112                return err;
113            }
114            if Instant::now().duration_since(checkpoint).as_secs() > 5 {
115                checkpoint = Instant::now();
116                self.db_driver()?.checkpoint().await?;
117            }
118
119            let step_name = format!("{step}");
120
121            debug!("Executing step {step_name}: {step:?}");
122
123            err = match step {
124                Step::Title(title) => self.set_title(title),
125                Step::Day(day) => {
126                    self.set_current_day(*day).await?;
127                    Ok(())
128                }
129                Step::AddProduct(product) => self.add_product(&db, product).await,
130                Step::AddInventoryRecord(inventory_record) => self.add_inventory_record(&db, inventory_record).await,
131                Step::AddCustomer(customer) => self.add_customer(&db, customer).await,
132                Step::AddOrder(order) => self.add_order(&db, order).await,
133                Step::AddStock(shipment) => {
134                    self.update_inventory_record(&db, shipment.product_id, shipment.batch_size as i64)
135                        .await
136                }
137                Step::UpdateOrder(order) => self.update_order(&db, order).await,
138                Step::CheckInventory(checkpoint) => {
139                    self.check_inventory(&db, checkpoint.product_id, checkpoint.stock, &checkpoint.comment)
140                        .await
141                }
142                Step::AddSession(session) => self.add_session(&db, session).await,
143                Step::UpdateSession(session) => self.update_session(&db, session).await,
144                Step::ViewProduct(product_id) => self.update_product_view_count(&db, *product_id).await,
145                Step::CollectSessions => self.collect_sessions(&db).await,
146            }
147            .inspect_err(|err| {
148                err.context(format!("step {step_name}"));
149            });
150
151            progress.maybe_inc(1);
152            self.step_complete(&db, step_num).await?;
153        }
154
155        self.curtain_call().await?;
156
157        Ok(())
158    }
159
160    #[instrument(level = "trace", skip(self))]
161    async fn curtain_call(&self) -> Result<(), SimError> {
162        Ok(())
163    }
164
165    /// Collect sessions that are older than the current day and have no customer ID
166    #[instrument(level = "trace", skip(self, db))]
167    async fn collect_session_stubs(&self, db: &DatabaseConnection) -> Result<(), SimError> {
168        let res = Sessions::delete_many()
169            .filter(
170                super::db::entity::session::Column::ExpiresOn
171                    .lte(self.current_day())
172                    .and(super::db::entity::session::Column::CustomerId.is_null()),
173            )
174            .exec(db)
175            .await?;
176
177        if res.rows_affected == 0 {
178            self.progress()?.maybe_set_message("");
179        }
180        else {
181            self.progress()?
182                .maybe_set_message(format!("Collected {} sessions", res.rows_affected));
183        }
184
185        Ok(())
186    }
187}