1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::fmt::Display;
4use std::io::BufWriter;
5use std::io::Read;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::Duration;
9use std::time::Instant;
10
11use clap::error::ErrorKind;
12use clap::CommandFactory;
13use clap::Parser;
14use fieldx::fxstruct;
15use fieldx_plus::agent_build;
16use fieldx_plus::fx_plus;
17use garde::Validate;
18use indicatif::ProgressBar;
19use indicatif::ProgressStyle;
20use postcard::to_io;
21use sea_orm::entity::*;
22use sea_orm::query::*;
23use sea_orm::EntityTrait;
24use sea_orm::QueryOrder;
25use sea_orm_migration::MigratorTrait;
26use tokio::sync::Barrier;
27use tokio::task::JoinSet;
28use tokio_stream::StreamExt;
29use tracing::instrument;
30
31use super::actor::TestActor;
32use super::db;
33#[cfg(feature = "pg")]
34use super::db::driver::pg::Pg;
35#[cfg(feature = "sqlite")]
36use super::db::driver::sqlite::Sqlite;
37use super::db::driver::DatabaseDriver;
38use super::db::entity::Customers;
39use super::db::entity::InventoryRecords;
40use super::db::entity::Orders;
41use super::db::entity::Products;
42use super::db::entity::Sessions;
43use super::db::migrations::Migrator;
44use super::progress::MaybeProgress;
45use super::progress::POrder;
46use super::progress::PStyle;
47use super::progress::ProgressUI;
48use super::scriptwriter::steps::Step;
49use super::scriptwriter::ScriptWriter;
50use super::types::simerr;
51use super::types::Result;
52use super::types::SimError;
53use super::types::SimErrorAny;
54use super::SimulationApp;
55
56const INNER_ZIP_NAME: &str = "__script.postcard";
57
58#[derive(Debug, Clone, clap::Parser, Validate)]
59#[fxstruct(no_new, get(copy))]
60#[clap(about, version, author, name = "company")]
61pub(crate) struct Cli {
62 #[fieldx(get(clone))]
64 #[garde(skip)]
65 script: Option<PathBuf>,
66
67 #[clap(long, short, default_value_t = false)]
69 #[garde(skip)]
70 quiet: bool,
71
72 #[clap(long, default_value_t = 365)]
74 #[garde(range(min = 1))]
75 period: u32,
76
77 #[clap(long, default_value_t = 10)]
79 #[garde(range(min = 1))]
80 products: u32,
81
82 #[clap(long, default_value_t = 1)]
84 #[garde(range(min = 1))]
85 initial_customers: u32,
86
87 #[clap(long, default_value_t = 1_000)]
89 #[garde(range(min = 1))]
90 market_capacity: u32,
91
92 #[clap(long, default_value_t = 400)]
94 #[garde(range(min = 1), custom(Self::less_than("market-capacity", &self.market_capacity)))]
95 inflection_point: u32,
96
97 #[clap(long, default_value_t = 0.05)]
99 #[garde(range(min = 0.0))]
100 growth_rate: f32,
101
102 #[clap(long, default_value_t = 0.15)]
105 #[garde(range(min = 0.0), custom(Self::less_than("max-customer-orders", &self.max_customer_orders)))]
106 min_customer_orders: f32,
107
108 #[clap(long, default_value_t = 3.0)]
111 #[garde(range(min = 0.0))]
112 max_customer_orders: f32,
113
114 #[clap(long, default_value_t = 30)]
116 #[garde(skip)]
117 return_window: u32,
118
119 #[clap(long, short)]
121 #[garde(custom(Self::with_file(&self.script)))]
122 save: bool,
123
124 #[clap(long, short)]
126 #[garde(custom(Self::with_file(&self.script)))]
127 #[fieldx(get(attributes_fn(allow(unused))))]
129 load: bool,
130
131 #[clap(long)]
133 #[garde(skip)]
134 #[fieldx(get(attributes_fn(allow(unused))))]
136 test: bool,
137
138 #[cfg_attr(feature = "sqlite", clap(long))]
139 #[fieldx(get(copy, attributes_fn(cfg(feature = "sqlite"))))]
140 #[cfg(feature = "sqlite")]
141 #[garde(skip)]
142 sqlite: bool,
144
145 #[cfg_attr(feature = "sqlite", clap(long, env = "WBCACHE_SQLITE_PATH"))]
148 #[fieldx(get(clone, attributes_fn(cfg(feature = "sqlite"))))]
149 #[cfg(feature = "sqlite")]
150 #[garde(skip)]
151 sqlite_path: Option<PathBuf>,
152
153 #[cfg_attr(feature = "pg", clap(long))]
154 #[fieldx(get(copy, attributes_fn(cfg(feature = "pg"))))]
155 #[garde(skip)]
156 #[cfg(feature = "pg")]
157 pg: bool,
159
160 #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_HOST", default_value = "localhost"))]
161 #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
162 #[garde(skip)]
163 #[cfg(feature = "pg")]
164 pg_host: String,
165
166 #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_PORT", default_value_t = 5432))]
167 #[fieldx(get(copy, attributes_fn(cfg(feature = "pg"))))]
168 #[garde(skip)]
169 #[cfg(feature = "pg")]
170 pg_port: u16,
171
172 #[cfg_attr(feature = "pg", clap(long, env = "WBCACHE_PG_USER", default_value = "wbcache"))]
173 #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
174 #[garde(skip)]
175 #[cfg(feature = "pg")]
176 pg_user: String,
177
178 #[cfg_attr(
179 feature = "pg",
180 clap(long, env = "WBCACHE_PG_PASSWORD", hide_env_values = true, default_value = "wbcache")
181 )]
182 #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
183 #[garde(skip)]
184 #[cfg(feature = "pg")]
185 pg_password: String,
186
187 #[cfg_attr(
188 feature = "pg",
189 clap(long, env = "WBCACHE_PG_DB_PREFIX", default_value = "wbcache_test")
190 )]
191 #[fieldx(get(clone, attributes_fn(cfg(feature = "pg"))))]
192 #[garde(skip)]
193 #[cfg(feature = "pg")]
194 pg_db_prefix: String,
195
196 #[cfg_attr(feature = "log", clap(long, env = "WBCACHE_LOG_FILE"))]
198 #[fieldx(get(clone, attributes_fn(cfg(feature = "log"), allow(unused))))]
199 #[garde(skip)]
200 #[cfg(feature = "log")]
201 log_file: Option<PathBuf>,
202
203 #[cfg_attr(
205 all(feature = "tracing", feature = "tracing-loki"),
206 clap(long, env = "WBCACHE_LOKI_URL", default_value = "https://127.0.0.1:3100")
207 )]
208 #[fieldx(get(
209 clone,
210 attributes_fn(cfg(all(feature = "tracing", feature = "tracing-loki")), allow(unused))
211 ))]
212 #[garde(skip)]
213 #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
214 loki_url: tracing_loki::url::Url,
215}
216
217impl Cli {
218 fn less_than<'a, T: PartialOrd + Display>(
219 max_name: &'static str,
220 max: &'a T,
221 ) -> impl FnOnce(&'a T, &()) -> garde::Result {
222 move |value, _| {
223 if value > max {
224 Err(garde::Error::new(format!(
225 "{} is more than {max_name} ({})",
226 *value, *max
227 )))
228 }
229 else {
230 Ok(())
231 }
232 }
233 }
234
235 fn with_file<'a>(file: &'a Option<PathBuf>) -> impl FnOnce(&'a bool, &()) -> garde::Result {
236 move |value, _| {
237 if *value && file.is_none() {
238 Err(garde::Error::new("Script file name is required"))
239 }
240 else {
241 Ok(())
242 }
243 }
244 }
245}
246
247#[fx_plus(
248 app,
249 rc,
250 new(private),
251 sync,
252 get,
253 fallible(off, error(SimErrorAny)),
254 builder(vis(pub))
255)]
256pub struct EcommerceApp {
257 #[fieldx(inner_mut, clearer, builder("_cli_args"))]
258 cli_args: Vec<String>,
259
260 #[fieldx(lazy, private, fallible(error(clap::Error)), get(clone))]
261 cli: Cli,
262
263 #[fieldx(lazy, get, clearer, fallible)]
264 script_writer: Arc<ScriptWriter>,
265
266 #[fieldx(lazy, private, get(attributes_fn(allow(unused))), fallible)]
268 tempdir: tempfile::TempDir,
269
270 #[fieldx(lazy, fallible, get, clearer)]
271 progress_ui: ProgressUI,
272
273 #[fieldx(
275 lock,
276 private,
277 get(copy, attributes_fn(allow(unused))),
278 set("_set_plain_per_sec"),
279 default(0.0)
280 )]
281 plain_per_sec: f64,
282
283 #[fieldx(
285 lock,
286 private,
287 get(copy, attributes_fn(allow(unused))),
288 set("_set_cached_per_sec"),
289 default(0.0)
290 )]
291 cached_per_sec: f64,
292}
293
294impl EcommerceApp {
295 fn build_cli(&self) -> Result<Cli, clap::Error> {
296 Ok(if let Some(custom_args) = self.clear_cli_args() {
297 Cli::try_parse_from(custom_args.into_iter())?
298 }
299 else {
300 Cli::try_parse()?
301 })
302 }
303
304 fn build_script_writer(&self) -> Result<Arc<ScriptWriter>> {
305 let cli = self.cli()?;
306 Ok(ScriptWriter::builder()
307 .quiet(cli.quiet())
308 .period(cli.period() as i32)
309 .product_count(cli.products() as i32)
310 .initial_customers(cli.initial_customers())
311 .market_capacity(cli.market_capacity())
312 .inflection_point(cli.inflection_point())
313 .growth_rate(cli.growth_rate() as f64)
314 .min_customer_orders(cli.min_customer_orders() as f64)
315 .max_customer_orders(cli.max_customer_orders() as f64)
316 .return_window(cli.return_window() as i32)
317 .build()?)
318 }
319
320 fn build_tempdir(&self) -> Result<tempfile::TempDir, SimErrorAny> {
321 Ok(tempfile::Builder::new().prefix("wb-cache-simulation").tempdir()?)
322 }
323
324 fn build_progress_ui(&self) -> Result<ProgressUI, SimErrorAny> {
325 Ok(ProgressUI::builder().quiet(self.cli()?.quiet()).build()?)
326 }
327
328 fn validate(&self) -> Result<(), SimErrorAny> {
329 if let Err(err) = self.cli()?.validate() {
330 let mut cmd = Cli::command();
331 let err = cmd.error(ErrorKind::InvalidValue, err);
332
333 err.exit();
334 }
335
336 Ok(())
337 }
338
339 async fn db_prepare<D: DatabaseDriver>(&self, dbd: &D) -> Result<()> {
340 dbd.configure().await?;
341 let db = dbd.connection();
342 Migrator::down(&db, None).await?;
343 Migrator::up(&db, None).await?;
344 Ok(())
345 }
346
347 async fn compare_tables<E>(
348 &self,
349 table: &str,
350 key: E::Column,
351 name1: &str,
352 db1: Arc<impl DatabaseDriver>,
353 name2: &str,
354 db2: Arc<impl DatabaseDriver>,
355 ) -> Result<(), SimErrorAny>
356 where
357 E: EntityTrait,
358 E::Model: FromQueryResult + Sized + Send + Sync + PartialEq + Debug,
359 {
360 let conn1 = db1.connection();
361 let conn2 = db2.connection();
362
363 let mut paginator1 = E::find().order_by_asc(key).paginate(&conn1, 1000).into_stream();
364 let mut paginator2 = E::find().order_by_asc(key).paginate(&conn2, 1000).into_stream();
365
366 loop {
367 let page1 = paginator1.next().await;
368 let page2 = paginator2.next().await;
369
370 if page1.is_none() && page2.is_none() {
371 break;
372 }
373
374 if page1.is_none() {
375 return Err(simerr!("Table '{table}': {name2} has more records than {name1}"));
376 }
377 if page2.is_none() {
378 return Err(simerr!("Table '{table}': {name1} has more records than {name2}"));
379 }
380
381 let page1 = page1.unwrap()?;
382 let page2 = page2.unwrap()?;
383
384 if page1.len() != page2.len() {
385 return Err(simerr!(
386 "Table '{table}': {name1} has {} records, {name2} has {} records",
387 page1.len(),
388 page2.len()
389 ));
390 }
391
392 for (record1, record2) in page1.iter().zip(page2.iter()) {
393 if record1 != record2 {
394 return Err(simerr!(
395 "Table '{table}': Records do not match: {name1} = {:?}, {name2} = {:?}",
396 record1,
397 record2
398 ));
399 }
400 }
401 }
402
403 Ok(())
404 }
405
406 async fn test_db<D: DatabaseDriver>(&self, db_plain: Arc<D>, db_cached: Arc<D>) -> Result<(), SimErrorAny> {
409 self.compare_tables::<Customers>(
410 "customers",
411 db::entity::customer::Column::Id,
412 "plain",
413 db_plain.clone(),
414 "cached",
415 db_cached.clone(),
416 )
417 .await?;
418
419 self.compare_tables::<InventoryRecords>(
420 "inventory_records",
421 db::entity::inventory_record::Column::ProductId,
422 "plain",
423 db_plain.clone(),
424 "cached",
425 db_cached.clone(),
426 )
427 .await?;
428
429 self.compare_tables::<Products>(
430 "products",
431 db::entity::product::Column::Id,
432 "plain",
433 db_plain.clone(),
434 "cached",
435 db_cached.clone(),
436 )
437 .await?;
438
439 self.compare_tables::<Orders>(
440 "orders",
441 db::entity::order::Column::Id,
442 "plain",
443 db_plain.clone(),
444 "cached",
445 db_cached.clone(),
446 )
447 .await?;
448
449 self.compare_tables::<Sessions>(
450 "sessions",
451 db::entity::session::Column::Id,
452 "plain",
453 db_plain.clone(),
454 "cached",
455 db_cached.clone(),
456 )
457 .await?;
458
459 Ok(())
460 }
461
462 #[instrument(level = "trace", skip(self, db_plain, db_cached, screenplay))]
463 async fn execute_script<D: DatabaseDriver>(
464 &self,
465 db_plain: Arc<D>,
466 db_cached: Arc<D>,
467 screenplay: Arc<Vec<Step>>,
468 ) -> Result<(), SimErrorAny> {
469 let barrier = Arc::new(Barrier::new(2));
470
471 let message_progress = self.progress_ui()?.acquire_progress(PStyle::Message, None);
472 message_progress.maybe_set_prefix("Rate");
473
474 let mut tasks = JoinSet::<Result<(&'static str, Duration), SimError>>::new();
475
476 let myself = self.myself().unwrap();
477 tokio::spawn(async move {
478 let mut interval = tokio::time::interval(Duration::from_millis(100));
479 loop {
480 interval.tick().await;
481
482 let rate = if myself.plain_per_sec() > 0.0 {
483 myself.cached_per_sec() / myself.plain_per_sec()
484 }
485 else {
486 0.0
487 };
488
489 message_progress.maybe_set_message(format!(
490 "{rate:.2}x | Average: cached {:.2}/s, plain {:.2}/s",
491 myself.cached_per_sec(),
492 myself.plain_per_sec()
493 ));
494 message_progress.maybe_inc(1);
495 }
496 });
497
498 let myself = self.myself().unwrap();
500 let s1 = screenplay.clone();
501 let b1 = barrier.clone();
502 let db_plain_async = db_plain.clone();
503 tasks.spawn(async move {
504 myself.db_prepare(&*db_plain_async).await?;
505 b1.wait().await;
506 let started = Instant::now();
507 let plain_actor = agent_build!(
508 myself, crate::test::simulation::company_plain::TestCompany<Self, D> {
509 db: db_plain_async
510 }
511 )?;
512 plain_actor.act(&s1).await.inspect_err(|err| {
513 err.context("Plain actor");
514 })?;
515 Ok(("plain", Instant::now().duration_since(started)))
516 });
517
518 let s2 = screenplay.clone();
520 let b2 = barrier.clone();
521 let myself = self.myself().unwrap();
522 let db_cached_async = db_cached.clone();
523 tasks.spawn(async move {
524 myself.db_prepare(&*db_cached_async).await?;
525 b2.wait().await;
526 let started = Instant::now();
527 let cached_actor = agent_build!(
528 myself, crate::test::simulation::company_cached::TestCompany<Self, D> {
529 db: db_cached_async
530 }
531 )?;
532 cached_actor.act(&s2).await.inspect_err(|err| {
533 err.context("Cached actor");
534 })?;
535 myself.report_debug("Cached actor completed.");
536 Ok(("cached", Instant::now().duration_since(started)))
537 });
538
539 let mut all_success = true;
540 let mut outcomes = HashMap::new();
541
542 while let Some(res) = tasks.join_next().await {
543 match res {
544 Ok(Ok((label, duration))) => {
545 self.report_info(format!("{} actor completed in {:.2}s", label, duration.as_secs_f64()));
546 outcomes.insert(label.to_string(), duration);
547 }
548 Ok(Err(err)) => {
549 all_success = false;
550 self.report_error(err.to_string_with_backtrace("An error occurred during actor execution"));
551 tasks.abort_all();
552 }
553 Err(err) => {
554 all_success = false;
555 let err = SimErrorAny::from(err);
556 self.report_error(err.to_string_with_backtrace("Actor errorred out"));
557 tasks.abort_all();
558 }
559 }
560 self.report_info(format!("Tasks left: {}", tasks.len()));
561 }
562
563 if all_success {
564 let plain = outcomes.get("plain").unwrap();
565 let cached = outcomes.get("cached").unwrap();
566 self.report_info(format!("{:>11} | {:>11}", "plain", "cached"));
567 self.report_info(format!(
568 "{:>10.2}s | {:>10.2}s",
569 plain.as_secs_f64(),
570 cached.as_secs_f64()
571 ));
572 self.report_info(format!(
573 "{:>10.2}x | {:>10.2}x",
574 plain.as_secs_f64() / cached.as_secs_f64(),
575 1.0
576 ));
577 }
578
579 if self.cli()?.test() {
580 self.test_db(db_plain, db_cached).await?;
581 }
582
583 Ok(())
584 }
585
586 fn save_script(&self) -> Result<(), SimErrorAny> {
587 let script = self.script_writer()?.create()?;
588 let script_file = self.cli()?.script().unwrap();
589
590 let out = std::fs::File::create(&script_file)?;
591 let mut zip = zip::ZipWriter::new(out);
592 zip.start_file(INNER_ZIP_NAME, zip::write::SimpleFileOptions::default())?;
593 let pb = ProgressBar::no_length()
594 .with_message(format!("Saving script to {}", script_file.display()))
595 .with_style(ProgressStyle::default_spinner().template("[{binary_bytes:.yellow}] {msg}")?);
596
597 let mut zip = BufWriter::with_capacity(128 * 1024, zip);
598 to_io(&script, pb.wrap_write(&mut zip))?;
599 pb.finish_with_message("Script saved successfully.");
600 zip.into_inner()?.finish()?;
601
602 Ok(())
603 }
604
605 fn load_script(&self) -> Result<Vec<Step>, SimErrorAny> {
606 let script_file = self.cli()?.script().unwrap();
607 let file = std::fs::File::open(&script_file)?;
608 let mut zip = zip::ZipArchive::new(file)?;
609 let zip_file = zip.by_name(INNER_ZIP_NAME)?;
610
611 let size = zip_file.size();
612 let mut buf = vec![0; size as usize];
613
614 let pb = ProgressBar::new(size)
615 .with_message(format!("Loading script from {}", script_file.display()))
616 .with_style(ProgressStyle::default_spinner().template("[{binary_bytes:.yellow}] {msg}")?);
617
618 pb.wrap_read(zip_file).read_exact(&mut buf[..size as usize])?;
619 pb.set_message("Script file loaded successfully.");
620 let script: Vec<Step> = postcard::from_bytes(&buf)?;
621 pb.finish_with_message("Script extracted successfully.");
622
623 Ok(script)
624 }
625
626 #[cfg(feature = "sqlite")]
627 fn db_dir(&self) -> Result<PathBuf, SimErrorAny> {
628 self.cli()?
629 .sqlite_path()
630 .as_ref()
631 .cloned()
632 .map_or_else(|| self.tempdir().map(|t| t.path().to_path_buf()), Ok)
633 }
634
635 #[cfg(any(feature = "pg", feature = "sqlite"))]
636 #[instrument(level = "trace", skip(script, self))]
637 async fn execute_per_db(&self, script: Vec<Step>) -> Result<(), SimErrorAny> {
638 let cli = self.cli()?;
639 let script = Arc::new(script);
640
641 #[cfg(feature = "sqlite")]
642 if cli.sqlite() {
643 let db_plain = Sqlite::connect(&self.db_dir()?, "test_company_plan.db").await?;
644 let db_cached = Sqlite::connect(&self.db_dir()?, "test_company_cached.db").await?;
645 self.execute_script(db_plain, db_cached, script.clone()).await?;
646 }
647
648 #[cfg(feature = "pg")]
649 if cli.pg() {
650 let db_plain = Pg::builder()
651 .host(cli.pg_host())
652 .port(cli.pg_port())
653 .user(cli.pg_user())
654 .password(cli.pg_password())
655 .database(format!("{}_plain", cli.pg_db_prefix()))
656 .build()?;
657 db_plain.connect().await?;
658 let db_cached = Pg::builder()
659 .host(cli.pg_host())
660 .port(cli.pg_port())
661 .user(cli.pg_user())
662 .password(cli.pg_password())
663 .database(format!("{}_cached", cli.pg_db_prefix()))
664 .build()?;
665 db_cached.connect().await?;
666 self.execute_script(db_plain, db_cached, script.clone()).await?;
667 }
668
669 Ok(())
670 }
671
672 #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
673 fn resource() -> opentelemetry_sdk::Resource {
674 use opentelemetry::KeyValue;
675 use opentelemetry_semantic_conventions::attribute::DEPLOYMENT_ENVIRONMENT_NAME;
676 use opentelemetry_semantic_conventions::attribute::SERVICE_VERSION;
677 use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
678 use opentelemetry_semantic_conventions::SCHEMA_URL;
679
680 opentelemetry_sdk::Resource::builder()
681 .with_service_name(env!("CARGO_PKG_NAME"))
682 .with_schema_url(
683 [
684 KeyValue::new(SERVICE_NAME, "wb_cache::company"),
685 KeyValue::new(SERVICE_VERSION, env!("CARGO_PKG_VERSION")),
686 KeyValue::new(DEPLOYMENT_ENVIRONMENT_NAME, "develop"),
687 ],
688 SCHEMA_URL,
689 )
690 .build()
691 }
692
693 #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
694 fn init_meter_provider(&self) -> Result<opentelemetry_sdk::metrics::SdkMeterProvider, SimErrorAny> {
695 use opentelemetry::global;
696 use opentelemetry_sdk::metrics::MeterProviderBuilder;
697 use opentelemetry_sdk::metrics::PeriodicReader;
698
699 let exporter = opentelemetry_otlp::MetricExporter::builder()
700 .with_tonic()
701 .with_temporality(opentelemetry_sdk::metrics::Temporality::default())
702 .build()
703 .unwrap();
704
705 let reader = PeriodicReader::builder(exporter)
706 .with_interval(std::time::Duration::from_secs(30))
707 .build();
708
709 let meter_provider = MeterProviderBuilder::default()
713 .with_resource(Self::resource())
714 .with_reader(reader)
715 .build();
717
718 global::set_meter_provider(meter_provider.clone());
719
720 Ok(meter_provider)
721 }
722
723 #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
724 fn init_tracer_provider(&self) -> Result<opentelemetry_sdk::trace::SdkTracerProvider, SimErrorAny> {
725 use opentelemetry_sdk::trace::RandomIdGenerator;
726 use opentelemetry_sdk::trace::Sampler;
727 use opentelemetry_sdk::trace::SdkTracerProvider;
728
729 let exporter = opentelemetry_otlp::SpanExporter::builder().with_tonic().build()?;
730
731 Ok(SdkTracerProvider::builder()
732 .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(1.0))))
734 .with_id_generator(RandomIdGenerator::default())
736 .with_resource(Self::resource())
737 .with_batch_exporter(exporter)
738 .build())
739 }
740
741 #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
742 #[allow(clippy::type_complexity)]
743 fn setup_tracing_otlp<R>(
744 &self,
745 registry: R,
746 ) -> Result<
747 tracing_subscriber::layer::Layered<
748 tracing_opentelemetry::MetricsLayer<
749 tracing_subscriber::layer::Layered<
750 tracing_opentelemetry::OpenTelemetryLayer<R, opentelemetry_sdk::trace::Tracer>,
751 R,
752 >,
753 >,
754 tracing_subscriber::layer::Layered<
755 tracing_opentelemetry::OpenTelemetryLayer<R, opentelemetry_sdk::trace::Tracer>,
756 R,
757 >,
758 >,
759 SimErrorAny,
760 >
761 where
762 R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
763 {
764 use opentelemetry::trace::TracerProvider;
765 use tracing_opentelemetry::MetricsLayer;
766 use tracing_opentelemetry::OpenTelemetryLayer;
767 use tracing_subscriber::layer::SubscriberExt;
768
769 let meter_provider = self.init_meter_provider()?;
770 let otlp_exporter = opentelemetry_otlp::SpanExporter::builder().with_tonic().build()?;
771 let _ = opentelemetry_sdk::trace::SdkTracerProvider::builder()
772 .with_simple_exporter(otlp_exporter)
773 .build();
774
775 let tracer_provider = self.init_tracer_provider()?;
776 let tracer = tracer_provider.tracer("wb_cache::company");
777
778 Ok(registry
779 .with(OpenTelemetryLayer::new(tracer))
780 .with(MetricsLayer::new(meter_provider.clone())))
781 }
782
783 #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
784 fn setup_tracing_loki<R>(
785 &self,
786 registry: R,
787 ) -> Result<tracing_subscriber::layer::Layered<tracing_loki::Layer, R>, SimErrorAny>
788 where
789 R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
790 {
791 use std::process;
792
793 let url = self.cli()?.loki_url();
794
795 let (loki, loki_task) = tracing_loki::builder()
796 .label("app", "wb_cache::company")?
797 .extra_field("pid", format!("{}", process::id()))?
798 .build_url(url)?;
799
800 tokio::spawn(loki_task);
801
802 Ok(registry.with(loki))
803 }
804
805 #[cfg(all(feature = "tracing", feature = "tracing-file"))]
806 #[allow(clippy::type_complexity)]
807 fn setup_tracing_file<R>(
808 &self,
809 registry: R,
810 ) -> Result<
811 tracing_subscriber::layer::Layered<
812 tracing_subscriber::fmt::Layer<
813 R,
814 tracing_subscriber::fmt::format::DefaultFields,
815 tracing_subscriber::fmt::format::Format,
816 ::std::sync::Mutex<Box<dyn std::io::Write + Send + 'static>>,
817 >,
818 R,
819 >,
820 SimErrorAny,
821 >
822 where
823 R: tracing_subscriber::layer::SubscriberExt + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
824 {
825 use std::io;
826 use std::sync::Mutex;
827 use tracing_subscriber::fmt::format::FmtSpan;
828
829 let cli = self.cli()?;
830
831 let dest_writer = Mutex::new(if let Some(log_file) = cli.log_file() {
832 let file = std::fs::OpenOptions::new()
833 .create(true)
834 .write(true)
835 .truncate(true)
836 .open(log_file)?;
837 Box::new(file) as Box<dyn io::Write + Send>
838 }
839 else {
840 Box::new(io::stdout()) as Box<dyn io::Write + Send>
841 });
842
843 Ok(registry.with(
844 tracing_subscriber::fmt::layer()
845 .with_writer(dest_writer)
846 .with_span_events(FmtSpan::FULL),
847 ))
848 }
849
850 #[cfg(feature = "tracing")]
851 fn setup_tracing(&self) -> Result<(), SimErrorAny> {
852 use tracing::info;
853 use tracing_subscriber::layer::SubscriberExt;
854 use tracing_subscriber::util::SubscriberInitExt;
855
856 let filter = tracing_subscriber::EnvFilter::from_default_env();
857
858 let tracing_registry = tracing_subscriber::registry();
859 let tracing_registry = tracing_registry.with(filter);
860
861 #[cfg(all(feature = "tracing", feature = "tracing-otlp"))]
862 let tracing_registry = self.setup_tracing_otlp(tracing_registry)?;
863
864 #[cfg(all(feature = "tracing", feature = "tracing-loki"))]
865 let tracing_registry = self.setup_tracing_loki(tracing_registry)?;
866
867 #[cfg(all(feature = "tracing", feature = "tracing-file"))]
868 let tracing_registry = self.setup_tracing_file(tracing_registry)?;
869
870 tracing_registry.try_init()?;
871
872 info!("Tracing initialized");
873
874 Ok(())
875 }
876
877 pub async fn execute(&self) -> Result<(), SimErrorAny> {
878 let cli = match self.cli() {
879 Ok(cli) => cli,
880 Err(err) => match err.kind() {
881 ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
882 let mut cmd = Cli::command();
883 cmd.print_help().unwrap();
885 return Ok(());
886 }
887 _ => {
888 return Err(err.into());
889 }
890 },
891 };
892
893 self.validate()?;
894
895 #[cfg(feature = "tracing")]
896 self.setup_tracing()?;
897
898 if cli.save() {
899 let myself = self.myself().unwrap();
900 return tokio::task::spawn_blocking(move || myself.save_script()).await?;
901 }
902
903 #[cfg(any(feature = "pg", feature = "sqlite"))]
904 {
905 let script = if cli.load() {
906 self.load_script()?
907 }
908 else {
909 let s = self.script_writer()?.create()?;
910 self.clear_script_writer();
911 s
912 };
913
914 self.execute_per_db(script).await?;
915 }
916
917 Ok(())
918 }
919
920 pub async fn run() -> Result<(), SimErrorAny> {
921 EcommerceApp::__fieldx_new().execute().await
922 }
923}
924
925impl EcommerceAppBuilder {
926 pub fn cli_args<S: ToString>(self, args: Vec<S>) -> Self {
927 self._cli_args(args.into_iter().map(|s| s.to_string()).collect())
928 }
929}
930
931impl Debug for EcommerceApp {
932 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
933 write!(f, "SimApp {{ ... }}")
934 }
935}
936
937impl SimulationApp for EcommerceApp {
938 fn acquire_progress<'a>(
939 &'a self,
940 style: PStyle,
941 order: Option<POrder<'a>>,
942 ) -> Result<Option<ProgressBar>, SimErrorAny> {
943 Ok(self.progress_ui()?.acquire_progress(style, order))
944 }
945
946 fn set_cached_per_sec(&self, step: f64) {
947 self._set_cached_per_sec(step);
948 }
949
950 fn set_plain_per_sec(&self, step: f64) {
951 self._set_plain_per_sec(step);
952 }
953
954 fn report_info<S: ToString>(&self, msg: S) {
955 self.progress_ui().unwrap().report_info(msg.to_string());
956 }
957
958 fn report_debug<S: ToString>(&self, msg: S) {
959 self.progress_ui().unwrap().report_debug(msg.to_string());
960 }
961
962 fn report_warn<S: ToString>(&self, msg: S) {
963 self.progress_ui().unwrap().report_warn(msg.to_string());
964 }
965
966 fn report_error<S: ToString>(&self, msg: S) {
967 self.progress_ui().unwrap().report_error(msg.to_string());
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974
975 #[test]
976 fn test_cli_parsing() {
977 let args = vec!["cmd", "--quiet", "--test", "--products", "5", "--period", "30"];
978 let cli = Cli::try_parse_from(args).expect("Failed to parse CLI arguments");
979 assert_eq!(cli.products(), 5);
980 assert_eq!(cli.period(), 30);
981 assert!(cli.quiet());
982 assert!(cli.test());
983 }
984}