Skip to main content

veloq_pytorch_query/
error.rs

1use std::borrow::Cow;
2use thiserror::Error;
3use veloq_core::query::NameMatchError;
4use veloq_core::time::TimeParseError;
5use veloq_core::{ErrorCode, VeloqDiagnostic};
6
7pub type PytorchQueryResult<T> = Result<T, PytorchQueryError>;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SqlPhase {
11    Prepare,
12    Query,
13    Read,
14}
15
16impl SqlPhase {
17    fn code(self) -> ErrorCode {
18        match self {
19            Self::Prepare => ErrorCode::new("pytorch.query.sql-prepare"),
20            Self::Query => ErrorCode::new("pytorch.query.sql-query"),
21            Self::Read => ErrorCode::new("pytorch.query.sql-read"),
22        }
23    }
24
25    fn action(self) -> &'static str {
26        match self {
27            Self::Prepare => "preparing",
28            Self::Query => "querying",
29            Self::Read => "reading",
30        }
31    }
32
33    fn row_suffix(self) -> &'static str {
34        match self {
35            Self::Prepare | Self::Query => "",
36            Self::Read => " row",
37        }
38    }
39}
40
41#[derive(Debug, Error)]
42pub enum PytorchQueryError {
43    #[error(
44        "unknown pytorch --type `{token}`; expected one of: cpu-op, annotation, step, runtime, driver, kernel, memcpy, memset, memory, python, comm, all"
45    )]
46    UnknownType { token: String },
47    #[error("--type must list at least one event type")]
48    EmptyTypeSelection,
49    #[error("pytorch trace has multiple ranks; use `--rank <n>` or `--all-ranks`")]
50    MultiRankRequiresScope,
51    #[error(
52        "--limit must be at least 1 (limit=0 would suppress total_matched too); use `--limit 1` for one row + totals"
53    )]
54    LimitTooSmall,
55    #[error("--name and --name-regex are mutually exclusive")]
56    MutuallyExclusiveNameFilters,
57    #[error("invalid --name glob `{pattern}`")]
58    InvalidNameGlob {
59        pattern: String,
60        #[source]
61        source: regex::Error,
62    },
63    #[error("invalid --name-regex `{pattern}`")]
64    InvalidNameRegex {
65        pattern: String,
66        #[source]
67        source: regex::Error,
68    },
69    #[error("`--from` and `--to` must be set together (got only one)")]
70    MissingTimeBound,
71    #[error("invalid --from `{value}`")]
72    InvalidFrom {
73        value: String,
74        #[source]
75        source: TimeParseError,
76    },
77    #[error("invalid --to `{value}`")]
78    InvalidTo {
79        value: String,
80        #[source]
81        source: TimeParseError,
82    },
83    #[error("time window end ({end} ns) must be greater than start ({start} ns)")]
84    EmptyTimeWindow { start: i64, end: i64 },
85    #[error("--interval must be greater than 0 ns")]
86    IntervalTooSmall,
87    #[error(
88        "unknown pytorch stats --group-by axis `{axis}`; expected name,type,step,rank,device,stream,shape,comm-kind,python-context,python-path"
89    )]
90    UnknownStatsGroupBy { axis: String },
91    #[error("pytorch --{axis} {value} requires {parents}; {suggestion}")]
92    LocalFilterParentRequired {
93        axis: &'static str,
94        value: i64,
95        parents: &'static str,
96        suggestion: &'static str,
97    },
98    #[error(
99        "pytorch stats --group-by {axis} requires parent axis {parents}; use a fixed scope or include `{group_by}` in --group-by"
100    )]
101    StatsGroupByParentRequired {
102        axis: &'static str,
103        parents: &'static str,
104        group_by: &'static str,
105    },
106    #[error(
107        "pytorch stats --group-by {axis} requires Python stack events, but this trace has none; re-capture with `torch.profiler.profile(..., with_stack=True)`"
108    )]
109    PythonStackMissing { axis: String },
110    #[error("unknown pytorch slices --group-by axis `{axis}`; expected name or step")]
111    UnknownSlicesGroupBy { axis: String },
112    #[error("opening in-memory DuckDB connection")]
113    SqlOpen {
114        #[source]
115        source: duckdb::Error,
116    },
117    #[error(
118        "{action} pytorch {area} SQL{row_suffix}",
119        action = phase.action(),
120        row_suffix = phase.row_suffix()
121    )]
122    Sql {
123        area: &'static str,
124        phase: SqlPhase,
125        label: String,
126        #[source]
127        source: duckdb::Error,
128    },
129    #[error("pytorch search SQL count does not fit in usize")]
130    SearchCountOverflow {
131        #[source]
132        source: std::num::TryFromIntError,
133    },
134    #[error("pytorch stats SQL count does not fit in usize")]
135    StatsCountOverflow {
136        #[source]
137        source: std::num::TryFromIntError,
138    },
139    #[error("pytorch timeline SQL count does not fit in usize")]
140    TimelineCountOverflow {
141        #[source]
142        source: std::num::TryFromIntError,
143    },
144    #[error("pytorch slices SQL count does not fit in usize")]
145    SlicesCountOverflow {
146        #[source]
147        source: std::num::TryFromIntError,
148    },
149    #[error("pytorch collectives SQL count does not fit in usize")]
150    CollectivesCountOverflow {
151        #[source]
152        source: std::num::TryFromIntError,
153    },
154    #[error("decoding pytorch inspect {field} JSON sidecar value")]
155    InspectJsonDecode {
156        field: String,
157        #[source]
158        source: serde_json::Error,
159    },
160    #[error("pytorch inspect trace_index does not fit in u32")]
161    InspectTraceIndexOverflow {
162        #[source]
163        source: std::num::TryFromIntError,
164    },
165}
166
167impl PytorchQueryError {
168    pub fn unknown_type(token: &str) -> Self {
169        Self::UnknownType {
170            token: token.to_string(),
171        }
172    }
173
174    pub fn invalid_name_glob(pattern: &str, source: regex::Error) -> Self {
175        Self::InvalidNameGlob {
176            pattern: pattern.to_string(),
177            source,
178        }
179    }
180
181    pub fn invalid_name_regex(pattern: &str, source: regex::Error) -> Self {
182        Self::InvalidNameRegex {
183            pattern: pattern.to_string(),
184            source,
185        }
186    }
187
188    pub fn invalid_from(value: &str, source: TimeParseError) -> Self {
189        Self::InvalidFrom {
190            value: value.to_string(),
191            source,
192        }
193    }
194
195    pub fn invalid_to(value: &str, source: TimeParseError) -> Self {
196        Self::InvalidTo {
197            value: value.to_string(),
198            source,
199        }
200    }
201
202    pub fn unknown_stats_group_by(axis: &str) -> Self {
203        Self::UnknownStatsGroupBy {
204            axis: axis.to_string(),
205        }
206    }
207
208    pub(crate) fn local_filter_parent_required(
209        axis: &'static str,
210        value: i64,
211        parents: &'static str,
212        suggestion: &'static str,
213    ) -> Self {
214        Self::LocalFilterParentRequired {
215            axis,
216            value,
217            parents,
218            suggestion,
219        }
220    }
221
222    pub(crate) fn stats_group_by_parent_required(
223        axis: &'static str,
224        parents: &'static str,
225        group_by: &'static str,
226    ) -> Self {
227        Self::StatsGroupByParentRequired {
228            axis,
229            parents,
230            group_by,
231        }
232    }
233
234    pub fn python_stack_missing(axis: &str) -> Self {
235        Self::PythonStackMissing {
236            axis: axis.to_string(),
237        }
238    }
239
240    pub fn unknown_slices_group_by(axis: &str) -> Self {
241        Self::UnknownSlicesGroupBy {
242            axis: axis.to_string(),
243        }
244    }
245
246    pub(crate) fn from_name_match(source: NameMatchError) -> Self {
247        match source {
248            NameMatchError::Conflict => Self::MutuallyExclusiveNameFilters,
249            NameMatchError::InvalidGlob { pattern, source } => {
250                Self::InvalidNameGlob { pattern, source }
251            }
252            NameMatchError::InvalidRegex { pattern, source } => {
253                Self::InvalidNameRegex { pattern, source }
254            }
255        }
256    }
257
258    pub(crate) fn sql_open(source: duckdb::Error) -> Self {
259        Self::SqlOpen { source }
260    }
261
262    pub(crate) fn sql(
263        area: &'static str,
264        phase: SqlPhase,
265        label: impl Into<String>,
266        source: duckdb::Error,
267    ) -> Self {
268        Self::Sql {
269            area,
270            phase,
271            label: label.into(),
272            source,
273        }
274    }
275
276    pub(crate) fn sql_prepare(
277        area: &'static str,
278        label: impl Into<String>,
279        source: duckdb::Error,
280    ) -> Self {
281        Self::sql(area, SqlPhase::Prepare, label, source)
282    }
283
284    pub(crate) fn sql_query(
285        area: &'static str,
286        label: impl Into<String>,
287        source: duckdb::Error,
288    ) -> Self {
289        Self::sql(area, SqlPhase::Query, label, source)
290    }
291
292    pub(crate) fn sql_read(
293        area: &'static str,
294        label: impl Into<String>,
295        source: duckdb::Error,
296    ) -> Self {
297        Self::sql(area, SqlPhase::Read, label, source)
298    }
299
300    pub fn sql_parts(&self) -> Option<(&'static str, SqlPhase, &str)> {
301        match self {
302            Self::Sql {
303                area, phase, label, ..
304            } => Some((*area, *phase, label.as_str())),
305            _ => None,
306        }
307    }
308
309    pub(crate) fn search_count_overflow(source: std::num::TryFromIntError) -> Self {
310        Self::SearchCountOverflow { source }
311    }
312
313    pub(crate) fn stats_count_overflow(source: std::num::TryFromIntError) -> Self {
314        Self::StatsCountOverflow { source }
315    }
316
317    pub(crate) fn timeline_count_overflow(source: std::num::TryFromIntError) -> Self {
318        Self::TimelineCountOverflow { source }
319    }
320
321    pub(crate) fn slices_count_overflow(source: std::num::TryFromIntError) -> Self {
322        Self::SlicesCountOverflow { source }
323    }
324
325    pub(crate) fn collectives_count_overflow(source: std::num::TryFromIntError) -> Self {
326        Self::CollectivesCountOverflow { source }
327    }
328
329    pub(crate) fn inspect_json_decode(field: &str, source: serde_json::Error) -> Self {
330        Self::InspectJsonDecode {
331            field: field.to_string(),
332            source,
333        }
334    }
335
336    pub(crate) fn inspect_trace_index_overflow(source: std::num::TryFromIntError) -> Self {
337        Self::InspectTraceIndexOverflow { source }
338    }
339}
340
341impl VeloqDiagnostic for PytorchQueryError {
342    fn code(&self) -> ErrorCode {
343        match self {
344            Self::UnknownType { .. } => ErrorCode::new("pytorch.query.unknown-type"),
345            Self::EmptyTypeSelection => ErrorCode::new("pytorch.query.empty-type-selection"),
346            Self::MultiRankRequiresScope => ErrorCode::new("pytorch.query.rank-scope-required"),
347            Self::LimitTooSmall => ErrorCode::new("pytorch.query.limit-too-small"),
348            Self::MutuallyExclusiveNameFilters => {
349                ErrorCode::new("pytorch.query.name-filter-conflict")
350            }
351            Self::InvalidNameGlob { .. } => ErrorCode::new("pytorch.query.invalid-name-glob"),
352            Self::InvalidNameRegex { .. } => ErrorCode::new("pytorch.query.invalid-name-regex"),
353            Self::MissingTimeBound => ErrorCode::new("pytorch.query.missing-time-bound"),
354            Self::InvalidFrom { .. } => ErrorCode::new("pytorch.query.invalid-from"),
355            Self::InvalidTo { .. } => ErrorCode::new("pytorch.query.invalid-to"),
356            Self::EmptyTimeWindow { .. } => ErrorCode::new("pytorch.query.empty-time-window"),
357            Self::IntervalTooSmall => ErrorCode::new("pytorch.query.interval-too-small"),
358            Self::UnknownStatsGroupBy { .. } => {
359                ErrorCode::new("pytorch.query.unknown-stats-group-by")
360            }
361            Self::LocalFilterParentRequired { .. } => {
362                ErrorCode::new("pytorch.query.local-filter-parent-required")
363            }
364            Self::StatsGroupByParentRequired { .. } => {
365                ErrorCode::new("pytorch.query.stats-group-by-parent-required")
366            }
367            Self::PythonStackMissing { .. } => ErrorCode::new("pytorch.query.python-stack-missing"),
368            Self::UnknownSlicesGroupBy { .. } => {
369                ErrorCode::new("pytorch.query.unknown-slices-group-by")
370            }
371            Self::SqlOpen { .. } => ErrorCode::new("pytorch.query.sql-open"),
372            Self::Sql { phase, .. } => phase.code(),
373            Self::SearchCountOverflow { .. } => {
374                ErrorCode::new("pytorch.query.search-count-overflow")
375            }
376            Self::StatsCountOverflow { .. } => ErrorCode::new("pytorch.query.stats-count-overflow"),
377            Self::TimelineCountOverflow { .. } => {
378                ErrorCode::new("pytorch.query.timeline-count-overflow")
379            }
380            Self::SlicesCountOverflow { .. } => {
381                ErrorCode::new("pytorch.query.slices-count-overflow")
382            }
383            Self::CollectivesCountOverflow { .. } => {
384                ErrorCode::new("pytorch.query.collectives-count-overflow")
385            }
386            Self::InspectJsonDecode { .. } => ErrorCode::new("pytorch.query.inspect-json-decode"),
387            Self::InspectTraceIndexOverflow { .. } => {
388                ErrorCode::new("pytorch.query.inspect-trace-index-overflow")
389            }
390        }
391    }
392
393    fn hint(&self) -> Option<Cow<'_, str>> {
394        match self {
395            Self::MultiRankRequiresScope => Some(Cow::Borrowed(
396                "Rerun with `--all-ranks` for an explicit aggregate, or `--rank 0` for one rank",
397            )),
398            _ => None,
399        }
400    }
401}