wifi_densepose_train/error.rs
1//! Error types for the WiFi-DensePose training pipeline.
2//!
3//! This module is the single source of truth for all error types in the
4//! training crate. Every module that produces an error imports its error type
5//! from here rather than defining it inline, keeping the error hierarchy
6//! centralised and consistent.
7//!
8//! ## Hierarchy
9//!
10//! ```text
11//! TrainError (top-level)
12//! ├── ConfigError (config validation / file loading)
13//! ├── DatasetError (data loading, I/O, format)
14//! ├── SubcarrierError (frequency-axis resampling)
15//! └── MaeError (MAE patchify / masking — ADR-152 §2.3)
16//! ```
17
18use std::path::PathBuf;
19use thiserror::Error;
20
21// ---------------------------------------------------------------------------
22// TrainResult
23// ---------------------------------------------------------------------------
24
25/// Convenient `Result` alias used by orchestration-level functions.
26pub type TrainResult<T> = Result<T, TrainError>;
27
28// ---------------------------------------------------------------------------
29// TrainError — top-level aggregator
30// ---------------------------------------------------------------------------
31
32/// Top-level error type for the WiFi-DensePose training pipeline.
33///
34/// Orchestration-level functions (e.g. [`crate::trainer::Trainer`] methods)
35/// return `TrainResult<T>`. Lower-level functions in [`crate::config`] and
36/// [`crate::dataset`] return their own module-specific error types which are
37/// automatically coerced into `TrainError` via [`From`].
38#[derive(Debug, Error)]
39pub enum TrainError {
40 /// A configuration validation or loading error.
41 #[error("Configuration error: {0}")]
42 Config(#[from] ConfigError),
43
44 /// A dataset loading or access error.
45 #[error("Dataset error: {0}")]
46 Dataset(#[from] DatasetError),
47
48 /// A MAE pretraining patchify / masking error (ADR-152 §2.3).
49 #[error("MAE pretraining error: {0}")]
50 Mae(#[from] MaeError),
51
52 /// JSON (de)serialization error.
53 #[error("JSON error: {0}")]
54 Json(#[from] serde_json::Error),
55
56 /// The dataset is empty and no training can be performed.
57 #[error("Dataset is empty")]
58 EmptyDataset,
59
60 /// Index out of bounds when accessing dataset items.
61 #[error("Index {index} is out of bounds for dataset of length {len}")]
62 IndexOutOfBounds {
63 /// The out-of-range index.
64 index: usize,
65 /// The total number of items in the dataset.
66 len: usize,
67 },
68
69 /// A shape mismatch was detected between two tensors.
70 #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
71 ShapeMismatch {
72 /// Expected shape.
73 expected: Vec<usize>,
74 /// Actual shape.
75 actual: Vec<usize>,
76 },
77
78 /// A training step failed.
79 #[error("Training step failed: {0}")]
80 TrainingStep(String),
81
82 /// A checkpoint could not be saved or loaded.
83 #[error("Checkpoint error: {message} (path: {path:?})")]
84 Checkpoint {
85 /// Human-readable description.
86 message: String,
87 /// Path that was being accessed.
88 path: PathBuf,
89 },
90
91 /// Feature not yet implemented.
92 #[error("Not implemented: {0}")]
93 NotImplemented(String),
94}
95
96impl TrainError {
97 /// Construct a [`TrainError::TrainingStep`].
98 pub fn training_step<S: Into<String>>(msg: S) -> Self {
99 TrainError::TrainingStep(msg.into())
100 }
101
102 /// Construct a [`TrainError::Checkpoint`].
103 pub fn checkpoint<S: Into<String>>(msg: S, path: impl Into<PathBuf>) -> Self {
104 TrainError::Checkpoint {
105 message: msg.into(),
106 path: path.into(),
107 }
108 }
109
110 /// Construct a [`TrainError::NotImplemented`].
111 pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
112 TrainError::NotImplemented(msg.into())
113 }
114
115 /// Construct a [`TrainError::ShapeMismatch`].
116 pub fn shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
117 TrainError::ShapeMismatch { expected, actual }
118 }
119}
120
121// ---------------------------------------------------------------------------
122// ConfigError
123// ---------------------------------------------------------------------------
124
125/// Errors produced when loading or validating a [`TrainingConfig`].
126///
127/// [`TrainingConfig`]: crate::config::TrainingConfig
128#[derive(Debug, Error)]
129pub enum ConfigError {
130 /// A field has an invalid value.
131 #[error("Invalid value for `{field}`: {reason}")]
132 InvalidValue {
133 /// Name of the field.
134 field: &'static str,
135 /// Human-readable reason.
136 reason: String,
137 },
138
139 /// A configuration file could not be read from disk.
140 #[error("Cannot read config file `{path}`: {source}")]
141 FileRead {
142 /// Path that was being read.
143 path: PathBuf,
144 /// Underlying I/O error.
145 #[source]
146 source: std::io::Error,
147 },
148
149 /// A configuration file contains malformed JSON.
150 #[error("Cannot parse config file `{path}`: {source}")]
151 ParseError {
152 /// Path that was being parsed.
153 path: PathBuf,
154 /// Underlying JSON parse error.
155 #[source]
156 source: serde_json::Error,
157 },
158
159 /// A path referenced in the config does not exist.
160 #[error("Path `{path}` in config does not exist")]
161 PathNotFound {
162 /// The missing path.
163 path: PathBuf,
164 },
165}
166
167impl ConfigError {
168 /// Construct a [`ConfigError::InvalidValue`].
169 pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
170 ConfigError::InvalidValue {
171 field,
172 reason: reason.into(),
173 }
174 }
175}
176
177// ---------------------------------------------------------------------------
178// DatasetError
179// ---------------------------------------------------------------------------
180
181/// Errors produced while loading or accessing dataset samples.
182///
183/// Production training code MUST NOT silently suppress these errors.
184/// If data is missing, training must fail explicitly so the user is aware.
185/// The [`SyntheticCsiDataset`] is the only source of non-file-system data
186/// and is restricted to proof/testing use.
187///
188/// [`SyntheticCsiDataset`]: crate::dataset::SyntheticCsiDataset
189#[derive(Debug, Error)]
190pub enum DatasetError {
191 /// A required data file or directory was not found on disk.
192 #[error("Data not found at `{path}`: {message}")]
193 DataNotFound {
194 /// Path that was expected to contain data.
195 path: PathBuf,
196 /// Additional context.
197 message: String,
198 },
199
200 /// A file was found but its format or shape is wrong.
201 #[error("Invalid data format in `{path}`: {message}")]
202 InvalidFormat {
203 /// Path of the malformed file.
204 path: PathBuf,
205 /// Description of the problem.
206 message: String,
207 },
208
209 /// A low-level I/O error while reading a data file.
210 #[error("I/O error reading `{path}`: {source}")]
211 IoError {
212 /// Path being read when the error occurred.
213 path: PathBuf,
214 /// Underlying I/O error.
215 #[source]
216 source: std::io::Error,
217 },
218
219 /// The number of subcarriers in the file doesn't match expectations.
220 #[error("Subcarrier count mismatch in `{path}`: file has {found}, expected {expected}")]
221 SubcarrierMismatch {
222 /// Path of the offending file.
223 path: PathBuf,
224 /// Subcarrier count found in the file.
225 found: usize,
226 /// Subcarrier count expected.
227 expected: usize,
228 },
229
230 /// A sample index is out of bounds.
231 #[error("Index {idx} out of bounds (dataset has {len} samples)")]
232 IndexOutOfBounds {
233 /// The requested index.
234 idx: usize,
235 /// Total length of the dataset.
236 len: usize,
237 },
238
239 /// A numpy array file could not be parsed.
240 #[error("NumPy read error in `{path}`: {message}")]
241 NpyReadError {
242 /// Path of the `.npy` file.
243 path: PathBuf,
244 /// Error description.
245 message: String,
246 },
247
248 /// Metadata for a subject is missing or malformed.
249 #[error("Metadata error for subject {subject_id}: {message}")]
250 MetadataError {
251 /// Subject whose metadata was invalid.
252 subject_id: u32,
253 /// Description of the problem.
254 message: String,
255 },
256
257 /// A data format error (e.g. wrong numpy shape) occurred.
258 ///
259 /// This is a convenience variant for short-form error messages where
260 /// the full path context is not available.
261 #[error("File format error: {0}")]
262 Format(String),
263
264 /// The data directory does not exist.
265 #[error("Directory not found: {path}")]
266 DirectoryNotFound {
267 /// The path that was not found.
268 path: String,
269 },
270
271 /// No subjects matching the requested IDs were found.
272 #[error("No subjects found in `{data_dir}` for IDs: {requested:?}")]
273 NoSubjectsFound {
274 /// Root data directory.
275 data_dir: PathBuf,
276 /// IDs that were requested.
277 requested: Vec<u32>,
278 },
279
280 /// An I/O error that carries no path context.
281 #[error("IO error: {0}")]
282 Io(#[from] std::io::Error),
283
284 /// A train/test split is invalid — it leaks information across the boundary
285 /// (a subject appears in both partitions, or a window is shared) or is
286 /// degenerate (an empty partition). ADR-155 §Tier-1.2.
287 #[error("Invalid split: {0}")]
288 InvalidSplit(String),
289}
290
291impl DatasetError {
292 /// Construct a [`DatasetError::DataNotFound`].
293 pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
294 DatasetError::DataNotFound {
295 path: path.into(),
296 message: msg.into(),
297 }
298 }
299
300 /// Construct a [`DatasetError::InvalidFormat`].
301 pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
302 DatasetError::InvalidFormat {
303 path: path.into(),
304 message: msg.into(),
305 }
306 }
307
308 /// Construct a [`DatasetError::IoError`].
309 pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
310 DatasetError::IoError {
311 path: path.into(),
312 source,
313 }
314 }
315
316 /// Construct a [`DatasetError::SubcarrierMismatch`].
317 pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
318 DatasetError::SubcarrierMismatch {
319 path: path.into(),
320 found,
321 expected,
322 }
323 }
324
325 /// Construct a [`DatasetError::NpyReadError`].
326 pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
327 DatasetError::NpyReadError {
328 path: path.into(),
329 message: msg.into(),
330 }
331 }
332}
333
334// ---------------------------------------------------------------------------
335// SubcarrierError
336// ---------------------------------------------------------------------------
337
338/// Errors produced by the subcarrier resampling / interpolation functions.
339#[derive(Debug, Error)]
340pub enum SubcarrierError {
341 /// The source or destination count is zero.
342 #[error("Subcarrier count must be >= 1, got {count}")]
343 ZeroCount {
344 /// The offending count.
345 count: usize,
346 },
347
348 /// The array's last dimension does not match the declared source count.
349 #[error(
350 "Subcarrier shape mismatch: last dim is {actual_sc} but src_n={expected_sc} \
351 (full shape: {shape:?})"
352 )]
353 InputShapeMismatch {
354 /// Expected subcarrier count.
355 expected_sc: usize,
356 /// Actual last-dimension size.
357 actual_sc: usize,
358 /// Full shape of the input.
359 shape: Vec<usize>,
360 },
361
362 /// The requested interpolation method is not yet implemented.
363 #[error("Interpolation method `{method}` is not implemented")]
364 MethodNotImplemented {
365 /// Name of the unsupported method.
366 method: String,
367 },
368
369 /// `src_n == dst_n` — no resampling needed.
370 #[error("src_n == dst_n == {count}; call interpolate only when counts differ")]
371 NopInterpolation {
372 /// The equal count.
373 count: usize,
374 },
375
376 /// A numerical error during interpolation.
377 #[error("Numerical error: {0}")]
378 NumericalError(String),
379}
380
381impl SubcarrierError {
382 /// Construct a [`SubcarrierError::NumericalError`].
383 pub fn numerical<S: Into<String>>(msg: S) -> Self {
384 SubcarrierError::NumericalError(msg.into())
385 }
386}
387
388// ---------------------------------------------------------------------------
389// MaeError
390// ---------------------------------------------------------------------------
391
392/// Errors produced by the MAE pretraining patchify / masking functions
393/// ([`crate::mae`], ADR-152 §2.3).
394#[derive(Debug, Error)]
395pub enum MaeError {
396 /// The flat window buffer does not match the declared `time × subc` shape.
397 #[error(
398 "Window length {actual} does not match time × subcarriers = \
399 {time} × {subc} = {expected}"
400 )]
401 WindowShapeMismatch {
402 /// Declared time dimension.
403 time: usize,
404 /// Declared subcarrier dimension.
405 subc: usize,
406 /// Expected buffer length (`time * subc`).
407 expected: usize,
408 /// Actual buffer length.
409 actual: usize,
410 },
411
412 /// A patch dimension is larger than the window along that axis.
413 #[error("Patch {axis} extent {patch} exceeds window {axis} extent {window}")]
414 PatchExceedsWindow {
415 /// Axis name (`"time"` or `"subcarrier"`).
416 axis: &'static str,
417 /// Patch extent along the axis.
418 patch: usize,
419 /// Window extent along the axis.
420 window: usize,
421 },
422
423 /// The window is not an exact multiple of the patch extent along an axis.
424 ///
425 /// Patchification never silently truncates; crop the window to `crop`
426 /// (the largest divisible extent) or change the patch size.
427 #[error(
428 "Window {axis} extent {window} is not divisible by patch {axis} extent \
429 {patch} (remainder {remainder}); crop the window to {crop} or change \
430 the patch size"
431 )]
432 NotDivisible {
433 /// Axis name (`"time"` or `"subcarrier"`).
434 axis: &'static str,
435 /// Window extent along the axis.
436 window: usize,
437 /// Patch extent along the axis.
438 patch: usize,
439 /// `window % patch`.
440 remainder: usize,
441 /// Largest divisible extent (`window - remainder`).
442 crop: usize,
443 },
444
445 /// The mask ratio is not a finite value strictly inside `(0, 1)` — the
446 /// same rule as [`MaePretrainConfig::validate`]. A NaN ratio must never
447 /// silently mask zero patches, and ratios ≤ 0 / ≥ 1 degenerate to
448 /// all-visible / all-masked grids.
449 ///
450 /// [`MaePretrainConfig::validate`]: crate::mae::MaePretrainConfig::validate
451 #[error("Invalid mask ratio {ratio}: must be finite and strictly inside (0, 1)")]
452 InvalidMaskRatio {
453 /// The offending ratio.
454 ratio: f64,
455 },
456
457 /// A NaN or ±inf CSI value was found; corrupted input must be cleaned
458 /// upstream, never masked over.
459 #[error("Non-finite CSI value {value} at (t={row}, sc={col})")]
460 NonFiniteValue {
461 /// Time index of the offending value.
462 row: usize,
463 /// Subcarrier index of the offending value.
464 col: usize,
465 /// The non-finite value itself.
466 value: f32,
467 },
468}