Skip to main content

wifi_densepose_train/
error.rs

1//! Error types for the WiFi-DensePose training pipeline.
2//!
3//! This module is the single source of truth for all error types in the
4//! training crate. Every module that produces an error imports its error type
5//! from here rather than defining it inline, keeping the error hierarchy
6//! centralised and consistent.
7//!
8//! ## Hierarchy
9//!
10//! ```text
11//! TrainError (top-level)
12//! ├── ConfigError      (config validation / file loading)
13//! ├── DatasetError     (data loading, I/O, format)
14//! ├── SubcarrierError  (frequency-axis resampling)
15//! └── MaeError         (MAE patchify / masking — ADR-152 §2.3)
16//! ```
17
18use std::path::PathBuf;
19use thiserror::Error;
20
21// ---------------------------------------------------------------------------
22// TrainResult
23// ---------------------------------------------------------------------------
24
25/// Convenient `Result` alias used by orchestration-level functions.
26pub type TrainResult<T> = Result<T, TrainError>;
27
28// ---------------------------------------------------------------------------
29// TrainError — top-level aggregator
30// ---------------------------------------------------------------------------
31
32/// Top-level error type for the WiFi-DensePose training pipeline.
33///
34/// Orchestration-level functions (e.g. [`crate::trainer::Trainer`] methods)
35/// return `TrainResult<T>`. Lower-level functions in [`crate::config`] and
36/// [`crate::dataset`] return their own module-specific error types which are
37/// automatically coerced into `TrainError` via [`From`].
38#[derive(Debug, Error)]
39pub enum TrainError {
40    /// A configuration validation or loading error.
41    #[error("Configuration error: {0}")]
42    Config(#[from] ConfigError),
43
44    /// A dataset loading or access error.
45    #[error("Dataset error: {0}")]
46    Dataset(#[from] DatasetError),
47
48    /// A MAE pretraining patchify / masking error (ADR-152 §2.3).
49    #[error("MAE pretraining error: {0}")]
50    Mae(#[from] MaeError),
51
52    /// JSON (de)serialization error.
53    #[error("JSON error: {0}")]
54    Json(#[from] serde_json::Error),
55
56    /// The dataset is empty and no training can be performed.
57    #[error("Dataset is empty")]
58    EmptyDataset,
59
60    /// Index out of bounds when accessing dataset items.
61    #[error("Index {index} is out of bounds for dataset of length {len}")]
62    IndexOutOfBounds {
63        /// The out-of-range index.
64        index: usize,
65        /// The total number of items in the dataset.
66        len: usize,
67    },
68
69    /// A shape mismatch was detected between two tensors.
70    #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
71    ShapeMismatch {
72        /// Expected shape.
73        expected: Vec<usize>,
74        /// Actual shape.
75        actual: Vec<usize>,
76    },
77
78    /// A training step failed.
79    #[error("Training step failed: {0}")]
80    TrainingStep(String),
81
82    /// A checkpoint could not be saved or loaded.
83    #[error("Checkpoint error: {message} (path: {path:?})")]
84    Checkpoint {
85        /// Human-readable description.
86        message: String,
87        /// Path that was being accessed.
88        path: PathBuf,
89    },
90
91    /// Feature not yet implemented.
92    #[error("Not implemented: {0}")]
93    NotImplemented(String),
94}
95
96impl TrainError {
97    /// Construct a [`TrainError::TrainingStep`].
98    pub fn training_step<S: Into<String>>(msg: S) -> Self {
99        TrainError::TrainingStep(msg.into())
100    }
101
102    /// Construct a [`TrainError::Checkpoint`].
103    pub fn checkpoint<S: Into<String>>(msg: S, path: impl Into<PathBuf>) -> Self {
104        TrainError::Checkpoint {
105            message: msg.into(),
106            path: path.into(),
107        }
108    }
109
110    /// Construct a [`TrainError::NotImplemented`].
111    pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
112        TrainError::NotImplemented(msg.into())
113    }
114
115    /// Construct a [`TrainError::ShapeMismatch`].
116    pub fn shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
117        TrainError::ShapeMismatch { expected, actual }
118    }
119}
120
121// ---------------------------------------------------------------------------
122// ConfigError
123// ---------------------------------------------------------------------------
124
125/// Errors produced when loading or validating a [`TrainingConfig`].
126///
127/// [`TrainingConfig`]: crate::config::TrainingConfig
128#[derive(Debug, Error)]
129pub enum ConfigError {
130    /// A field has an invalid value.
131    #[error("Invalid value for `{field}`: {reason}")]
132    InvalidValue {
133        /// Name of the field.
134        field: &'static str,
135        /// Human-readable reason.
136        reason: String,
137    },
138
139    /// A configuration file could not be read from disk.
140    #[error("Cannot read config file `{path}`: {source}")]
141    FileRead {
142        /// Path that was being read.
143        path: PathBuf,
144        /// Underlying I/O error.
145        #[source]
146        source: std::io::Error,
147    },
148
149    /// A configuration file contains malformed JSON.
150    #[error("Cannot parse config file `{path}`: {source}")]
151    ParseError {
152        /// Path that was being parsed.
153        path: PathBuf,
154        /// Underlying JSON parse error.
155        #[source]
156        source: serde_json::Error,
157    },
158
159    /// A path referenced in the config does not exist.
160    #[error("Path `{path}` in config does not exist")]
161    PathNotFound {
162        /// The missing path.
163        path: PathBuf,
164    },
165}
166
167impl ConfigError {
168    /// Construct a [`ConfigError::InvalidValue`].
169    pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
170        ConfigError::InvalidValue {
171            field,
172            reason: reason.into(),
173        }
174    }
175}
176
177// ---------------------------------------------------------------------------
178// DatasetError
179// ---------------------------------------------------------------------------
180
181/// Errors produced while loading or accessing dataset samples.
182///
183/// Production training code MUST NOT silently suppress these errors.
184/// If data is missing, training must fail explicitly so the user is aware.
185/// The [`SyntheticCsiDataset`] is the only source of non-file-system data
186/// and is restricted to proof/testing use.
187///
188/// [`SyntheticCsiDataset`]: crate::dataset::SyntheticCsiDataset
189#[derive(Debug, Error)]
190pub enum DatasetError {
191    /// A required data file or directory was not found on disk.
192    #[error("Data not found at `{path}`: {message}")]
193    DataNotFound {
194        /// Path that was expected to contain data.
195        path: PathBuf,
196        /// Additional context.
197        message: String,
198    },
199
200    /// A file was found but its format or shape is wrong.
201    #[error("Invalid data format in `{path}`: {message}")]
202    InvalidFormat {
203        /// Path of the malformed file.
204        path: PathBuf,
205        /// Description of the problem.
206        message: String,
207    },
208
209    /// A low-level I/O error while reading a data file.
210    #[error("I/O error reading `{path}`: {source}")]
211    IoError {
212        /// Path being read when the error occurred.
213        path: PathBuf,
214        /// Underlying I/O error.
215        #[source]
216        source: std::io::Error,
217    },
218
219    /// The number of subcarriers in the file doesn't match expectations.
220    #[error("Subcarrier count mismatch in `{path}`: file has {found}, expected {expected}")]
221    SubcarrierMismatch {
222        /// Path of the offending file.
223        path: PathBuf,
224        /// Subcarrier count found in the file.
225        found: usize,
226        /// Subcarrier count expected.
227        expected: usize,
228    },
229
230    /// A sample index is out of bounds.
231    #[error("Index {idx} out of bounds (dataset has {len} samples)")]
232    IndexOutOfBounds {
233        /// The requested index.
234        idx: usize,
235        /// Total length of the dataset.
236        len: usize,
237    },
238
239    /// A numpy array file could not be parsed.
240    #[error("NumPy read error in `{path}`: {message}")]
241    NpyReadError {
242        /// Path of the `.npy` file.
243        path: PathBuf,
244        /// Error description.
245        message: String,
246    },
247
248    /// Metadata for a subject is missing or malformed.
249    #[error("Metadata error for subject {subject_id}: {message}")]
250    MetadataError {
251        /// Subject whose metadata was invalid.
252        subject_id: u32,
253        /// Description of the problem.
254        message: String,
255    },
256
257    /// A data format error (e.g. wrong numpy shape) occurred.
258    ///
259    /// This is a convenience variant for short-form error messages where
260    /// the full path context is not available.
261    #[error("File format error: {0}")]
262    Format(String),
263
264    /// The data directory does not exist.
265    #[error("Directory not found: {path}")]
266    DirectoryNotFound {
267        /// The path that was not found.
268        path: String,
269    },
270
271    /// No subjects matching the requested IDs were found.
272    #[error("No subjects found in `{data_dir}` for IDs: {requested:?}")]
273    NoSubjectsFound {
274        /// Root data directory.
275        data_dir: PathBuf,
276        /// IDs that were requested.
277        requested: Vec<u32>,
278    },
279
280    /// An I/O error that carries no path context.
281    #[error("IO error: {0}")]
282    Io(#[from] std::io::Error),
283
284    /// A train/test split is invalid — it leaks information across the boundary
285    /// (a subject appears in both partitions, or a window is shared) or is
286    /// degenerate (an empty partition). ADR-155 §Tier-1.2.
287    #[error("Invalid split: {0}")]
288    InvalidSplit(String),
289}
290
291impl DatasetError {
292    /// Construct a [`DatasetError::DataNotFound`].
293    pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
294        DatasetError::DataNotFound {
295            path: path.into(),
296            message: msg.into(),
297        }
298    }
299
300    /// Construct a [`DatasetError::InvalidFormat`].
301    pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
302        DatasetError::InvalidFormat {
303            path: path.into(),
304            message: msg.into(),
305        }
306    }
307
308    /// Construct a [`DatasetError::IoError`].
309    pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
310        DatasetError::IoError {
311            path: path.into(),
312            source,
313        }
314    }
315
316    /// Construct a [`DatasetError::SubcarrierMismatch`].
317    pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
318        DatasetError::SubcarrierMismatch {
319            path: path.into(),
320            found,
321            expected,
322        }
323    }
324
325    /// Construct a [`DatasetError::NpyReadError`].
326    pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
327        DatasetError::NpyReadError {
328            path: path.into(),
329            message: msg.into(),
330        }
331    }
332}
333
334// ---------------------------------------------------------------------------
335// SubcarrierError
336// ---------------------------------------------------------------------------
337
338/// Errors produced by the subcarrier resampling / interpolation functions.
339#[derive(Debug, Error)]
340pub enum SubcarrierError {
341    /// The source or destination count is zero.
342    #[error("Subcarrier count must be >= 1, got {count}")]
343    ZeroCount {
344        /// The offending count.
345        count: usize,
346    },
347
348    /// The array's last dimension does not match the declared source count.
349    #[error(
350        "Subcarrier shape mismatch: last dim is {actual_sc} but src_n={expected_sc} \
351         (full shape: {shape:?})"
352    )]
353    InputShapeMismatch {
354        /// Expected subcarrier count.
355        expected_sc: usize,
356        /// Actual last-dimension size.
357        actual_sc: usize,
358        /// Full shape of the input.
359        shape: Vec<usize>,
360    },
361
362    /// The requested interpolation method is not yet implemented.
363    #[error("Interpolation method `{method}` is not implemented")]
364    MethodNotImplemented {
365        /// Name of the unsupported method.
366        method: String,
367    },
368
369    /// `src_n == dst_n` — no resampling needed.
370    #[error("src_n == dst_n == {count}; call interpolate only when counts differ")]
371    NopInterpolation {
372        /// The equal count.
373        count: usize,
374    },
375
376    /// A numerical error during interpolation.
377    #[error("Numerical error: {0}")]
378    NumericalError(String),
379}
380
381impl SubcarrierError {
382    /// Construct a [`SubcarrierError::NumericalError`].
383    pub fn numerical<S: Into<String>>(msg: S) -> Self {
384        SubcarrierError::NumericalError(msg.into())
385    }
386}
387
388// ---------------------------------------------------------------------------
389// MaeError
390// ---------------------------------------------------------------------------
391
392/// Errors produced by the MAE pretraining patchify / masking functions
393/// ([`crate::mae`], ADR-152 §2.3).
394#[derive(Debug, Error)]
395pub enum MaeError {
396    /// The flat window buffer does not match the declared `time × subc` shape.
397    #[error(
398        "Window length {actual} does not match time × subcarriers = \
399         {time} × {subc} = {expected}"
400    )]
401    WindowShapeMismatch {
402        /// Declared time dimension.
403        time: usize,
404        /// Declared subcarrier dimension.
405        subc: usize,
406        /// Expected buffer length (`time * subc`).
407        expected: usize,
408        /// Actual buffer length.
409        actual: usize,
410    },
411
412    /// A patch dimension is larger than the window along that axis.
413    #[error("Patch {axis} extent {patch} exceeds window {axis} extent {window}")]
414    PatchExceedsWindow {
415        /// Axis name (`"time"` or `"subcarrier"`).
416        axis: &'static str,
417        /// Patch extent along the axis.
418        patch: usize,
419        /// Window extent along the axis.
420        window: usize,
421    },
422
423    /// The window is not an exact multiple of the patch extent along an axis.
424    ///
425    /// Patchification never silently truncates; crop the window to `crop`
426    /// (the largest divisible extent) or change the patch size.
427    #[error(
428        "Window {axis} extent {window} is not divisible by patch {axis} extent \
429         {patch} (remainder {remainder}); crop the window to {crop} or change \
430         the patch size"
431    )]
432    NotDivisible {
433        /// Axis name (`"time"` or `"subcarrier"`).
434        axis: &'static str,
435        /// Window extent along the axis.
436        window: usize,
437        /// Patch extent along the axis.
438        patch: usize,
439        /// `window % patch`.
440        remainder: usize,
441        /// Largest divisible extent (`window - remainder`).
442        crop: usize,
443    },
444
445    /// The mask ratio is not a finite value strictly inside `(0, 1)` — the
446    /// same rule as [`MaePretrainConfig::validate`]. A NaN ratio must never
447    /// silently mask zero patches, and ratios ≤ 0 / ≥ 1 degenerate to
448    /// all-visible / all-masked grids.
449    ///
450    /// [`MaePretrainConfig::validate`]: crate::mae::MaePretrainConfig::validate
451    #[error("Invalid mask ratio {ratio}: must be finite and strictly inside (0, 1)")]
452    InvalidMaskRatio {
453        /// The offending ratio.
454        ratio: f64,
455    },
456
457    /// A NaN or ±inf CSI value was found; corrupted input must be cleaned
458    /// upstream, never masked over.
459    #[error("Non-finite CSI value {value} at (t={row}, sc={col})")]
460    NonFiniteValue {
461        /// Time index of the offending value.
462        row: usize,
463        /// Subcarrier index of the offending value.
464        col: usize,
465        /// The non-finite value itself.
466        value: f32,
467    },
468}