wb_cache/test/simulation/
sim_app.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::fmt::Display;
4use std::io::BufWriter;
5use std::io::Read;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::Duration;
9use std::time::Instant;
10
11use clap::error::ErrorKind;
12use clap::CommandFactory;
13use clap::Parser;
14use fieldx::fxstruct;
15use fieldx_plus::agent_build;
16use fieldx_plus::fx_plus;
17use garde::Validate;
18use indicatif::ProgressBar;
19use indicatif::ProgressStyle;
20use postcard::to_io;
21use sea_orm::entity::*;
22use sea_orm::query::*;
23use sea_orm::EntityTrait;
24use sea_orm::QueryOrder;
25use sea_orm_migration::MigratorTrait;
26use tokio::sync::Barrier;
27use tokio::task::JoinSet;
28use tokio_stream::StreamExt;
29use tracing::instrument;
30
31use super::actor::TestActor;
32use super::db;
33#[cfg(feature = "pg")]
34use super::db::driver::pg::Pg;
35#[cfg(feature = "sqlite")]
36use super::db::driver::sqlite::Sqlite;
37use super::db::driver::DatabaseDriver;
38use super::db::entity::Customers;
39use super::db::entity::InventoryRecords;
40use super::db::entity::Orders;
41use super::db::entity::Products;
42use super::db::entity::Sessions;
43use super::db::migrations::Migrator;
44use super::progress::MaybeProgress;
45use super::progress::POrder;
46use super::progress::PStyle;
47use super::progress::ProgressUI;
48use super::scriptwriter::steps::Step;
49use super::scriptwriter::ScriptWriter;
50use super::types::simerr;
51use super::types::Result;
52use super::types::SimError;
53use super::types::SimErrorAny;
54use super::SimulationApp;
55
56const INNER_ZIP_NAME: &str = "__script.postcard";
57
58#[derive(Debug, Clone, clap::Parser, Validate)]
59#[fxstruct(no_new, get(copy))]
60#[clap(about, version, author, name = "company")]
61pub(crate) struct Cli {
62    /// File name of the script.
63    #[fieldx(get(clone))]
64    #[garde(skip)]
65    script: Option<PathBuf>,
66
67    /// Silence the output
68    #[clap(long, short, default_value_t = false)]
69    #[garde(skip)]
70    quiet: bool,
71
72    /// Simulation period in "days".
73    #[clap(long, default_value_t = 365)]
74    #[garde(range(min = 1))]
75    period: u32,
76
77    /// Number of products to "offer"
78    #[clap(long, default_value_t = 10)]
79    #[garde(range(min = 1))]
80    products: u32,
81
82    /// The number of customers we have on day 1.
83    #[clap(long, default_value_t = 1)]
84    #[garde(range(min = 1))]
85    initial_customers: u32,
86
87    /// The maximum number of customers the company can have.
88    #[clap(long, default_value_t = 1_000)]
89    #[garde(range(min = 1))]
90    market_capacity: u32,
91
92    /// Where customer base growth reaches its peak.
93    #[clap(long, default_value_t = 400)]
94    #[garde(range(min = 1), custom(Self::less_than("market-capacity", &self.market_capacity)))]
95    inflection_point: u32,
96
97    /// Company's "success" rate – how fast the customer base grows
98    #[clap(long, default_value_t = 0.05)]
99    #[garde(range(min = 0.0))]
100    growth_rate: f32,
101
102    /// Minimal number of orders per customer per day. Values below 1 indicate that a customer makes a purchase less
103    /// than once a day.
104    #[clap(long, default_value_t = 0.15)]
105    #[garde(range(min = 0.0), custom(Self::less_than("max-customer-orders", &self.max_customer_orders)))]
106    min_customer_orders: f32,
107
108    /// Maximum number of orders per customer per day. This is not a hard limit but an expectation that 90% of the
109    /// customers will fall within this range.  The remaining 10% may exhibit less restrained behavior.
110    #[clap(long, default_value_t = 3.0)]
111    #[garde(range(min = 0.0))]
112    max_customer_orders: f32,
113
114    /// The period of time we allow for a purchase to be returned.
115    #[clap(long, default_value_t = 30)]
116    #[garde(skip)]
117    return_window: u32,
118
119    /// Save the script to a file.
120    #[clap(long, short)]
121    #[garde(custom(Self::with_file(&self.script)))]
122    save: bool,
123
124    /// Load the script from a file.
125    #[clap(long, short)]
126    #[garde(custom(Self::with_file(&self.script)))]
127    // This field is only used when either sqlite or pg features are enabled.
128    #[fieldx(get(attributes_fn(allow(unused))))]
129    load: bool,
130
131    /// Test the results of the simulation by comparing two databases.
132    #[clap(long)]
133    #[garde(skip)]
134    // This field is only used when either sqlite or pg features are enabled.
135    #[fieldx(get(attributes_fn(allow(unused))))]
136    test: bool,
137
138    #[cfg_attr(feature = "sqlite", clap(long))]
139    #[fieldx(get(copy, attributes_fn(cfg(feature = "sqlite"))))]
140    #[cfg(feature = "sqlite")]
141    #[garde(skip)]
142    /// Use SQLite as the database backend.
143    sqlite: bool,
144
145    /// Path to the directory where the SQLite database is stored.
146    /// If not provided, a temporary directory will be used.
147    #[cfg_attr(feature = "sqlite", clap(long, env = "WBCACHE_SQLITE_PATH"))]
148    #[fieldx(get(clone, attributes_fn(cfg(feature = "sqlite"))))]
149    #[cfg(feature = "sqlite")]
150    #[garde(skip)]
151    sqlite_path: Option<PathBuf>,
152
153    #[cfg_attr(feature = "pg", clap(long))]
154    #[fieldx(get(copy, attributes_fn(cfg(feature = "pg"))))]
155    #[garde(skip)]
156    #[cfg(feature = "pg")]
157    /// Use PostgreSQL as the database backend.
158    pg: bool,
159
160    #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_HOST", default_value = "localhost"))]
161    #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
162    #[garde(skip)]
163    #[cfg(feature = "pg")]
164    pg_host: String,
165
166    #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_PORT", default_value_t = 5432))]
167    #[fieldx(get(copy, attributes_fn(cfg(feature = "pg"))))]
168    #[garde(skip)]
169    #[cfg(feature = "pg")]
170    pg_port: u16,
171
172    #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_USER", default_value = "wbcache"))]
173    #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
174    #[garde(skip)]
175    #[cfg(feature = "pg")]
176    pg_user: String,
177
178    #[cfg_attr(
179        feature = "pg",
180        clap(long, env = "WBCACHE_PG_PASSWORD", hide_env_values = true, default_value = "wbcache")
181    )]
182    #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
183    #[garde(skip)]
184    #[cfg(feature = "pg")]
185    pg_password: String,
186
187    #[cfg_attr(
188        feature = "pg",
189        clap(long, env = "WBCACHE_PG_DB_PREFIX", default_value = "wbcache_test")
190    )]
191    #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
192    #[garde(skip)]
193    #[cfg(feature = "pg")]
194    pg_db_prefix: String,
195
196    /// File to send log into
197    #[cfg_attr(feature = "log", clap(long, env = "WBCACHE_LOG_FILE"))]
198    #[fieldx(get(clone, attributes_fn(cfg(feature = "log"), allow(unused))))]
199    #[garde(skip)]
200    #[cfg(feature = "log")]
201    log_file: Option<PathBuf>,
202
203    /// URL of the Loki server for tracing.
204    #[cfg_attr(
205        all(feature = "tracing", feature = "tracing-loki"),
206        clap(long, env = "WBCACHE_LOKI_URL", default_value = "https://127.0.0.1:3100")
207    )]
208    #[fieldx(get(
209        clone,
210        attributes_fn(cfg(all(feature = "tracing", feature = "tracing-loki")), allow(unused))
211    ))]
212    #[garde(skip)]
213    #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
214    loki_url: tracing_loki::url::Url,
215}
216
217impl Cli {
218    fn less_than<'a, T: PartialOrd + Display>(
219        max_name: &'static str,
220        max: &'a T,
221    ) -> impl FnOnce(&'a T, &()) -> garde::Result {
222        move |value, _| {
223            if value > max {
224                Err(garde::Error::new(format!(
225                    "{} is more than {max_name} ({})",
226                    *value, *max
227                )))
228            }
229            else {
230                Ok(())
231            }
232        }
233    }
234
235    fn with_file<'a>(file: &'a Option<PathBuf>) -> impl FnOnce(&'a bool, &()) -> garde::Result {
236        move |value, _| {
237            if *value && file.is_none() {
238                Err(garde::Error::new("Script file name is required"))
239            }
240            else {
241                Ok(())
242            }
243        }
244    }
245}
246
247#[fx_plus(
248    app,
249    rc,
250    new(private),
251    sync,
252    get,
253    fallible(off, error(SimErrorAny)),
254    builder(vis(pub))
255)]
256pub struct EcommerceApp {
257    #[fieldx(inner_mut, clearer, builder("_cli_args"))]
258    cli_args: Vec<String>,
259
260    #[fieldx(lazy, private, fallible(error(clap::Error)), get(clone))]
261    cli: Cli,
262
263    #[fieldx(lazy, get, clearer, fallible)]
264    script_writer: Arc<ScriptWriter>,
265
266    // This field is only used when either sqlite or pg features are enabled.
267    #[fieldx(lazy, private, get(attributes_fn(allow(unused))), fallible)]
268    tempdir: tempfile::TempDir,
269
270    #[fieldx(lazy, fallible, get, clearer)]
271    progress_ui: ProgressUI,
272
273    // This field is only used when either sqlite or pg features are enabled.
274    #[fieldx(
275        lock,
276        private,
277        get(copy, attributes_fn(allow(unused))),
278        set("_set_plain_per_sec"),
279        default(0.0)
280    )]
281    plain_per_sec: f64,
282
283    // This field is only used when either sqlite or pg features are enabled.
284    #[fieldx(
285        lock,
286        private,
287        get(copy, attributes_fn(allow(unused))),
288        set("_set_cached_per_sec"),
289        default(0.0)
290    )]
291    cached_per_sec: f64,
292}
293
294impl EcommerceApp {
295    fn build_cli(&self) -> Result<Cli, clap::Error> {
296        Ok(if let Some(custom_args) = self.clear_cli_args() {
297            Cli::try_parse_from(custom_args.into_iter())?
298        }
299        else {
300            Cli::try_parse()?
301        })
302    }
303
304    fn build_script_writer(&self) -> Result<Arc<ScriptWriter>> {
305        let cli = self.cli()?;
306        Ok(ScriptWriter::builder()
307            .quiet(cli.quiet())
308            .period(cli.period() as i32)
309            .product_count(cli.products() as i32)
310            .initial_customers(cli.initial_customers())
311            .market_capacity(cli.market_capacity())
312            .inflection_point(cli.inflection_point())
313            .growth_rate(cli.growth_rate() as f64)
314            .min_customer_orders(cli.min_customer_orders() as f64)
315            .max_customer_orders(cli.max_customer_orders() as f64)
316            .return_window(cli.return_window() as i32)
317            .build()?)
318    }
319
320    fn build_tempdir(&self) -> Result<tempfile::TempDir, SimErrorAny> {
321        Ok(tempfile::Builder::new().prefix("wb-cache-simulation").tempdir()?)
322    }
323
324    fn build_progress_ui(&self) -> Result<ProgressUI, SimErrorAny> {
325        Ok(ProgressUI::builder().quiet(self.cli()?.quiet()).build()?)
326    }
327
328    fn validate(&self) -> Result<(), SimErrorAny> {
329        if let Err(err) = self.cli()?.validate() {
330            let mut cmd = Cli::command();
331            let err = cmd.error(ErrorKind::InvalidValue, err);
332
333            err.exit();
334        }
335
336        Ok(())
337    }
338
339    async fn db_prepare<D: DatabaseDriver>(&self, dbd: &D) -> Result<()> {
340        dbd.configure().await?;
341        let db = dbd.connection();
342        Migrator::down(&db, None).await?;
343        Migrator::up(&db, None).await?;
344        Ok(())
345    }
346
347    async fn compare_tables<E>(
348        &self,
349        table: &str,
350        key: E::Column,
351        name1: &str,
352        db1: Arc<impl DatabaseDriver>,
353        name2: &str,
354        db2: Arc<impl DatabaseDriver>,
355    ) -> Result<(), SimErrorAny>
356    where
357        E: EntityTrait,
358        E::Model: FromQueryResult + Sized + Send + Sync + PartialEq + Debug,
359    {
360        let conn1 = db1.connection();
361        let conn2 = db2.connection();
362
363        let mut paginator1 = E::find().order_by_asc(key).paginate(&conn1, 1000).into_stream();
364        let mut paginator2 = E::find().order_by_asc(key).paginate(&conn2, 1000).into_stream();
365
366        loop {
367            let page1 = paginator1.next().await;
368            let page2 = paginator2.next().await;
369
370            if page1.is_none() && page2.is_none() {
371                break;
372            }
373
374            if page1.is_none() {
375                return Err(simerr!("Table '{table}': {name2} has more records than {name1}"));
376            }
377            if page2.is_none() {
378                return Err(simerr!("Table '{table}': {name1} has more records than {name2}"));
379            }
380
381            let page1 = page1.unwrap()?;
382            let page2 = page2.unwrap()?;
383
384            if page1.len() != page2.len() {
385                return Err(simerr!(
386                    "Table '{table}': {name1} has {} records, {name2} has {} records",
387                    page1.len(),
388                    page2.len()
389                ));
390            }
391
392            for (record1, record2) in page1.iter().zip(page2.iter()) {
393                if record1 != record2 {
394                    return Err(simerr!(
395                        "Table '{table}': Records do not match: {name1} = {:?}, {name2} = {:?}",
396                        record1,
397                        record2
398                    ));
399                }
400            }
401        }
402
403        Ok(())
404    }
405
406    // Implement the most straightforward test by comparing all records in all
407    // tables in both databases.
408    async fn test_db<D: DatabaseDriver>(&self, db_plain: Arc<D>, db_cached: Arc<D>) -> Result<(), SimErrorAny> {
409        self.compare_tables::<Customers>(
410            "customers",
411            db::entity::customer::Column::Id,
412            "plain",
413            db_plain.clone(),
414            "cached",
415            db_cached.clone(),
416        )
417        .await?;
418
419        self.compare_tables::<InventoryRecords>(
420            "inventory_records",
421            db::entity::inventory_record::Column::ProductId,
422            "plain",
423            db_plain.clone(),
424            "cached",
425            db_cached.clone(),
426        )
427        .await?;
428
429        self.compare_tables::<Products>(
430            "products",
431            db::entity::product::Column::Id,
432            "plain",
433            db_plain.clone(),
434            "cached",
435            db_cached.clone(),
436        )
437        .await?;
438
439        self.compare_tables::<Orders>(
440            "orders",
441            db::entity::order::Column::Id,
442            "plain",
443            db_plain.clone(),
444            "cached",
445            db_cached.clone(),
446        )
447        .await?;
448
449        self.compare_tables::<Sessions>(
450            "sessions",
451            db::entity::session::Column::Id,
452            "plain",
453            db_plain.clone(),
454            "cached",
455            db_cached.clone(),
456        )
457        .await?;
458
459        Ok(())
460    }
461
462    #[instrument(level = "trace", skip(self, db_plain, db_cached, screenplay))]
463    async fn execute_script<D: DatabaseDriver>(
464        &self,
465        db_plain: Arc<D>,
466        db_cached: Arc<D>,
467        screenplay: Arc<Vec<Step>>,
468    ) -> Result<(), SimErrorAny> {
469        let barrier = Arc::new(Barrier::new(2));
470
471        let message_progress = self.progress_ui()?.acquire_progress(PStyle::Message, None);
472        message_progress.maybe_set_prefix("Rate");
473
474        let mut tasks = JoinSet::<Result<(&'static str, Duration), SimError>>::new();
475
476        let myself = self.myself().unwrap();
477        tokio::spawn(async move {
478            let mut interval = tokio::time::interval(Duration::from_millis(100));
479            loop {
480                interval.tick().await;
481
482                let rate = if myself.plain_per_sec() > 0.0 {
483                    myself.cached_per_sec() / myself.plain_per_sec()
484                }
485                else {
486                    0.0
487                };
488
489                message_progress.maybe_set_message(format!(
490                    "{rate:.2}x | Average: cached {:.2}/s, plain {:.2}/s",
491                    myself.cached_per_sec(),
492                    myself.plain_per_sec()
493                ));
494                message_progress.maybe_inc(1);
495            }
496        });
497
498        // Spawn the plain actor
499        let myself = self.myself().unwrap();
500        let s1 = screenplay.clone();
501        let b1 = barrier.clone();
502        let db_plain_async = db_plain.clone();
503        tasks.spawn(async move {
504            myself.db_prepare(&*db_plain_async).await?;
505            b1.wait().await;
506            let started = Instant::now();
507            let plain_actor = agent_build!(
508                myself, crate::test::simulation::company_plain::TestCompany<Self, D> {
509                    db: db_plain_async
510                }
511            )?;
512            plain_actor.act(&s1).await.inspect_err(|err| {
513                err.context("Plain actor");
514            })?;
515            Ok(("plain", Instant::now().duration_since(started)))
516        });
517
518        // Spawn the cached actor
519        let s2 = screenplay.clone();
520        let b2 = barrier.clone();
521        let myself = self.myself().unwrap();
522        let db_cached_async = db_cached.clone();
523        tasks.spawn(async move {
524            myself.db_prepare(&*db_cached_async).await?;
525            b2.wait().await;
526            let started = Instant::now();
527            let cached_actor = agent_build!(
528                myself, crate::test::simulation::company_cached::TestCompany<Self, D> {
529                    db: db_cached_async
530                }
531            )?;
532            cached_actor.act(&s2).await.inspect_err(|err| {
533                err.context("Cached actor");
534            })?;
535            myself.report_debug("Cached actor completed.");
536            Ok(("cached", Instant::now().duration_since(started)))
537        });
538
539        let mut all_success = true;
540        let mut outcomes = HashMap::new();
541
542        while let Some(res) = tasks.join_next().await {
543            match res {
544                Ok(Ok((label, duration))) => {
545                    self.report_info(format!("{} actor completed in {:.2}s", label, duration.as_secs_f64()));
546                    outcomes.insert(label.to_string(), duration);
547                }
548                Ok(Err(err)) => {
549                    all_success = false;
550                    self.report_error(err.to_string_with_backtrace("An error occurred during actor execution"));
551                    tasks.abort_all();
552                }
553                Err(err) => {
554                    all_success = false;
555                    let err = SimErrorAny::from(err);
556                    self.report_error(err.to_string_with_backtrace("Actor errorred out"));
557                    tasks.abort_all();
558                }
559            }
560            self.report_info(format!("Tasks left: {}", tasks.len()));
561        }
562
563        if all_success {
564            let plain = outcomes.get("plain").unwrap();
565            let cached = outcomes.get("cached").unwrap();
566            self.report_info(format!("{:>11} | {:>11}", "plain", "cached"));
567            self.report_info(format!(
568                "{:>10.2}s | {:>10.2}s",
569                plain.as_secs_f64(),
570                cached.as_secs_f64()
571            ));
572            self.report_info(format!(
573                "{:>10.2}x | {:>10.2}x",
574                plain.as_secs_f64() / cached.as_secs_f64(),
575                1.0
576            ));
577        }
578
579        if self.cli()?.test() {
580            self.test_db(db_plain, db_cached).await?;
581        }
582
583        Ok(())
584    }
585
586    fn save_script(&self) -> Result<(), SimErrorAny> {
587        let script = self.script_writer()?.create()?;
588        let script_file = self.cli()?.script().unwrap();
589
590        let out = std::fs::File::create(&script_file)?;
591        let mut zip = zip::ZipWriter::new(out);
592        zip.start_file(INNER_ZIP_NAME, zip::write::SimpleFileOptions::default())?;
593        let pb = ProgressBar::no_length()
594            .with_message(format!("Saving script to {}", script_file.display()))
595            .with_style(ProgressStyle::default_spinner().template("[{binary_bytes:.yellow}] {msg}")?);
596
597        let mut zip = BufWriter::with_capacity(128 * 1024, zip);
598        to_io(&script, pb.wrap_write(&mut zip))?;
599        pb.finish_with_message("Script saved successfully.");
600        zip.into_inner()?.finish()?;
601
602        Ok(())
603    }
604
605    fn load_script(&self) -> Result<Vec<Step>, SimErrorAny> {
606        let script_file = self.cli()?.script().unwrap();
607        let file = std::fs::File::open(&script_file)?;
608        let mut zip = zip::ZipArchive::new(file)?;
609        let zip_file = zip.by_name(INNER_ZIP_NAME)?;
610
611        let size = zip_file.size();
612        let mut buf = vec![0; size as usize];
613
614        let pb = ProgressBar::new(size)
615            .with_message(format!("Loading script from {}", script_file.display()))
616            .with_style(ProgressStyle::default_spinner().template("[{binary_bytes:.yellow}] {msg}")?);
617
618        pb.wrap_read(zip_file).read_exact(&mut buf[..size as usize])?;
619        pb.set_message("Script file loaded successfully.");
620        let script: Vec<Step> = postcard::from_bytes(&buf)?;
621        pb.finish_with_message("Script extracted successfully.");
622
623        Ok(script)
624    }
625
626    #[cfg(feature = "sqlite")]
627    fn db_dir(&self) -> Result<PathBuf, SimErrorAny> {
628        self.cli()?
629            .sqlite_path()
630            .as_ref()
631            .cloned()
632            .map_or_else(|| self.tempdir().map(|t| t.path().to_path_buf()), Ok)
633    }
634
635    #[cfg(any(feature = "pg", feature = "sqlite"))]
636    #[instrument(level = "trace", skip(script, self))]
637    async fn execute_per_db(&self, script: Vec<Step>) -> Result<(), SimErrorAny> {
638        let cli = self.cli()?;
639        let script = Arc::new(script);
640
641        #[cfg(feature = "sqlite")]
642        if cli.sqlite() {
643            let db_plain = Sqlite::connect(&self.db_dir()?, "test_company_plan.db").await?;
644            let db_cached = Sqlite::connect(&self.db_dir()?, "test_company_cached.db").await?;
645            self.execute_script(db_plain, db_cached, script.clone()).await?;
646        }
647
648        #[cfg(feature = "pg")]
649        if cli.pg() {
650            let db_plain = Pg::builder()
651                .host(cli.pg_host())
652                .port(cli.pg_port())
653                .user(cli.pg_user())
654                .password(cli.pg_password())
655                .database(format!("{}_plain", cli.pg_db_prefix()))
656                .build()?;
657            db_plain.connect().await?;
658            let db_cached = Pg::builder()
659                .host(cli.pg_host())
660                .port(cli.pg_port())
661                .user(cli.pg_user())
662                .password(cli.pg_password())
663                .database(format!("{}_cached", cli.pg_db_prefix()))
664                .build()?;
665            db_cached.connect().await?;
666            self.execute_script(db_plain, db_cached, script.clone()).await?;
667        }
668
669        Ok(())
670    }
671
672    #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
673    fn resource() -> opentelemetry_sdk::Resource {
674        use opentelemetry::KeyValue;
675        use opentelemetry_semantic_conventions::attribute::DEPLOYMENT_ENVIRONMENT_NAME;
676        use opentelemetry_semantic_conventions::attribute::SERVICE_VERSION;
677        use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
678        use opentelemetry_semantic_conventions::SCHEMA_URL;
679
680        opentelemetry_sdk::Resource::builder()
681            .with_service_name(env!("CARGO_PKG_NAME"))
682            .with_schema_url(
683                [
684                    KeyValue::new(SERVICE_NAME, "wb_cache::company"),
685                    KeyValue::new(SERVICE_VERSION, env!("CARGO_PKG_VERSION")),
686                    KeyValue::new(DEPLOYMENT_ENVIRONMENT_NAME, "develop"),
687                ],
688                SCHEMA_URL,
689            )
690            .build()
691    }
692
693    #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
694    fn init_meter_provider(&self) -> Result<opentelemetry_sdk::metrics::SdkMeterProvider, SimErrorAny> {
695        use opentelemetry::global;
696        use opentelemetry_sdk::metrics::MeterProviderBuilder;
697        use opentelemetry_sdk::metrics::PeriodicReader;
698
699        let exporter = opentelemetry_otlp::MetricExporter::builder()
700            .with_tonic()
701            .with_temporality(opentelemetry_sdk::metrics::Temporality::default())
702            .build()
703            .unwrap();
704
705        let reader = PeriodicReader::builder(exporter)
706            .with_interval(std::time::Duration::from_secs(30))
707            .build();
708
709        // For debugging in development
710        // let stdout_reader = PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build();
711
712        let meter_provider = MeterProviderBuilder::default()
713            .with_resource(Self::resource())
714            .with_reader(reader)
715            // .with_reader(stdout_reader)
716            .build();
717
718        global::set_meter_provider(meter_provider.clone());
719
720        Ok(meter_provider)
721    }
722
723    #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
724    fn init_tracer_provider(&self) -> Result<opentelemetry_sdk::trace::SdkTracerProvider, SimErrorAny> {
725        use opentelemetry_sdk::trace::RandomIdGenerator;
726        use opentelemetry_sdk::trace::Sampler;
727        use opentelemetry_sdk::trace::SdkTracerProvider;
728
729        let exporter = opentelemetry_otlp::SpanExporter::builder().with_tonic().build()?;
730
731        Ok(SdkTracerProvider::builder()
732            // Customize sampling strategy
733            .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(1.0))))
734            // If export trace to AWS X-Ray, you can use XrayIdGenerator
735            .with_id_generator(RandomIdGenerator::default())
736            .with_resource(Self::resource())
737            .with_batch_exporter(exporter)
738            .build())
739    }
740
741    #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
742    #[allow(clippy::type_complexity)]
743    fn setup_tracing_otlp<R>(
744        &self,
745        registry: R,
746    ) -> Result<
747        tracing_subscriber::layer::Layered<
748            tracing_opentelemetry::MetricsLayer<
749                tracing_subscriber::layer::Layered<
750                    tracing_opentelemetry::OpenTelemetryLayer<R, opentelemetry_sdk::trace::Tracer>,
751                    R,
752                >,
753            >,
754            tracing_subscriber::layer::Layered<
755                tracing_opentelemetry::OpenTelemetryLayer<R, opentelemetry_sdk::trace::Tracer>,
756                R,
757            >,
758        >,
759        SimErrorAny,
760    >
761    where
762        R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
763    {
764        use opentelemetry::trace::TracerProvider;
765        use tracing_opentelemetry::MetricsLayer;
766        use tracing_opentelemetry::OpenTelemetryLayer;
767        use tracing_subscriber::layer::SubscriberExt;
768
769        let meter_provider = self.init_meter_provider()?;
770        let otlp_exporter = opentelemetry_otlp::SpanExporter::builder().with_tonic().build()?;
771        let _ = opentelemetry_sdk::trace::SdkTracerProvider::builder()
772            .with_simple_exporter(otlp_exporter)
773            .build();
774
775        let tracer_provider = self.init_tracer_provider()?;
776        let tracer = tracer_provider.tracer("wb_cache::company");
777
778        Ok(registry
779            .with(OpenTelemetryLayer::new(tracer))
780            .with(MetricsLayer::new(meter_provider.clone())))
781    }
782
783    #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
784    fn setup_tracing_loki<R>(
785        &self,
786        registry: R,
787    ) -> Result<tracing_subscriber::layer::Layered<tracing_loki::Layer, R>, SimErrorAny>
788    where
789        R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
790    {
791        use std::process;
792
793        let url = self.cli()?.loki_url();
794
795        let (loki, loki_task) = tracing_loki::builder()
796            .label("app", "wb_cache::company")?
797            .extra_field("pid", format!("{}", process::id()))?
798            .build_url(url)?;
799
800        tokio::spawn(loki_task);
801
802        Ok(registry.with(loki))
803    }
804
805    #[cfg(all(feature = "tracing", feature = "tracing-file"))]
806    #[allow(clippy::type_complexity)]
807    fn setup_tracing_file<R>(
808        &self,
809        registry: R,
810    ) -> Result<
811        tracing_subscriber::layer::Layered<
812            tracing_subscriber::fmt::Layer<
813                R,
814                tracing_subscriber::fmt::format::DefaultFields,
815                tracing_subscriber::fmt::format::Format,
816                ::std::sync::Mutex<Box<dyn std::io::Write + Send + 'static>>,
817            >,
818            R,
819        >,
820        SimErrorAny,
821    >
822    where
823        R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
824    {
825        use std::io;
826        use std::sync::Mutex;
827        use tracing_subscriber::fmt::format::FmtSpan;
828
829        let cli = self.cli()?;
830
831        let dest_writer = Mutex::new(if let Some(log_file) = cli.log_file() {
832            let file = std::fs::OpenOptions::new()
833                .create(true)
834                .write(true)
835                .truncate(true)
836                .open(log_file)?;
837            Box::new(file) as Box<dyn io::Write + Send>
838        }
839        else {
840            Box::new(io::stdout()) as Box<dyn io::Write + Send>
841        });
842
843        Ok(registry.with(
844            tracing_subscriber::fmt::layer()
845                .with_writer(dest_writer)
846                .with_span_events(FmtSpan::FULL),
847        ))
848    }
849
850    #[cfg(feature = "tracing")]
851    fn setup_tracing(&self) -> Result<(), SimErrorAny> {
852        use tracing::info;
853        use tracing_subscriber::layer::SubscriberExt;
854        use tracing_subscriber::util::SubscriberInitExt;
855
856        let filter = tracing_subscriber::EnvFilter::from_default_env();
857
858        let tracing_registry = tracing_subscriber::registry();
859        let tracing_registry = tracing_registry.with(filter);
860
861        #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
862        let tracing_registry = self.setup_tracing_otlp(tracing_registry)?;
863
864        #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
865        let tracing_registry = self.setup_tracing_loki(tracing_registry)?;
866
867        #[cfg(all(feature = "tracing", feature = "tracing-file"))]
868        let tracing_registry = self.setup_tracing_file(tracing_registry)?;
869
870        tracing_registry.try_init()?;
871
872        info!("Tracing initialized");
873
874        Ok(())
875    }
876
877    pub async fn execute(&self) -> Result<(), SimErrorAny> {
878        let cli = match self.cli() {
879            Ok(cli) => cli,
880            Err(err) => match err.kind() {
881                ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
882                    let mut cmd = Cli::command();
883                    // let mut cmd = cmd.color(clap::ColorChoice::Always);
884                    cmd.print_help().unwrap();
885                    return Ok(());
886                }
887                _ => {
888                    return Err(err.into());
889                }
890            },
891        };
892
893        self.validate()?;
894
895        #[cfg(feature = "tracing")]
896        self.setup_tracing()?;
897
898        if cli.save() {
899            let myself = self.myself().unwrap();
900            return tokio::task::spawn_blocking(move || myself.save_script()).await?;
901        }
902
903        #[cfg(any(feature = "pg", feature = "sqlite"))]
904        {
905            let script = if cli.load() {
906                self.load_script()?
907            }
908            else {
909                let s = self.script_writer()?.create()?;
910                self.clear_script_writer();
911                s
912            };
913
914            self.execute_per_db(script).await?;
915        }
916
917        Ok(())
918    }
919
920    pub async fn run() -> Result<(), SimErrorAny> {
921        EcommerceApp::__fieldx_new().execute().await
922    }
923}
924
925impl EcommerceAppBuilder {
926    pub fn cli_args<S: ToString>(self, args: Vec<S>) -> Self {
927        self._cli_args(args.into_iter().map(|s| s.to_string()).collect())
928    }
929}
930
931impl Debug for EcommerceApp {
932    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
933        write!(f, "SimApp {{ ... }}")
934    }
935}
936
937impl SimulationApp for EcommerceApp {
938    fn acquire_progress<'a>(
939        &'a self,
940        style: PStyle,
941        order: Option<POrder<'a>>,
942    ) -> Result<Option<ProgressBar>, SimErrorAny> {
943        Ok(self.progress_ui()?.acquire_progress(style, order))
944    }
945
946    fn set_cached_per_sec(&self, step: f64) {
947        self._set_cached_per_sec(step);
948    }
949
950    fn set_plain_per_sec(&self, step: f64) {
951        self._set_plain_per_sec(step);
952    }
953
954    fn report_info<S: ToString>(&self, msg: S) {
955        self.progress_ui().unwrap().report_info(msg.to_string());
956    }
957
958    fn report_debug<S: ToString>(&self, msg: S) {
959        self.progress_ui().unwrap().report_debug(msg.to_string());
960    }
961
962    fn report_warn<S: ToString>(&self, msg: S) {
963        self.progress_ui().unwrap().report_warn(msg.to_string());
964    }
965
966    fn report_error<S: ToString>(&self, msg: S) {
967        self.progress_ui().unwrap().report_error(msg.to_string());
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    #[test]
976    fn test_cli_parsing() {
977        let args = vec!["cmd", "--quiet", "--test", "--products", "5", "--period", "30"];
978        let cli = Cli::try_parse_from(args).expect("Failed to parse CLI arguments");
979        assert_eq!(cli.products(), 5);
980        assert_eq!(cli.period(), 30);
981        assert!(cli.quiet());
982        assert!(cli.test());
983    }
984}