1use std::borrow::Cow;
2use thiserror::Error;
3use veloq_core::query::NameMatchError;
4use veloq_core::time::TimeParseError;
5use veloq_core::{ErrorCode, VeloqDiagnostic};
6
7pub type PytorchQueryResult<T> = Result<T, PytorchQueryError>;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SqlPhase {
11 Prepare,
12 Query,
13 Read,
14}
15
16impl SqlPhase {
17 fn code(self) -> ErrorCode {
18 match self {
19 Self::Prepare => ErrorCode::new("pytorch.query.sql-prepare"),
20 Self::Query => ErrorCode::new("pytorch.query.sql-query"),
21 Self::Read => ErrorCode::new("pytorch.query.sql-read"),
22 }
23 }
24
25 fn action(self) -> &'static str {
26 match self {
27 Self::Prepare => "preparing",
28 Self::Query => "querying",
29 Self::Read => "reading",
30 }
31 }
32
33 fn row_suffix(self) -> &'static str {
34 match self {
35 Self::Prepare | Self::Query => "",
36 Self::Read => " row",
37 }
38 }
39}
40
41#[derive(Debug, Error)]
42pub enum PytorchQueryError {
43 #[error(
44 "unknown pytorch --type `{token}`; expected one of: cpu-op, annotation, step, runtime, driver, kernel, memcpy, memset, memory, python, comm, all"
45 )]
46 UnknownType { token: String },
47 #[error("--type must list at least one event type")]
48 EmptyTypeSelection,
49 #[error("pytorch trace has multiple ranks; use `--rank <n>` or `--all-ranks`")]
50 MultiRankRequiresScope,
51 #[error(
52 "--limit must be at least 1 (limit=0 would suppress total_matched too); use `--limit 1` for one row + totals"
53 )]
54 LimitTooSmall,
55 #[error("--name and --name-regex are mutually exclusive")]
56 MutuallyExclusiveNameFilters,
57 #[error("invalid --name glob `{pattern}`")]
58 InvalidNameGlob {
59 pattern: String,
60 #[source]
61 source: regex::Error,
62 },
63 #[error("invalid --name-regex `{pattern}`")]
64 InvalidNameRegex {
65 pattern: String,
66 #[source]
67 source: regex::Error,
68 },
69 #[error("`--from` and `--to` must be set together (got only one)")]
70 MissingTimeBound,
71 #[error("invalid --from `{value}`")]
72 InvalidFrom {
73 value: String,
74 #[source]
75 source: TimeParseError,
76 },
77 #[error("invalid --to `{value}`")]
78 InvalidTo {
79 value: String,
80 #[source]
81 source: TimeParseError,
82 },
83 #[error("time window end ({end} ns) must be greater than start ({start} ns)")]
84 EmptyTimeWindow { start: i64, end: i64 },
85 #[error("--interval must be greater than 0 ns")]
86 IntervalTooSmall,
87 #[error(
88 "unknown pytorch stats --group-by axis `{axis}`; expected name,type,step,rank,device,stream,shape,comm-kind,python-context,python-path"
89 )]
90 UnknownStatsGroupBy { axis: String },
91 #[error("pytorch --{axis} {value} requires {parents}; {suggestion}")]
92 LocalFilterParentRequired {
93 axis: &'static str,
94 value: i64,
95 parents: &'static str,
96 suggestion: &'static str,
97 },
98 #[error(
99 "pytorch stats --group-by {axis} requires parent axis {parents}; use a fixed scope or include `{group_by}` in --group-by"
100 )]
101 StatsGroupByParentRequired {
102 axis: &'static str,
103 parents: &'static str,
104 group_by: &'static str,
105 },
106 #[error(
107 "pytorch stats --group-by {axis} requires Python stack events, but this trace has none; re-capture with `torch.profiler.profile(..., with_stack=True)`"
108 )]
109 PythonStackMissing { axis: String },
110 #[error("unknown pytorch slices --group-by axis `{axis}`; expected name or step")]
111 UnknownSlicesGroupBy { axis: String },
112 #[error("opening in-memory DuckDB connection")]
113 SqlOpen {
114 #[source]
115 source: duckdb::Error,
116 },
117 #[error(
118 "{action} pytorch {area} SQL{row_suffix}",
119 action = phase.action(),
120 row_suffix = phase.row_suffix()
121 )]
122 Sql {
123 area: &'static str,
124 phase: SqlPhase,
125 label: String,
126 #[source]
127 source: duckdb::Error,
128 },
129 #[error("pytorch search SQL count does not fit in usize")]
130 SearchCountOverflow {
131 #[source]
132 source: std::num::TryFromIntError,
133 },
134 #[error("pytorch stats SQL count does not fit in usize")]
135 StatsCountOverflow {
136 #[source]
137 source: std::num::TryFromIntError,
138 },
139 #[error("pytorch timeline SQL count does not fit in usize")]
140 TimelineCountOverflow {
141 #[source]
142 source: std::num::TryFromIntError,
143 },
144 #[error("pytorch slices SQL count does not fit in usize")]
145 SlicesCountOverflow {
146 #[source]
147 source: std::num::TryFromIntError,
148 },
149 #[error("pytorch collectives SQL count does not fit in usize")]
150 CollectivesCountOverflow {
151 #[source]
152 source: std::num::TryFromIntError,
153 },
154 #[error("decoding pytorch inspect {field} JSON sidecar value")]
155 InspectJsonDecode {
156 field: String,
157 #[source]
158 source: serde_json::Error,
159 },
160 #[error("pytorch inspect trace_index does not fit in u32")]
161 InspectTraceIndexOverflow {
162 #[source]
163 source: std::num::TryFromIntError,
164 },
165}
166
167impl PytorchQueryError {
168 pub fn unknown_type(token: &str) -> Self {
169 Self::UnknownType {
170 token: token.to_string(),
171 }
172 }
173
174 pub fn invalid_name_glob(pattern: &str, source: regex::Error) -> Self {
175 Self::InvalidNameGlob {
176 pattern: pattern.to_string(),
177 source,
178 }
179 }
180
181 pub fn invalid_name_regex(pattern: &str, source: regex::Error) -> Self {
182 Self::InvalidNameRegex {
183 pattern: pattern.to_string(),
184 source,
185 }
186 }
187
188 pub fn invalid_from(value: &str, source: TimeParseError) -> Self {
189 Self::InvalidFrom {
190 value: value.to_string(),
191 source,
192 }
193 }
194
195 pub fn invalid_to(value: &str, source: TimeParseError) -> Self {
196 Self::InvalidTo {
197 value: value.to_string(),
198 source,
199 }
200 }
201
202 pub fn unknown_stats_group_by(axis: &str) -> Self {
203 Self::UnknownStatsGroupBy {
204 axis: axis.to_string(),
205 }
206 }
207
208 pub(crate) fn local_filter_parent_required(
209 axis: &'static str,
210 value: i64,
211 parents: &'static str,
212 suggestion: &'static str,
213 ) -> Self {
214 Self::LocalFilterParentRequired {
215 axis,
216 value,
217 parents,
218 suggestion,
219 }
220 }
221
222 pub(crate) fn stats_group_by_parent_required(
223 axis: &'static str,
224 parents: &'static str,
225 group_by: &'static str,
226 ) -> Self {
227 Self::StatsGroupByParentRequired {
228 axis,
229 parents,
230 group_by,
231 }
232 }
233
234 pub fn python_stack_missing(axis: &str) -> Self {
235 Self::PythonStackMissing {
236 axis: axis.to_string(),
237 }
238 }
239
240 pub fn unknown_slices_group_by(axis: &str) -> Self {
241 Self::UnknownSlicesGroupBy {
242 axis: axis.to_string(),
243 }
244 }
245
246 pub(crate) fn from_name_match(source: NameMatchError) -> Self {
247 match source {
248 NameMatchError::Conflict => Self::MutuallyExclusiveNameFilters,
249 NameMatchError::InvalidGlob { pattern, source } => {
250 Self::InvalidNameGlob { pattern, source }
251 }
252 NameMatchError::InvalidRegex { pattern, source } => {
253 Self::InvalidNameRegex { pattern, source }
254 }
255 }
256 }
257
258 pub(crate) fn sql_open(source: duckdb::Error) -> Self {
259 Self::SqlOpen { source }
260 }
261
262 pub(crate) fn sql(
263 area: &'static str,
264 phase: SqlPhase,
265 label: impl Into<String>,
266 source: duckdb::Error,
267 ) -> Self {
268 Self::Sql {
269 area,
270 phase,
271 label: label.into(),
272 source,
273 }
274 }
275
276 pub(crate) fn sql_prepare(
277 area: &'static str,
278 label: impl Into<String>,
279 source: duckdb::Error,
280 ) -> Self {
281 Self::sql(area, SqlPhase::Prepare, label, source)
282 }
283
284 pub(crate) fn sql_query(
285 area: &'static str,
286 label: impl Into<String>,
287 source: duckdb::Error,
288 ) -> Self {
289 Self::sql(area, SqlPhase::Query, label, source)
290 }
291
292 pub(crate) fn sql_read(
293 area: &'static str,
294 label: impl Into<String>,
295 source: duckdb::Error,
296 ) -> Self {
297 Self::sql(area, SqlPhase::Read, label, source)
298 }
299
300 pub fn sql_parts(&self) -> Option<(&'static str, SqlPhase, &str)> {
301 match self {
302 Self::Sql {
303 area, phase, label, ..
304 } => Some((*area, *phase, label.as_str())),
305 _ => None,
306 }
307 }
308
309 pub(crate) fn search_count_overflow(source: std::num::TryFromIntError) -> Self {
310 Self::SearchCountOverflow { source }
311 }
312
313 pub(crate) fn stats_count_overflow(source: std::num::TryFromIntError) -> Self {
314 Self::StatsCountOverflow { source }
315 }
316
317 pub(crate) fn timeline_count_overflow(source: std::num::TryFromIntError) -> Self {
318 Self::TimelineCountOverflow { source }
319 }
320
321 pub(crate) fn slices_count_overflow(source: std::num::TryFromIntError) -> Self {
322 Self::SlicesCountOverflow { source }
323 }
324
325 pub(crate) fn collectives_count_overflow(source: std::num::TryFromIntError) -> Self {
326 Self::CollectivesCountOverflow { source }
327 }
328
329 pub(crate) fn inspect_json_decode(field: &str, source: serde_json::Error) -> Self {
330 Self::InspectJsonDecode {
331 field: field.to_string(),
332 source,
333 }
334 }
335
336 pub(crate) fn inspect_trace_index_overflow(source: std::num::TryFromIntError) -> Self {
337 Self::InspectTraceIndexOverflow { source }
338 }
339}
340
341impl VeloqDiagnostic for PytorchQueryError {
342 fn code(&self) -> ErrorCode {
343 match self {
344 Self::UnknownType { .. } => ErrorCode::new("pytorch.query.unknown-type"),
345 Self::EmptyTypeSelection => ErrorCode::new("pytorch.query.empty-type-selection"),
346 Self::MultiRankRequiresScope => ErrorCode::new("pytorch.query.rank-scope-required"),
347 Self::LimitTooSmall => ErrorCode::new("pytorch.query.limit-too-small"),
348 Self::MutuallyExclusiveNameFilters => {
349 ErrorCode::new("pytorch.query.name-filter-conflict")
350 }
351 Self::InvalidNameGlob { .. } => ErrorCode::new("pytorch.query.invalid-name-glob"),
352 Self::InvalidNameRegex { .. } => ErrorCode::new("pytorch.query.invalid-name-regex"),
353 Self::MissingTimeBound => ErrorCode::new("pytorch.query.missing-time-bound"),
354 Self::InvalidFrom { .. } => ErrorCode::new("pytorch.query.invalid-from"),
355 Self::InvalidTo { .. } => ErrorCode::new("pytorch.query.invalid-to"),
356 Self::EmptyTimeWindow { .. } => ErrorCode::new("pytorch.query.empty-time-window"),
357 Self::IntervalTooSmall => ErrorCode::new("pytorch.query.interval-too-small"),
358 Self::UnknownStatsGroupBy { .. } => {
359 ErrorCode::new("pytorch.query.unknown-stats-group-by")
360 }
361 Self::LocalFilterParentRequired { .. } => {
362 ErrorCode::new("pytorch.query.local-filter-parent-required")
363 }
364 Self::StatsGroupByParentRequired { .. } => {
365 ErrorCode::new("pytorch.query.stats-group-by-parent-required")
366 }
367 Self::PythonStackMissing { .. } => ErrorCode::new("pytorch.query.python-stack-missing"),
368 Self::UnknownSlicesGroupBy { .. } => {
369 ErrorCode::new("pytorch.query.unknown-slices-group-by")
370 }
371 Self::SqlOpen { .. } => ErrorCode::new("pytorch.query.sql-open"),
372 Self::Sql { phase, .. } => phase.code(),
373 Self::SearchCountOverflow { .. } => {
374 ErrorCode::new("pytorch.query.search-count-overflow")
375 }
376 Self::StatsCountOverflow { .. } => ErrorCode::new("pytorch.query.stats-count-overflow"),
377 Self::TimelineCountOverflow { .. } => {
378 ErrorCode::new("pytorch.query.timeline-count-overflow")
379 }
380 Self::SlicesCountOverflow { .. } => {
381 ErrorCode::new("pytorch.query.slices-count-overflow")
382 }
383 Self::CollectivesCountOverflow { .. } => {
384 ErrorCode::new("pytorch.query.collectives-count-overflow")
385 }
386 Self::InspectJsonDecode { .. } => ErrorCode::new("pytorch.query.inspect-json-decode"),
387 Self::InspectTraceIndexOverflow { .. } => {
388 ErrorCode::new("pytorch.query.inspect-trace-index-overflow")
389 }
390 }
391 }
392
393 fn hint(&self) -> Option<Cow<'_, str>> {
394 match self {
395 Self::MultiRankRequiresScope => Some(Cow::Borrowed(
396 "Rerun with `--all-ranks` for an explicit aggregate, or `--rank 0` for one rank",
397 )),
398 _ => None,
399 }
400 }
401}