wifi_densepose_train/lib.rs
1//! # WiFi-DensePose Training Infrastructure
2//!
3//! This crate provides the complete training pipeline for the WiFi-DensePose pose
4//! estimation model. It includes configuration management, dataset loading with
5//! subcarrier interpolation, loss functions, evaluation metrics, and the training
6//! loop orchestrator.
7//!
8//! ## Architecture
9//!
10//! ```text
11//! TrainingConfig ──► Trainer ──► Model
12//! │ │
13//! │ DataLoader
14//! │ │
15//! │ CsiDataset (MmFiDataset | SyntheticCsiDataset)
16//! │ │
17//! │ subcarrier::interpolate_subcarriers
18//! │
19//! └──► losses / metrics
20//! ```
21//!
22//! ## Quick Start
23//!
24//! ```rust,no_run
25//! use wifi_densepose_train::config::TrainingConfig;
26//! use wifi_densepose_train::dataset::{SyntheticCsiDataset, SyntheticConfig, CsiDataset};
27//!
28//! // Build config
29//! let config = TrainingConfig::default();
30//! config.validate().expect("config is valid");
31//!
32//! // Create a synthetic dataset (deterministic, fixed-seed)
33//! let syn_cfg = SyntheticConfig::default();
34//! let dataset = SyntheticCsiDataset::new(200, syn_cfg);
35//!
36//! // Load one sample
37//! let sample = dataset.get(0).unwrap();
38//! println!("amplitude shape: {:?}", sample.amplitude.shape());
39//! ```
40
41// Note: #![forbid(unsafe_code)] is intentionally absent because the `tch`
42// dependency (PyTorch Rust bindings) internally requires unsafe code via FFI.
43// All *this* crate's code is written without unsafe blocks.
44#![warn(missing_docs)]
45
46pub mod config;
47pub mod dataset;
48pub mod domain;
49pub mod error;
50pub mod eval;
51pub mod geometry;
52pub mod rapid_adapt;
53pub mod ruview_metrics;
54pub mod subcarrier;
55pub mod virtual_aug;
56
57// The following modules use `tch` (PyTorch Rust bindings) for GPU-accelerated
58// training and are only compiled when the `tch-backend` feature is enabled.
59// Without the feature the crate still provides the dataset / config / subcarrier
60// APIs needed for data preprocessing and proof verification.
61#[cfg(feature = "tch-backend")]
62pub mod losses;
63#[cfg(feature = "tch-backend")]
64pub mod metrics;
65#[cfg(feature = "tch-backend")]
66pub mod model;
67#[cfg(feature = "tch-backend")]
68pub mod proof;
69#[cfg(feature = "tch-backend")]
70pub mod trainer;
71
72// Convenient re-exports at the crate root.
73pub use config::TrainingConfig;
74pub use dataset::{CsiDataset, CsiSample, DataLoader, MmFiDataset, SyntheticCsiDataset, SyntheticConfig};
75pub use error::{ConfigError, DatasetError, SubcarrierError, TrainError};
76// TrainResult<T> is the generic Result alias from error.rs; the concrete
77// TrainResult struct from trainer.rs is accessed via trainer::TrainResult.
78pub use error::TrainResult as TrainResultAlias;
79pub use subcarrier::{compute_interp_weights, interpolate_subcarriers, select_subcarriers_by_variance};
80
81// MERIDIAN (ADR-027) re-exports.
82pub use domain::{
83 AdversarialSchedule, DomainClassifier, DomainFactorizer, GradientReversalLayer,
84};
85pub use eval::CrossDomainEvaluator;
86pub use geometry::{FilmLayer, FourierPositionalEncoding, GeometryEncoder, MeridianGeometryConfig};
87pub use rapid_adapt::{AdaptError, AdaptationLoss, AdaptationResult, RapidAdaptation};
88pub use virtual_aug::VirtualDomainAugmentor;
89
90/// Crate version string.
91pub const VERSION: &str = env!("CARGO_PKG_VERSION");