1use sqlparser::ast::{
48 BinaryOperator, Expr, GroupByExpr, ObjectName, Query, Select, SelectItem, SetExpr, Statement,
49 TableFactor, UnaryOperator, Value,
50};
51use sqlparser::dialect::GenericDialect;
52use sqlparser::parser::Parser;
53
54#[derive(Debug, thiserror::Error)]
59pub enum SelectError {
60 #[error("SQL parse error: {0}")]
61 Parse(String),
62 #[error("unsupported SQL feature: {0}")]
63 UnsupportedFeature(String),
64 #[error("input format error: {0}")]
65 InputFormat(String),
66 #[error("row evaluation error: {0}")]
67 RowEval(String),
68}
69
70#[derive(Debug, Clone)]
75pub enum SelectInputFormat {
76 Csv { has_header: bool, delimiter: char },
77 JsonLines,
78}
79
80#[derive(Debug, Clone)]
81pub enum SelectOutputFormat {
82 Csv,
83 Json,
84}
85
86#[derive(Debug, Clone)]
91pub struct SelectQuery {
92 pub projection: Vec<SelectItem>,
95 pub where_clause: Option<Expr>,
96 pub from_alias: String,
100}
101
102pub fn parse_select(sql: &str) -> Result<SelectQuery, SelectError> {
107 let dialect = GenericDialect {};
108 let mut statements =
109 Parser::parse_sql(&dialect, sql).map_err(|e| SelectError::Parse(e.to_string()))?;
110 if statements.len() != 1 {
111 return Err(SelectError::Parse(format!(
112 "expected exactly one statement, got {}",
113 statements.len()
114 )));
115 }
116 let stmt = statements.pop().expect("len == 1");
117 let query = match stmt {
118 Statement::Query(q) => *q,
119 other => {
120 return Err(SelectError::UnsupportedFeature(format!(
121 "only SELECT statements are supported, got: {other:?}"
122 )));
123 }
124 };
125 let Query {
126 body,
127 order_by,
128 limit,
129 offset,
130 fetch,
131 locks,
132 with,
133 ..
134 } = query;
135 if with.is_some() {
136 return Err(SelectError::UnsupportedFeature("CTE / WITH".into()));
137 }
138 if order_by.is_some() {
139 return Err(SelectError::UnsupportedFeature("ORDER BY".into()));
140 }
141 if limit.is_some() {
142 return Err(SelectError::UnsupportedFeature("LIMIT".into()));
143 }
144 if offset.is_some() {
145 return Err(SelectError::UnsupportedFeature("OFFSET".into()));
146 }
147 if fetch.is_some() {
148 return Err(SelectError::UnsupportedFeature("FETCH".into()));
149 }
150 if !locks.is_empty() {
151 return Err(SelectError::UnsupportedFeature(
152 "FOR UPDATE / lock clauses".into(),
153 ));
154 }
155
156 let select = match *body {
157 SetExpr::Select(s) => *s,
158 SetExpr::Query(_) => {
159 return Err(SelectError::UnsupportedFeature("nested query".into()));
160 }
161 SetExpr::SetOperation { .. } => {
162 return Err(SelectError::UnsupportedFeature(
163 "set operation (UNION/INTERSECT/EXCEPT)".into(),
164 ));
165 }
166 other => {
167 return Err(SelectError::UnsupportedFeature(format!(
168 "unsupported SetExpr: {other:?}"
169 )));
170 }
171 };
172
173 let Select {
174 distinct,
175 top,
176 projection,
177 from,
178 selection,
179 group_by,
180 having,
181 named_window,
182 qualify,
183 cluster_by,
184 distribute_by,
185 sort_by,
186 prewhere,
187 connect_by,
188 ..
189 } = select;
190 if distinct.is_some() {
191 return Err(SelectError::UnsupportedFeature("DISTINCT".into()));
192 }
193 if top.is_some() {
194 return Err(SelectError::UnsupportedFeature("TOP".into()));
195 }
196 if having.is_some() {
197 return Err(SelectError::UnsupportedFeature("HAVING".into()));
198 }
199 if !named_window.is_empty() {
200 return Err(SelectError::UnsupportedFeature("WINDOW".into()));
201 }
202 if qualify.is_some() {
203 return Err(SelectError::UnsupportedFeature("QUALIFY".into()));
204 }
205 if !cluster_by.is_empty() || !distribute_by.is_empty() || !sort_by.is_empty() {
206 return Err(SelectError::UnsupportedFeature(
207 "CLUSTER BY / DISTRIBUTE BY / SORT BY".into(),
208 ));
209 }
210 if prewhere.is_some() {
211 return Err(SelectError::UnsupportedFeature("PREWHERE".into()));
212 }
213 if connect_by.is_some() {
214 return Err(SelectError::UnsupportedFeature("CONNECT BY".into()));
215 }
216 match group_by {
217 GroupByExpr::Expressions(ref exprs, ref mods) if exprs.is_empty() && mods.is_empty() => {}
218 _ => return Err(SelectError::UnsupportedFeature("GROUP BY".into())),
219 }
220
221 for item in &projection {
224 validate_projection_item(item)?;
225 }
226 if let Some(ref where_expr) = selection {
227 validate_where_expr(where_expr)?;
228 }
229
230 let from_alias = match from.as_slice() {
232 [twj] if twj.joins.is_empty() => match &twj.relation {
233 TableFactor::Table { name, alias, .. } => alias
234 .as_ref()
235 .map(|a| a.name.value.clone())
236 .unwrap_or_else(|| object_name_to_string(name)),
237 _ => {
238 return Err(SelectError::UnsupportedFeature(
239 "only `FROM s3object` (or aliased single table) is supported".into(),
240 ));
241 }
242 },
243 [] => "s3object".to_owned(),
244 _ => {
245 return Err(SelectError::UnsupportedFeature(
246 "JOIN / multiple FROM tables".into(),
247 ));
248 }
249 };
250
251 Ok(SelectQuery {
252 projection,
253 where_clause: selection,
254 from_alias,
255 })
256}
257
258fn object_name_to_string(name: &ObjectName) -> String {
259 name.0
260 .iter()
261 .map(|i| i.value.as_str())
262 .collect::<Vec<_>>()
263 .join(".")
264}
265
266fn validate_projection_item(item: &SelectItem) -> Result<(), SelectError> {
267 match item {
268 SelectItem::Wildcard(_) => Ok(()),
269 SelectItem::QualifiedWildcard(_, _) => Ok(()),
270 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
271 validate_simple_column_expr(e)
272 }
273 }
274}
275
276fn validate_simple_column_expr(expr: &Expr) -> Result<(), SelectError> {
277 match expr {
278 Expr::Identifier(_) | Expr::CompoundIdentifier(_) => Ok(()),
279 Expr::Function(_) => Err(SelectError::UnsupportedFeature(
280 "aggregate / scalar function in projection (only bare column references supported)"
281 .into(),
282 )),
283 Expr::Subquery(_) | Expr::Exists { .. } => Err(SelectError::UnsupportedFeature(
284 "subquery in projection".into(),
285 )),
286 _ => Err(SelectError::UnsupportedFeature(format!(
287 "unsupported projection expression: {expr}"
288 ))),
289 }
290}
291
292fn validate_where_expr(expr: &Expr) -> Result<(), SelectError> {
293 match expr {
294 Expr::Identifier(_) | Expr::CompoundIdentifier(_) | Expr::Value(_) => Ok(()),
295 Expr::Nested(inner) => validate_where_expr(inner),
296 Expr::UnaryOp { op, expr } => match op {
297 UnaryOperator::Not | UnaryOperator::Minus | UnaryOperator::Plus => {
298 validate_where_expr(expr)
299 }
300 other => Err(SelectError::UnsupportedFeature(format!(
301 "unsupported unary operator in WHERE: {other:?}"
302 ))),
303 },
304 Expr::BinaryOp { op, left, right } => match op {
305 BinaryOperator::Eq
306 | BinaryOperator::NotEq
307 | BinaryOperator::Lt
308 | BinaryOperator::LtEq
309 | BinaryOperator::Gt
310 | BinaryOperator::GtEq
311 | BinaryOperator::And
312 | BinaryOperator::Or => {
313 validate_where_expr(left)?;
314 validate_where_expr(right)
315 }
316 other => Err(SelectError::UnsupportedFeature(format!(
317 "unsupported binary operator in WHERE: {other:?}"
318 ))),
319 },
320 Expr::Like { expr, pattern, .. } => {
321 validate_where_expr(expr)?;
322 validate_where_expr(pattern)
323 }
324 Expr::IsNull(e) | Expr::IsNotNull(e) => validate_where_expr(e),
325 Expr::Function(_) => Err(SelectError::UnsupportedFeature(
326 "function call in WHERE".into(),
327 )),
328 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
329 Err(SelectError::UnsupportedFeature("subquery in WHERE".into()))
330 }
331 other => Err(SelectError::UnsupportedFeature(format!(
332 "unsupported WHERE expression: {other}"
333 ))),
334 }
335}
336
337pub struct CsvRow<'a> {
344 pub fields: Vec<&'a str>,
345 pub headers: Option<&'a [String]>,
346}
347
348impl CsvRow<'_> {
349 #[must_use]
353 pub fn get(&self, ident: &str) -> Option<&str> {
354 if let Some(stripped) = ident.strip_prefix('_')
355 && let Ok(n) = stripped.parse::<usize>()
356 && n >= 1
357 {
358 return self.fields.get(n - 1).copied();
359 }
360 if let Some(headers) = self.headers {
363 for (i, h) in headers.iter().enumerate() {
364 if h.eq_ignore_ascii_case(ident) {
365 return self.fields.get(i).copied();
366 }
367 }
368 }
369 None
370 }
371}
372
373#[derive(Debug, Clone)]
380enum Lit<'a> {
381 Null,
382 Bool(bool),
383 Int(i64),
384 Float(f64),
385 Str(std::borrow::Cow<'a, str>),
386}
387
388impl<'a> Lit<'a> {
389 fn from_str_value(s: &'a str) -> Lit<'a> {
390 Lit::Str(std::borrow::Cow::Borrowed(s))
391 }
392
393 fn truthy(&self) -> bool {
394 matches!(self, Lit::Bool(true))
395 }
396}
397
398pub fn evaluate_row(
403 query: &SelectQuery,
404 row: &CsvRow<'_>,
405) -> Result<Option<Vec<String>>, SelectError> {
406 if let Some(ref w) = query.where_clause {
407 let v = eval_expr(w, row)?;
408 if !v.truthy() {
409 return Ok(None);
410 }
411 }
412 let mut out = Vec::with_capacity(query.projection.len());
413 for item in &query.projection {
414 match item {
415 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
416 for f in &row.fields {
417 out.push((*f).to_owned());
418 }
419 }
420 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
421 let ident = expr_as_column(e)?;
422 let v = row
423 .get(&ident)
424 .ok_or_else(|| SelectError::RowEval(format!("column not found: {ident}")))?;
425 out.push(v.to_owned());
426 }
427 }
428 }
429 Ok(Some(out))
430}
431
432fn expr_as_column(expr: &Expr) -> Result<String, SelectError> {
433 match expr {
434 Expr::Identifier(i) => Ok(i.value.clone()),
435 Expr::CompoundIdentifier(parts) => parts
436 .last()
437 .map(|p| p.value.clone())
438 .ok_or_else(|| SelectError::RowEval("empty compound identifier".into())),
439 other => Err(SelectError::UnsupportedFeature(format!(
440 "non-column projection: {other}"
441 ))),
442 }
443}
444
445fn eval_expr<'a>(expr: &Expr, row: &'a CsvRow<'a>) -> Result<Lit<'a>, SelectError> {
446 match expr {
447 Expr::Nested(inner) => eval_expr(inner, row),
448 Expr::Identifier(i) => Ok(row.get(&i.value).map_or(Lit::Null, Lit::from_str_value)),
449 Expr::CompoundIdentifier(parts) => {
450 let last = parts
451 .last()
452 .ok_or_else(|| SelectError::RowEval("empty compound identifier".into()))?;
453 Ok(row.get(&last.value).map_or(Lit::Null, Lit::from_str_value))
454 }
455 Expr::Value(v) => value_to_lit(v),
456 Expr::UnaryOp { op, expr } => {
457 let v = eval_expr(expr, row)?;
458 match op {
459 UnaryOperator::Not => Ok(Lit::Bool(!v.truthy())),
460 UnaryOperator::Minus => match v {
461 Lit::Int(n) => Ok(Lit::Int(-n)),
462 Lit::Float(f) => Ok(Lit::Float(-f)),
463 other => Err(SelectError::RowEval(format!(
464 "cannot negate non-numeric value: {other:?}"
465 ))),
466 },
467 UnaryOperator::Plus => Ok(v),
468 other => Err(SelectError::UnsupportedFeature(format!(
469 "unsupported unary op: {other:?}"
470 ))),
471 }
472 }
473 Expr::BinaryOp { op, left, right } => {
474 let l = eval_expr(left, row)?;
475 let r = eval_expr(right, row)?;
476 eval_binary(op, &l, &r)
477 }
478 Expr::Like {
479 negated,
480 expr,
481 pattern,
482 escape_char,
483 } => {
484 if escape_char.is_some() {
485 return Err(SelectError::UnsupportedFeature("LIKE ESCAPE clause".into()));
486 }
487 let s_val = eval_expr(expr, row)?;
488 let p_val = eval_expr(pattern, row)?;
489 let s = lit_as_str(&s_val);
490 let p = lit_as_str(&p_val);
491 let m = like_match(s.as_ref(), p.as_ref());
492 Ok(Lit::Bool(if *negated { !m } else { m }))
493 }
494 Expr::IsNull(e) => Ok(Lit::Bool(matches!(eval_expr(e, row)?, Lit::Null))),
495 Expr::IsNotNull(e) => Ok(Lit::Bool(!matches!(eval_expr(e, row)?, Lit::Null))),
496 other => Err(SelectError::UnsupportedFeature(format!(
497 "unsupported expression in WHERE: {other}"
498 ))),
499 }
500}
501
502fn value_to_lit<'a>(v: &Value) -> Result<Lit<'a>, SelectError> {
503 match v {
504 Value::Number(s, _) => {
505 if let Ok(n) = s.parse::<i64>() {
506 Ok(Lit::Int(n))
507 } else if let Ok(f) = s.parse::<f64>() {
508 Ok(Lit::Float(f))
509 } else {
510 Err(SelectError::RowEval(format!("invalid number literal: {s}")))
511 }
512 }
513 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => {
514 Ok(Lit::Str(std::borrow::Cow::Owned(s.clone())))
515 }
516 Value::Boolean(b) => Ok(Lit::Bool(*b)),
517 Value::Null => Ok(Lit::Null),
518 other => Err(SelectError::UnsupportedFeature(format!(
519 "literal kind not supported: {other:?}"
520 ))),
521 }
522}
523
524fn lit_as_str<'a>(v: &Lit<'a>) -> std::borrow::Cow<'a, str> {
525 match v {
526 Lit::Null => std::borrow::Cow::Borrowed(""),
527 Lit::Bool(b) => std::borrow::Cow::Owned(if *b { "true" } else { "false" }.into()),
528 Lit::Int(n) => std::borrow::Cow::Owned(n.to_string()),
529 Lit::Float(f) => std::borrow::Cow::Owned(f.to_string()),
530 Lit::Str(s) => s.clone(),
531 }
532}
533
534fn lit_as_f64(v: &Lit<'_>) -> Option<f64> {
535 match v {
536 Lit::Int(n) => Some(*n as f64),
537 Lit::Float(f) => Some(*f),
538 Lit::Str(s) => s.parse::<f64>().ok(),
539 Lit::Bool(_) | Lit::Null => None,
540 }
541}
542
543fn eval_binary<'a>(op: &BinaryOperator, l: &Lit<'a>, r: &Lit<'a>) -> Result<Lit<'a>, SelectError> {
544 use BinaryOperator::*;
545 match op {
546 And => Ok(Lit::Bool(l.truthy() && r.truthy())),
547 Or => Ok(Lit::Bool(l.truthy() || r.truthy())),
548 Eq | NotEq | Lt | LtEq | Gt | GtEq => {
549 if matches!(l, Lit::Null) || matches!(r, Lit::Null) {
553 return Ok(Lit::Bool(false));
554 }
555 let cmp = if let (Some(a), Some(b)) = (lit_as_f64(l), lit_as_f64(r)) {
558 a.partial_cmp(&b)
559 } else {
560 let a = lit_as_str(l);
561 let b = lit_as_str(r);
562 Some(a.as_ref().cmp(b.as_ref()))
563 };
564 let ord =
565 cmp.ok_or_else(|| SelectError::RowEval("incomparable values (NaN?)".into()))?;
566 let res = match op {
567 Eq => ord == std::cmp::Ordering::Equal,
568 NotEq => ord != std::cmp::Ordering::Equal,
569 Lt => ord == std::cmp::Ordering::Less,
570 LtEq => ord != std::cmp::Ordering::Greater,
571 Gt => ord == std::cmp::Ordering::Greater,
572 GtEq => ord != std::cmp::Ordering::Less,
573 _ => unreachable!("guarded by outer match"),
574 };
575 Ok(Lit::Bool(res))
576 }
577 other => Err(SelectError::UnsupportedFeature(format!(
578 "unsupported binary operator: {other:?}"
579 ))),
580 }
581}
582
583fn like_match(s: &str, pattern: &str) -> bool {
587 let s_bytes: Vec<char> = s.chars().collect();
588 let p_bytes: Vec<char> = pattern.chars().collect();
589 let (mut si, mut pi) = (0usize, 0usize);
590 let (mut star, mut match_si) = (None::<usize>, 0usize);
591 while si < s_bytes.len() {
592 if pi < p_bytes.len() && (p_bytes[pi] == '_' || p_bytes[pi] == s_bytes[si]) {
593 si += 1;
594 pi += 1;
595 } else if pi < p_bytes.len() && p_bytes[pi] == '%' {
596 star = Some(pi);
597 match_si = si;
598 pi += 1;
599 } else if let Some(sp) = star {
600 pi = sp + 1;
601 match_si += 1;
602 si = match_si;
603 } else {
604 return false;
605 }
606 }
607 while pi < p_bytes.len() && p_bytes[pi] == '%' {
608 pi += 1;
609 }
610 pi == p_bytes.len()
611}
612
613pub fn run_select_csv(
621 sql: &str,
622 body: &[u8],
623 input: SelectInputFormat,
624 output: SelectOutputFormat,
625) -> Result<Vec<u8>, SelectError> {
626 if matches!(output, SelectOutputFormat::Csv)
635 && let Some(filtered) = select_gpu(sql, body, &input)
636 {
637 return Ok(filtered);
638 }
639
640 let (has_header, delim) = match input {
641 SelectInputFormat::Csv {
642 has_header,
643 delimiter,
644 } => (has_header, delimiter),
645 SelectInputFormat::JsonLines => {
646 return Err(SelectError::InputFormat(
647 "run_select_csv called with JsonLines input — use run_select_jsonlines".into(),
648 ));
649 }
650 };
651 let query = parse_select(sql)?;
652
653 let mut rdr = csv::ReaderBuilder::new()
654 .has_headers(has_header)
655 .delimiter(delim as u8)
656 .flexible(true)
657 .from_reader(body);
658
659 let headers_owned: Option<Vec<String>> = if has_header {
660 let h = rdr
661 .headers()
662 .map_err(|e| SelectError::InputFormat(format!("CSV headers: {e}")))?
663 .iter()
664 .map(|s| s.to_owned())
665 .collect();
666 Some(h)
667 } else {
668 None
669 };
670 let header_slice: Option<&[String]> = headers_owned.as_deref();
671
672 let mut out = Vec::with_capacity(body.len() / 2);
673 for record in rdr.records() {
674 let record = record.map_err(|e| SelectError::InputFormat(format!("CSV record: {e}")))?;
675 let fields: Vec<&str> = record.iter().collect();
676 let row = CsvRow {
677 fields,
678 headers: header_slice,
679 };
680 if let Some(values) = evaluate_row(&query, &row)? {
681 write_output_row(&query, &values, &output, &mut out)?;
682 }
683 }
684 Ok(out)
685}
686
687pub fn run_select_jsonlines(
692 sql: &str,
693 body: &[u8],
694 output: SelectOutputFormat,
695) -> Result<Vec<u8>, SelectError> {
696 let query = parse_select(sql)?;
697 let text = std::str::from_utf8(body)
698 .map_err(|e| SelectError::InputFormat(format!("body is not valid UTF-8: {e}")))?;
699 let mut out = Vec::with_capacity(body.len() / 2);
700 for (lineno, line) in text.lines().enumerate() {
701 let line = line.trim();
702 if line.is_empty() {
703 continue;
704 }
705 let v: serde_json::Value = serde_json::from_str(line).map_err(|e| {
706 SelectError::InputFormat(format!("JSON parse on line {}: {e}", lineno + 1))
707 })?;
708 let obj = v.as_object().ok_or_else(|| {
709 SelectError::InputFormat(format!(
710 "JSON Lines requires top-level object, line {} was not an object",
711 lineno + 1
712 ))
713 })?;
714 let headers: Vec<String> = obj.keys().cloned().collect();
717 let raw_strs: Vec<String> = obj
718 .values()
719 .map(|jv| match jv {
720 serde_json::Value::String(s) => s.clone(),
721 other => other.to_string(),
722 })
723 .collect();
724 let fields: Vec<&str> = raw_strs.iter().map(|s| s.as_str()).collect();
725 let row = CsvRow {
726 fields,
727 headers: Some(headers.as_slice()),
728 };
729 if let Some(values) = evaluate_row(&query, &row)? {
730 write_jsonlines_row(&query, &headers, &values, &output, &mut out)?;
731 }
732 }
733 Ok(out)
734}
735
736fn write_output_row(
737 query: &SelectQuery,
738 values: &[String],
739 output: &SelectOutputFormat,
740 out: &mut Vec<u8>,
741) -> Result<(), SelectError> {
742 match output {
743 SelectOutputFormat::Csv => {
744 let mut wtr = csv::WriterBuilder::new()
745 .terminator(csv::Terminator::CRLF)
746 .from_writer(Vec::new());
747 wtr.write_record(values.iter().map(String::as_str))
748 .map_err(|e| SelectError::InputFormat(format!("CSV write: {e}")))?;
749 wtr.flush()
750 .map_err(|e| SelectError::InputFormat(format!("CSV flush: {e}")))?;
751 let inner = wtr
752 .into_inner()
753 .map_err(|e| SelectError::InputFormat(format!("CSV finish: {e}")))?;
754 out.extend_from_slice(&inner);
755 }
756 SelectOutputFormat::Json => {
757 let names = projection_names(query, values.len());
758 let mut map = serde_json::Map::with_capacity(values.len());
759 for (n, v) in names.iter().zip(values.iter()) {
760 map.insert(n.clone(), serde_json::Value::String(v.clone()));
761 }
762 let line = serde_json::to_string(&serde_json::Value::Object(map))
763 .map_err(|e| SelectError::InputFormat(format!("JSON serialize: {e}")))?;
764 out.extend_from_slice(line.as_bytes());
765 out.push(b'\n');
766 }
767 }
768 Ok(())
769}
770
771fn write_jsonlines_row(
772 query: &SelectQuery,
773 headers: &[String],
774 values: &[String],
775 output: &SelectOutputFormat,
776 out: &mut Vec<u8>,
777) -> Result<(), SelectError> {
778 match output {
779 SelectOutputFormat::Csv => write_output_row(query, values, output, out)?,
780 SelectOutputFormat::Json => {
781 let names = projection_names_with_headers(query, headers, values.len());
782 let mut map = serde_json::Map::with_capacity(values.len());
783 for (n, v) in names.iter().zip(values.iter()) {
784 map.insert(n.clone(), serde_json::Value::String(v.clone()));
785 }
786 let line = serde_json::to_string(&serde_json::Value::Object(map))
787 .map_err(|e| SelectError::InputFormat(format!("JSON serialize: {e}")))?;
788 out.extend_from_slice(line.as_bytes());
789 out.push(b'\n');
790 }
791 }
792 Ok(())
793}
794
795fn projection_names(query: &SelectQuery, fallback_len: usize) -> Vec<String> {
796 let mut names = Vec::with_capacity(fallback_len);
797 for (i, item) in query.projection.iter().enumerate() {
798 match item {
799 SelectItem::ExprWithAlias { alias, .. } => names.push(alias.value.clone()),
800 SelectItem::UnnamedExpr(e) => match expr_as_column(e) {
801 Ok(s) => names.push(s),
802 Err(_) => names.push(format!("_{}", i + 1)),
803 },
804 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
805 for j in names.len()..fallback_len {
806 names.push(format!("_{}", j + 1));
807 }
808 return names;
809 }
810 }
811 }
812 while names.len() < fallback_len {
813 let n = names.len();
814 names.push(format!("_{}", n + 1));
815 }
816 names
817}
818
819fn projection_names_with_headers(
820 query: &SelectQuery,
821 headers: &[String],
822 fallback_len: usize,
823) -> Vec<String> {
824 let mut names = Vec::with_capacity(fallback_len);
825 for (i, item) in query.projection.iter().enumerate() {
826 match item {
827 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
828 for h in headers {
829 names.push(h.clone());
830 }
831 while names.len() < fallback_len {
832 let n = names.len();
833 names.push(format!("_{}", n + 1));
834 }
835 return names;
836 }
837 SelectItem::ExprWithAlias { alias, .. } => names.push(alias.value.clone()),
838 SelectItem::UnnamedExpr(e) => match expr_as_column(e) {
839 Ok(s) => names.push(s),
840 Err(_) => names.push(format!("_{}", i + 1)),
841 },
842 }
843 }
844 while names.len() < fallback_len {
845 let n = names.len();
846 names.push(format!("_{}", n + 1));
847 }
848 names
849}
850
851#[derive(Debug, Default)]
862pub struct EventStreamWriter {}
863
864impl EventStreamWriter {
865 #[must_use]
866 pub fn new() -> Self {
867 Self {}
868 }
869
870 pub fn records(&mut self, payload: &[u8]) -> Vec<u8> {
874 build_frame(
875 &[
876 (":event-type", "Records"),
877 (":content-type", "application/octet-stream"),
878 (":message-type", "event"),
879 ],
880 Some(payload),
881 )
882 }
883
884 pub fn stats(&mut self, scanned: u64, processed: u64, returned: u64) -> Vec<u8> {
887 let xml = format!(
888 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\
889<Stats xmlns=\"\">\
890<BytesScanned>{scanned}</BytesScanned>\
891<BytesProcessed>{processed}</BytesProcessed>\
892<BytesReturned>{returned}</BytesReturned>\
893</Stats>"
894 );
895 build_frame(
896 &[
897 (":event-type", "Stats"),
898 (":content-type", "text/xml"),
899 (":message-type", "event"),
900 ],
901 Some(xml.as_bytes()),
902 )
903 }
904
905 pub fn end(&mut self) -> Vec<u8> {
908 build_frame(&[(":event-type", "End"), (":message-type", "event")], None)
909 }
910}
911
912fn build_frame(headers: &[(&str, &str)], payload: Option<&[u8]>) -> Vec<u8> {
913 let mut header_buf: Vec<u8> = Vec::new();
914 for (name, value) in headers {
915 let name_bytes = name.as_bytes();
916 let value_bytes = value.as_bytes();
917 debug_assert!(name_bytes.len() <= u8::MAX as usize, "header name too long");
918 debug_assert!(
919 value_bytes.len() <= u16::MAX as usize,
920 "header value too long"
921 );
922 header_buf.push(name_bytes.len() as u8);
923 header_buf.extend_from_slice(name_bytes);
924 header_buf.push(7); header_buf.extend_from_slice(&(value_bytes.len() as u16).to_be_bytes());
926 header_buf.extend_from_slice(value_bytes);
927 }
928 let payload_bytes = payload.unwrap_or(&[]);
929 let headers_len: u32 = header_buf.len() as u32;
930 let total_len: u32 = 12 + headers_len + payload_bytes.len() as u32 + 4;
931
932 let mut buf: Vec<u8> = Vec::with_capacity(total_len as usize);
933 buf.extend_from_slice(&total_len.to_be_bytes());
934 buf.extend_from_slice(&headers_len.to_be_bytes());
935 let prelude_crc = crc32fast::hash(&buf[..8]);
936 buf.extend_from_slice(&prelude_crc.to_be_bytes());
937 buf.extend_from_slice(&header_buf);
938 buf.extend_from_slice(payload_bytes);
939 let message_crc = crc32fast::hash(&buf[..buf.len()]);
940 buf.extend_from_slice(&message_crc.to_be_bytes());
941 buf
942}
943
944#[cfg(feature = "nvcomp-gpu")]
968mod gpu {
969 use super::{Expr, GenericDialect, Parser, SelectInputFormat, Statement, Value};
970 use s4_codec::gpu_select::{CompareOp, GpuSelectKernel};
971 use sqlparser::ast::{BinaryOperator, SelectItem, SetExpr};
972 use std::sync::OnceLock;
973
974 static KERNEL: OnceLock<Option<GpuSelectKernel>> = OnceLock::new();
980
981 fn kernel() -> Option<&'static GpuSelectKernel> {
982 KERNEL
983 .get_or_init(|| match GpuSelectKernel::new() {
984 Ok(k) => Some(k),
985 Err(e) => {
986 tracing::debug!(
987 target: "s4_server::select::gpu",
988 ?e,
989 "GpuSelectKernel init failed; falling back to CPU permanently"
990 );
991 None
992 }
993 })
994 .as_ref()
995 }
996
997 pub(super) fn try_select_gpu(
1001 sql: &str,
1002 body: &[u8],
1003 input: &SelectInputFormat,
1004 ) -> Option<Vec<u8>> {
1005 let SelectInputFormat::Csv {
1008 has_header: true,
1009 delimiter: ',',
1010 } = input
1011 else {
1012 return None;
1013 };
1014
1015 let (col_name, op, literal) = parse_simple_predicate(sql)?;
1016
1017 let col_idx = resolve_header_column(body, &col_name)?;
1022
1023 let kernel = kernel()?;
1024 match kernel.scan_csv(body, col_idx, op, literal.as_bytes()) {
1025 Ok(out) => Some(out),
1026 Err(e) => {
1027 tracing::debug!(
1028 target: "s4_server::select::gpu",
1029 ?e,
1030 "GPU scan failed; falling back to CPU"
1031 );
1032 None
1033 }
1034 }
1035 }
1036
1037 fn parse_simple_predicate(sql: &str) -> Option<(String, CompareOp, String)> {
1042 let mut stmts = Parser::parse_sql(&GenericDialect {}, sql).ok()?;
1043 if stmts.len() != 1 {
1044 return None;
1045 }
1046 let Statement::Query(query) = stmts.pop()? else {
1047 return None;
1048 };
1049 if query.order_by.is_some() || query.limit.is_some() || query.with.is_some() {
1050 return None;
1051 }
1052 let SetExpr::Select(select) = *query.body else {
1053 return None;
1054 };
1055 let projection_is_star =
1060 select.projection.len() == 1 && matches!(select.projection[0], SelectItem::Wildcard(_));
1061 if !projection_is_star {
1062 return None;
1063 }
1064 let where_expr = select.selection?;
1065
1066 let Expr::BinaryOp { op, left, right } = where_expr else {
1068 return None;
1069 };
1070 let col_name = match *left {
1071 Expr::Identifier(i) => i.value,
1072 _ => return None,
1073 };
1074 let (cmp_op, literal_str) = match (op, *right) {
1075 (BinaryOperator::Eq, Expr::Value(v)) => (CompareOp::Equal, value_as_str(&v)?),
1076 (BinaryOperator::NotEq, Expr::Value(v)) => (CompareOp::NotEqual, value_as_str(&v)?),
1077 (BinaryOperator::Gt, Expr::Value(v)) => (CompareOp::GreaterThan, value_as_str(&v)?),
1078 (BinaryOperator::Lt, Expr::Value(v)) => (CompareOp::LessThan, value_as_str(&v)?),
1079 _ => return None,
1080 };
1081
1082 Some((col_name, cmp_op, literal_str))
1083 }
1084
1085 fn value_as_str(v: &Value) -> Option<String> {
1086 match v {
1087 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Some(s.clone()),
1088 Value::Number(s, _) => Some(s.clone()),
1089 _ => None,
1090 }
1091 }
1092
1093 fn resolve_header_column(body: &[u8], col_name: &str) -> Option<usize> {
1098 let nl = body.iter().position(|&b| b == b'\n').unwrap_or(body.len());
1099 let mut header = &body[..nl];
1100 if header.last() == Some(&b'\r') {
1101 header = &header[..header.len() - 1];
1102 }
1103 let header_str = std::str::from_utf8(header).ok()?;
1104 for (i, h) in header_str.split(',').enumerate() {
1105 if h.eq_ignore_ascii_case(col_name) {
1106 return Some(i);
1107 }
1108 }
1109 None
1110 }
1111}
1112
1113#[must_use]
1120pub fn select_gpu(sql: &str, body: &[u8], input: &SelectInputFormat) -> Option<Vec<u8>> {
1121 #[cfg(feature = "nvcomp-gpu")]
1122 {
1123 gpu::try_select_gpu(sql, body, input)
1124 }
1125 #[cfg(not(feature = "nvcomp-gpu"))]
1126 {
1127 let _ = (sql, body, input);
1128 None
1129 }
1130}
1131
1132#[cfg(test)]
1137mod tests {
1138 use super::*;
1139
1140 fn csv_input() -> SelectInputFormat {
1141 SelectInputFormat::Csv {
1142 has_header: true,
1143 delimiter: ',',
1144 }
1145 }
1146
1147 #[test]
1148 fn parse_select_happy_path() {
1149 let q = parse_select("SELECT name, age FROM s3object WHERE age > 30").unwrap();
1150 assert_eq!(q.projection.len(), 2);
1151 assert!(q.where_clause.is_some());
1152 assert_eq!(q.from_alias.to_lowercase(), "s3object");
1153 }
1154
1155 #[test]
1156 fn parse_select_rejects_group_by() {
1157 let err = parse_select("SELECT name, COUNT(*) FROM s3object GROUP BY name").unwrap_err();
1158 match err {
1159 SelectError::UnsupportedFeature(_) => {}
1160 other => panic!("expected UnsupportedFeature, got {other:?}"),
1161 }
1162 }
1163
1164 #[test]
1165 fn parse_select_rejects_join() {
1166 let err =
1167 parse_select("SELECT a.x FROM s3object a JOIN other b ON a.id = b.id").unwrap_err();
1168 assert!(matches!(err, SelectError::UnsupportedFeature(_)));
1169 }
1170
1171 #[test]
1172 fn parse_select_rejects_order_by() {
1173 let err = parse_select("SELECT name FROM s3object ORDER BY name").unwrap_err();
1174 assert!(matches!(err, SelectError::UnsupportedFeature(_)));
1175 }
1176
1177 #[test]
1178 fn evaluate_row_eq_match() {
1179 let q = parse_select("SELECT name FROM s3object WHERE name = 'alice'").unwrap();
1180 let headers = vec!["name".to_owned(), "age".to_owned()];
1181 let row = CsvRow {
1182 fields: vec!["alice", "30"],
1183 headers: Some(&headers),
1184 };
1185 let r = evaluate_row(&q, &row).unwrap();
1186 assert_eq!(r, Some(vec!["alice".to_owned()]));
1187
1188 let row2 = CsvRow {
1189 fields: vec!["bob", "30"],
1190 headers: Some(&headers),
1191 };
1192 assert_eq!(evaluate_row(&q, &row2).unwrap(), None);
1193 }
1194
1195 #[test]
1196 fn evaluate_row_int_compare() {
1197 let q = parse_select("SELECT age FROM s3object WHERE age > 100").unwrap();
1198 let headers = vec!["name".to_owned(), "age".to_owned()];
1199 let big = CsvRow {
1200 fields: vec!["x", "200"],
1201 headers: Some(&headers),
1202 };
1203 let small = CsvRow {
1204 fields: vec!["x", "50"],
1205 headers: Some(&headers),
1206 };
1207 assert!(evaluate_row(&q, &big).unwrap().is_some());
1208 assert!(evaluate_row(&q, &small).unwrap().is_none());
1209 }
1210
1211 #[test]
1212 fn evaluate_row_like_pattern() {
1213 let q = parse_select("SELECT name FROM s3object WHERE name LIKE 'foo%'").unwrap();
1214 let headers = vec!["name".to_owned()];
1215 let yes = CsvRow {
1216 fields: vec!["foobar"],
1217 headers: Some(&headers),
1218 };
1219 let no = CsvRow {
1220 fields: vec!["xfoobar"],
1221 headers: Some(&headers),
1222 };
1223 assert!(evaluate_row(&q, &yes).unwrap().is_some());
1224 assert!(evaluate_row(&q, &no).unwrap().is_none());
1225 }
1226
1227 #[test]
1228 fn run_select_csv_end_to_end_filters_rows() {
1229 let body = b"name,age\nalice,30\nbob,40\ncarol,50\n";
1230 let out = run_select_csv(
1231 "SELECT name FROM s3object WHERE age > 35",
1232 body,
1233 csv_input(),
1234 SelectOutputFormat::Csv,
1235 )
1236 .unwrap();
1237 let s = std::str::from_utf8(&out).unwrap();
1238 let lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1239 assert_eq!(lines, vec!["bob", "carol"]);
1240 }
1241
1242 #[test]
1243 fn run_select_jsonlines_filter() {
1244 let body = b"{\"name\":\"alice\",\"age\":\"30\"}\n\
1245 {\"name\":\"bob\",\"age\":\"40\"}\n\
1246 {\"name\":\"carol\",\"age\":\"50\"}\n";
1247 let out = run_select_jsonlines(
1248 "SELECT name FROM s3object WHERE age > 35",
1249 body,
1250 SelectOutputFormat::Json,
1251 )
1252 .unwrap();
1253 let s = std::str::from_utf8(&out).unwrap();
1254 let lines: Vec<&str> = s.lines().filter(|l| !l.is_empty()).collect();
1255 assert_eq!(lines.len(), 2);
1256 assert!(lines[0].contains("bob"));
1257 assert!(lines[1].contains("carol"));
1258 }
1259
1260 #[test]
1261 fn positional_column_ref() {
1262 let body = b"alice,30\nbob,40\n";
1263 let out = run_select_csv(
1264 "SELECT _1 FROM s3object WHERE _2 > 35",
1265 body,
1266 SelectInputFormat::Csv {
1267 has_header: false,
1268 delimiter: ',',
1269 },
1270 SelectOutputFormat::Csv,
1271 )
1272 .unwrap();
1273 let s = std::str::from_utf8(&out).unwrap();
1274 let lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1275 assert_eq!(lines, vec!["bob"]);
1276 }
1277
1278 #[test]
1279 fn and_or_combination() {
1280 let body = b"name,age,city\n\
1281 alice,30,nyc\n\
1282 bob,40,nyc\n\
1283 carol,50,sf\n\
1284 dan,25,sf\n";
1285 let out = run_select_csv(
1286 "SELECT name FROM s3object WHERE (city = 'nyc' AND age > 35) OR name = 'dan'",
1287 body,
1288 csv_input(),
1289 SelectOutputFormat::Csv,
1290 )
1291 .unwrap();
1292 let s = std::str::from_utf8(&out).unwrap();
1293 let mut lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1294 lines.sort_unstable();
1295 assert_eq!(lines, vec!["bob", "dan"]);
1296 }
1297
1298 #[test]
1299 fn event_stream_records_frame_format() {
1300 let mut w = EventStreamWriter::new();
1301 let frame = w.records(b"hello,world\r\n");
1302 let total = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1303 assert_eq!(total, frame.len());
1304 let headers_len = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1305 let prelude_crc = u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]);
1306 assert_eq!(prelude_crc, crc32fast::hash(&frame[..8]));
1307 let msg_crc = u32::from_be_bytes([
1308 frame[total - 4],
1309 frame[total - 3],
1310 frame[total - 2],
1311 frame[total - 1],
1312 ]);
1313 assert_eq!(msg_crc, crc32fast::hash(&frame[..total - 4]));
1314 let hdr_region = &frame[12..12 + headers_len];
1315 let s = String::from_utf8_lossy(hdr_region);
1316 assert!(s.contains(":event-type"));
1317 assert!(s.contains("Records"));
1318 let payload = &frame[12 + headers_len..total - 4];
1319 assert_eq!(payload, b"hello,world\r\n");
1320 }
1321
1322 #[test]
1323 fn event_stream_end_frame_no_payload() {
1324 let mut w = EventStreamWriter::new();
1325 let frame = w.end();
1326 let total = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1327 let headers_len = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1328 assert_eq!(total - 4 - 12 - headers_len, 0);
1329 let s = String::from_utf8_lossy(&frame[12..12 + headers_len]);
1330 assert!(s.contains("End"));
1331 }
1332
1333 #[test]
1334 fn event_stream_stats_xml_payload() {
1335 let mut w = EventStreamWriter::new();
1336 let frame = w.stats(1024, 800, 64);
1337 let total = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1338 let headers_len = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1339 let payload = &frame[12 + headers_len..total - 4];
1340 let xml = std::str::from_utf8(payload).unwrap();
1341 assert!(xml.contains("<BytesScanned>1024</BytesScanned>"));
1342 assert!(xml.contains("<BytesProcessed>800</BytesProcessed>"));
1343 assert!(xml.contains("<BytesReturned>64</BytesReturned>"));
1344 }
1345
1346 #[test]
1347 fn gpu_no_where_falls_through() {
1348 let v = select_gpu(
1352 "SELECT * FROM s3object",
1353 b"name,age\nalice,30\n",
1354 &csv_input(),
1355 );
1356 assert!(
1357 v.is_none(),
1358 "queries without a WHERE predicate must fall through to CPU"
1359 );
1360 }
1361
1362 #[test]
1363 fn gpu_jsonlines_falls_through() {
1364 let v = select_gpu(
1367 "SELECT * FROM s3object WHERE country = 'Japan'",
1368 b"{\"country\":\"Japan\"}\n",
1369 &SelectInputFormat::JsonLines,
1370 );
1371 assert!(
1372 v.is_none(),
1373 "JSON Lines input must always fall through to CPU"
1374 );
1375 }
1376
1377 #[test]
1378 fn like_match_basics() {
1379 assert!(like_match("foobar", "foo%"));
1380 assert!(!like_match("xfoobar", "foo%"));
1381 assert!(like_match("abc", "_b_"));
1382 assert!(like_match("anything", "%"));
1383 assert!(like_match("", ""));
1384 assert!(!like_match("a", ""));
1385 }
1386}