1use sqlparser::ast::{
48 BinaryOperator, Expr, GroupByExpr, ObjectName, Query, Select, SelectItem, SetExpr,
49 Statement, TableFactor, UnaryOperator, Value,
50};
51use sqlparser::dialect::GenericDialect;
52use sqlparser::parser::Parser;
53
54#[derive(Debug, thiserror::Error)]
59pub enum SelectError {
60 #[error("SQL parse error: {0}")]
61 Parse(String),
62 #[error("unsupported SQL feature: {0}")]
63 UnsupportedFeature(String),
64 #[error("input format error: {0}")]
65 InputFormat(String),
66 #[error("row evaluation error: {0}")]
67 RowEval(String),
68}
69
70#[derive(Debug, Clone)]
75pub enum SelectInputFormat {
76 Csv { has_header: bool, delimiter: char },
77 JsonLines,
78}
79
80#[derive(Debug, Clone)]
81pub enum SelectOutputFormat {
82 Csv,
83 Json,
84}
85
86#[derive(Debug, Clone)]
91pub struct SelectQuery {
92 pub projection: Vec<SelectItem>,
95 pub where_clause: Option<Expr>,
96 pub from_alias: String,
100}
101
102pub fn parse_select(sql: &str) -> Result<SelectQuery, SelectError> {
107 let dialect = GenericDialect {};
108 let mut statements = Parser::parse_sql(&dialect, sql)
109 .map_err(|e| SelectError::Parse(e.to_string()))?;
110 if statements.len() != 1 {
111 return Err(SelectError::Parse(format!(
112 "expected exactly one statement, got {}",
113 statements.len()
114 )));
115 }
116 let stmt = statements.pop().expect("len == 1");
117 let query = match stmt {
118 Statement::Query(q) => *q,
119 other => {
120 return Err(SelectError::UnsupportedFeature(format!(
121 "only SELECT statements are supported, got: {other:?}"
122 )));
123 }
124 };
125 let Query {
126 body, order_by, limit, offset, fetch, locks, with, ..
127 } = query;
128 if with.is_some() {
129 return Err(SelectError::UnsupportedFeature("CTE / WITH".into()));
130 }
131 if order_by.is_some() {
132 return Err(SelectError::UnsupportedFeature("ORDER BY".into()));
133 }
134 if limit.is_some() {
135 return Err(SelectError::UnsupportedFeature("LIMIT".into()));
136 }
137 if offset.is_some() {
138 return Err(SelectError::UnsupportedFeature("OFFSET".into()));
139 }
140 if fetch.is_some() {
141 return Err(SelectError::UnsupportedFeature("FETCH".into()));
142 }
143 if !locks.is_empty() {
144 return Err(SelectError::UnsupportedFeature("FOR UPDATE / lock clauses".into()));
145 }
146
147 let select = match *body {
148 SetExpr::Select(s) => *s,
149 SetExpr::Query(_) => {
150 return Err(SelectError::UnsupportedFeature("nested query".into()));
151 }
152 SetExpr::SetOperation { .. } => {
153 return Err(SelectError::UnsupportedFeature("set operation (UNION/INTERSECT/EXCEPT)".into()));
154 }
155 other => {
156 return Err(SelectError::UnsupportedFeature(format!("unsupported SetExpr: {other:?}")));
157 }
158 };
159
160 let Select {
161 distinct,
162 top,
163 projection,
164 from,
165 selection,
166 group_by,
167 having,
168 named_window,
169 qualify,
170 cluster_by,
171 distribute_by,
172 sort_by,
173 prewhere,
174 connect_by,
175 ..
176 } = select;
177 if distinct.is_some() {
178 return Err(SelectError::UnsupportedFeature("DISTINCT".into()));
179 }
180 if top.is_some() {
181 return Err(SelectError::UnsupportedFeature("TOP".into()));
182 }
183 if having.is_some() {
184 return Err(SelectError::UnsupportedFeature("HAVING".into()));
185 }
186 if !named_window.is_empty() {
187 return Err(SelectError::UnsupportedFeature("WINDOW".into()));
188 }
189 if qualify.is_some() {
190 return Err(SelectError::UnsupportedFeature("QUALIFY".into()));
191 }
192 if !cluster_by.is_empty() || !distribute_by.is_empty() || !sort_by.is_empty() {
193 return Err(SelectError::UnsupportedFeature(
194 "CLUSTER BY / DISTRIBUTE BY / SORT BY".into(),
195 ));
196 }
197 if prewhere.is_some() {
198 return Err(SelectError::UnsupportedFeature("PREWHERE".into()));
199 }
200 if connect_by.is_some() {
201 return Err(SelectError::UnsupportedFeature("CONNECT BY".into()));
202 }
203 match group_by {
204 GroupByExpr::Expressions(ref exprs, ref mods) if exprs.is_empty() && mods.is_empty() => {}
205 _ => return Err(SelectError::UnsupportedFeature("GROUP BY".into())),
206 }
207
208 for item in &projection {
211 validate_projection_item(item)?;
212 }
213 if let Some(ref where_expr) = selection {
214 validate_where_expr(where_expr)?;
215 }
216
217 let from_alias = match from.as_slice() {
219 [twj] if twj.joins.is_empty() => match &twj.relation {
220 TableFactor::Table { name, alias, .. } => alias
221 .as_ref()
222 .map(|a| a.name.value.clone())
223 .unwrap_or_else(|| object_name_to_string(name)),
224 _ => {
225 return Err(SelectError::UnsupportedFeature(
226 "only `FROM s3object` (or aliased single table) is supported".into(),
227 ));
228 }
229 },
230 [] => "s3object".to_owned(),
231 _ => return Err(SelectError::UnsupportedFeature("JOIN / multiple FROM tables".into())),
232 };
233
234 Ok(SelectQuery {
235 projection,
236 where_clause: selection,
237 from_alias,
238 })
239}
240
241fn object_name_to_string(name: &ObjectName) -> String {
242 name.0
243 .iter()
244 .map(|i| i.value.as_str())
245 .collect::<Vec<_>>()
246 .join(".")
247}
248
249fn validate_projection_item(item: &SelectItem) -> Result<(), SelectError> {
250 match item {
251 SelectItem::Wildcard(_) => Ok(()),
252 SelectItem::QualifiedWildcard(_, _) => Ok(()),
253 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
254 validate_simple_column_expr(e)
255 }
256 }
257}
258
259fn validate_simple_column_expr(expr: &Expr) -> Result<(), SelectError> {
260 match expr {
261 Expr::Identifier(_) | Expr::CompoundIdentifier(_) => Ok(()),
262 Expr::Function(_) => Err(SelectError::UnsupportedFeature(
263 "aggregate / scalar function in projection (only bare column references supported)".into(),
264 )),
265 Expr::Subquery(_) | Expr::Exists { .. } => {
266 Err(SelectError::UnsupportedFeature("subquery in projection".into()))
267 }
268 _ => Err(SelectError::UnsupportedFeature(format!(
269 "unsupported projection expression: {expr}"
270 ))),
271 }
272}
273
274fn validate_where_expr(expr: &Expr) -> Result<(), SelectError> {
275 match expr {
276 Expr::Identifier(_) | Expr::CompoundIdentifier(_) | Expr::Value(_) => Ok(()),
277 Expr::Nested(inner) => validate_where_expr(inner),
278 Expr::UnaryOp { op, expr } => match op {
279 UnaryOperator::Not | UnaryOperator::Minus | UnaryOperator::Plus => {
280 validate_where_expr(expr)
281 }
282 other => Err(SelectError::UnsupportedFeature(format!(
283 "unsupported unary operator in WHERE: {other:?}"
284 ))),
285 },
286 Expr::BinaryOp { op, left, right } => match op {
287 BinaryOperator::Eq
288 | BinaryOperator::NotEq
289 | BinaryOperator::Lt
290 | BinaryOperator::LtEq
291 | BinaryOperator::Gt
292 | BinaryOperator::GtEq
293 | BinaryOperator::And
294 | BinaryOperator::Or => {
295 validate_where_expr(left)?;
296 validate_where_expr(right)
297 }
298 other => Err(SelectError::UnsupportedFeature(format!(
299 "unsupported binary operator in WHERE: {other:?}"
300 ))),
301 },
302 Expr::Like { expr, pattern, .. } => {
303 validate_where_expr(expr)?;
304 validate_where_expr(pattern)
305 }
306 Expr::IsNull(e) | Expr::IsNotNull(e) => validate_where_expr(e),
307 Expr::Function(_) => Err(SelectError::UnsupportedFeature(
308 "function call in WHERE".into(),
309 )),
310 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
311 Err(SelectError::UnsupportedFeature("subquery in WHERE".into()))
312 }
313 other => Err(SelectError::UnsupportedFeature(format!(
314 "unsupported WHERE expression: {other}"
315 ))),
316 }
317}
318
319pub struct CsvRow<'a> {
326 pub fields: Vec<&'a str>,
327 pub headers: Option<&'a [String]>,
328}
329
330impl CsvRow<'_> {
331 #[must_use]
335 pub fn get(&self, ident: &str) -> Option<&str> {
336 if let Some(stripped) = ident.strip_prefix('_')
337 && let Ok(n) = stripped.parse::<usize>()
338 && n >= 1
339 {
340 return self.fields.get(n - 1).copied();
341 }
342 if let Some(headers) = self.headers {
345 for (i, h) in headers.iter().enumerate() {
346 if h.eq_ignore_ascii_case(ident) {
347 return self.fields.get(i).copied();
348 }
349 }
350 }
351 None
352 }
353}
354
355#[derive(Debug, Clone)]
362enum Lit<'a> {
363 Null,
364 Bool(bool),
365 Int(i64),
366 Float(f64),
367 Str(std::borrow::Cow<'a, str>),
368}
369
370impl<'a> Lit<'a> {
371 fn from_str_value(s: &'a str) -> Lit<'a> {
372 Lit::Str(std::borrow::Cow::Borrowed(s))
373 }
374
375 fn truthy(&self) -> bool {
376 matches!(self, Lit::Bool(true))
377 }
378}
379
380pub fn evaluate_row(
385 query: &SelectQuery,
386 row: &CsvRow<'_>,
387) -> Result<Option<Vec<String>>, SelectError> {
388 if let Some(ref w) = query.where_clause {
389 let v = eval_expr(w, row)?;
390 if !v.truthy() {
391 return Ok(None);
392 }
393 }
394 let mut out = Vec::with_capacity(query.projection.len());
395 for item in &query.projection {
396 match item {
397 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
398 for f in &row.fields {
399 out.push((*f).to_owned());
400 }
401 }
402 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
403 let ident = expr_as_column(e)?;
404 let v = row.get(&ident).ok_or_else(|| {
405 SelectError::RowEval(format!("column not found: {ident}"))
406 })?;
407 out.push(v.to_owned());
408 }
409 }
410 }
411 Ok(Some(out))
412}
413
414fn expr_as_column(expr: &Expr) -> Result<String, SelectError> {
415 match expr {
416 Expr::Identifier(i) => Ok(i.value.clone()),
417 Expr::CompoundIdentifier(parts) => parts
418 .last()
419 .map(|p| p.value.clone())
420 .ok_or_else(|| SelectError::RowEval("empty compound identifier".into())),
421 other => Err(SelectError::UnsupportedFeature(format!(
422 "non-column projection: {other}"
423 ))),
424 }
425}
426
427fn eval_expr<'a>(expr: &Expr, row: &'a CsvRow<'a>) -> Result<Lit<'a>, SelectError> {
428 match expr {
429 Expr::Nested(inner) => eval_expr(inner, row),
430 Expr::Identifier(i) => Ok(row
431 .get(&i.value)
432 .map_or(Lit::Null, Lit::from_str_value)),
433 Expr::CompoundIdentifier(parts) => {
434 let last = parts
435 .last()
436 .ok_or_else(|| SelectError::RowEval("empty compound identifier".into()))?;
437 Ok(row
438 .get(&last.value)
439 .map_or(Lit::Null, Lit::from_str_value))
440 }
441 Expr::Value(v) => value_to_lit(v),
442 Expr::UnaryOp { op, expr } => {
443 let v = eval_expr(expr, row)?;
444 match op {
445 UnaryOperator::Not => Ok(Lit::Bool(!v.truthy())),
446 UnaryOperator::Minus => match v {
447 Lit::Int(n) => Ok(Lit::Int(-n)),
448 Lit::Float(f) => Ok(Lit::Float(-f)),
449 other => Err(SelectError::RowEval(format!(
450 "cannot negate non-numeric value: {other:?}"
451 ))),
452 },
453 UnaryOperator::Plus => Ok(v),
454 other => Err(SelectError::UnsupportedFeature(format!(
455 "unsupported unary op: {other:?}"
456 ))),
457 }
458 }
459 Expr::BinaryOp { op, left, right } => {
460 let l = eval_expr(left, row)?;
461 let r = eval_expr(right, row)?;
462 eval_binary(op, &l, &r)
463 }
464 Expr::Like { negated, expr, pattern, escape_char } => {
465 if escape_char.is_some() {
466 return Err(SelectError::UnsupportedFeature(
467 "LIKE ESCAPE clause".into(),
468 ));
469 }
470 let s_val = eval_expr(expr, row)?;
471 let p_val = eval_expr(pattern, row)?;
472 let s = lit_as_str(&s_val);
473 let p = lit_as_str(&p_val);
474 let m = like_match(s.as_ref(), p.as_ref());
475 Ok(Lit::Bool(if *negated { !m } else { m }))
476 }
477 Expr::IsNull(e) => Ok(Lit::Bool(matches!(eval_expr(e, row)?, Lit::Null))),
478 Expr::IsNotNull(e) => Ok(Lit::Bool(!matches!(eval_expr(e, row)?, Lit::Null))),
479 other => Err(SelectError::UnsupportedFeature(format!(
480 "unsupported expression in WHERE: {other}"
481 ))),
482 }
483}
484
485fn value_to_lit<'a>(v: &Value) -> Result<Lit<'a>, SelectError> {
486 match v {
487 Value::Number(s, _) => {
488 if let Ok(n) = s.parse::<i64>() {
489 Ok(Lit::Int(n))
490 } else if let Ok(f) = s.parse::<f64>() {
491 Ok(Lit::Float(f))
492 } else {
493 Err(SelectError::RowEval(format!("invalid number literal: {s}")))
494 }
495 }
496 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => {
497 Ok(Lit::Str(std::borrow::Cow::Owned(s.clone())))
498 }
499 Value::Boolean(b) => Ok(Lit::Bool(*b)),
500 Value::Null => Ok(Lit::Null),
501 other => Err(SelectError::UnsupportedFeature(format!(
502 "literal kind not supported: {other:?}"
503 ))),
504 }
505}
506
507fn lit_as_str<'a>(v: &Lit<'a>) -> std::borrow::Cow<'a, str> {
508 match v {
509 Lit::Null => std::borrow::Cow::Borrowed(""),
510 Lit::Bool(b) => std::borrow::Cow::Owned(if *b { "true" } else { "false" }.into()),
511 Lit::Int(n) => std::borrow::Cow::Owned(n.to_string()),
512 Lit::Float(f) => std::borrow::Cow::Owned(f.to_string()),
513 Lit::Str(s) => s.clone(),
514 }
515}
516
517fn lit_as_f64(v: &Lit<'_>) -> Option<f64> {
518 match v {
519 Lit::Int(n) => Some(*n as f64),
520 Lit::Float(f) => Some(*f),
521 Lit::Str(s) => s.parse::<f64>().ok(),
522 Lit::Bool(_) | Lit::Null => None,
523 }
524}
525
526fn eval_binary<'a>(
527 op: &BinaryOperator,
528 l: &Lit<'a>,
529 r: &Lit<'a>,
530) -> Result<Lit<'a>, SelectError> {
531 use BinaryOperator::*;
532 match op {
533 And => Ok(Lit::Bool(l.truthy() && r.truthy())),
534 Or => Ok(Lit::Bool(l.truthy() || r.truthy())),
535 Eq | NotEq | Lt | LtEq | Gt | GtEq => {
536 if matches!(l, Lit::Null) || matches!(r, Lit::Null) {
540 return Ok(Lit::Bool(false));
541 }
542 let cmp = if let (Some(a), Some(b)) = (lit_as_f64(l), lit_as_f64(r)) {
545 a.partial_cmp(&b)
546 } else {
547 let a = lit_as_str(l);
548 let b = lit_as_str(r);
549 Some(a.as_ref().cmp(b.as_ref()))
550 };
551 let ord =
552 cmp.ok_or_else(|| SelectError::RowEval("incomparable values (NaN?)".into()))?;
553 let res = match op {
554 Eq => ord == std::cmp::Ordering::Equal,
555 NotEq => ord != std::cmp::Ordering::Equal,
556 Lt => ord == std::cmp::Ordering::Less,
557 LtEq => ord != std::cmp::Ordering::Greater,
558 Gt => ord == std::cmp::Ordering::Greater,
559 GtEq => ord != std::cmp::Ordering::Less,
560 _ => unreachable!("guarded by outer match"),
561 };
562 Ok(Lit::Bool(res))
563 }
564 other => Err(SelectError::UnsupportedFeature(format!(
565 "unsupported binary operator: {other:?}"
566 ))),
567 }
568}
569
570fn like_match(s: &str, pattern: &str) -> bool {
574 let s_bytes: Vec<char> = s.chars().collect();
575 let p_bytes: Vec<char> = pattern.chars().collect();
576 let (mut si, mut pi) = (0usize, 0usize);
577 let (mut star, mut match_si) = (None::<usize>, 0usize);
578 while si < s_bytes.len() {
579 if pi < p_bytes.len() && (p_bytes[pi] == '_' || p_bytes[pi] == s_bytes[si]) {
580 si += 1;
581 pi += 1;
582 } else if pi < p_bytes.len() && p_bytes[pi] == '%' {
583 star = Some(pi);
584 match_si = si;
585 pi += 1;
586 } else if let Some(sp) = star {
587 pi = sp + 1;
588 match_si += 1;
589 si = match_si;
590 } else {
591 return false;
592 }
593 }
594 while pi < p_bytes.len() && p_bytes[pi] == '%' {
595 pi += 1;
596 }
597 pi == p_bytes.len()
598}
599
600pub fn run_select_csv(
608 sql: &str,
609 body: &[u8],
610 input: SelectInputFormat,
611 output: SelectOutputFormat,
612) -> Result<Vec<u8>, SelectError> {
613 if matches!(output, SelectOutputFormat::Csv)
622 && let Some(filtered) = select_gpu(sql, body, &input)
623 {
624 return Ok(filtered);
625 }
626
627 let (has_header, delim) = match input {
628 SelectInputFormat::Csv { has_header, delimiter } => (has_header, delimiter),
629 SelectInputFormat::JsonLines => {
630 return Err(SelectError::InputFormat(
631 "run_select_csv called with JsonLines input — use run_select_jsonlines".into(),
632 ));
633 }
634 };
635 let query = parse_select(sql)?;
636
637 let mut rdr = csv::ReaderBuilder::new()
638 .has_headers(has_header)
639 .delimiter(delim as u8)
640 .flexible(true)
641 .from_reader(body);
642
643 let headers_owned: Option<Vec<String>> = if has_header {
644 let h = rdr
645 .headers()
646 .map_err(|e| SelectError::InputFormat(format!("CSV headers: {e}")))?
647 .iter()
648 .map(|s| s.to_owned())
649 .collect();
650 Some(h)
651 } else {
652 None
653 };
654 let header_slice: Option<&[String]> = headers_owned.as_deref();
655
656 let mut out = Vec::with_capacity(body.len() / 2);
657 for record in rdr.records() {
658 let record = record
659 .map_err(|e| SelectError::InputFormat(format!("CSV record: {e}")))?;
660 let fields: Vec<&str> = record.iter().collect();
661 let row = CsvRow {
662 fields,
663 headers: header_slice,
664 };
665 if let Some(values) = evaluate_row(&query, &row)? {
666 write_output_row(&query, &values, &output, &mut out)?;
667 }
668 }
669 Ok(out)
670}
671
672pub fn run_select_jsonlines(
677 sql: &str,
678 body: &[u8],
679 output: SelectOutputFormat,
680) -> Result<Vec<u8>, SelectError> {
681 let query = parse_select(sql)?;
682 let text = std::str::from_utf8(body)
683 .map_err(|e| SelectError::InputFormat(format!("body is not valid UTF-8: {e}")))?;
684 let mut out = Vec::with_capacity(body.len() / 2);
685 for (lineno, line) in text.lines().enumerate() {
686 let line = line.trim();
687 if line.is_empty() {
688 continue;
689 }
690 let v: serde_json::Value = serde_json::from_str(line).map_err(|e| {
691 SelectError::InputFormat(format!("JSON parse on line {}: {e}", lineno + 1))
692 })?;
693 let obj = v.as_object().ok_or_else(|| {
694 SelectError::InputFormat(format!(
695 "JSON Lines requires top-level object, line {} was not an object",
696 lineno + 1
697 ))
698 })?;
699 let headers: Vec<String> = obj.keys().cloned().collect();
702 let raw_strs: Vec<String> = obj
703 .values()
704 .map(|jv| match jv {
705 serde_json::Value::String(s) => s.clone(),
706 other => other.to_string(),
707 })
708 .collect();
709 let fields: Vec<&str> = raw_strs.iter().map(|s| s.as_str()).collect();
710 let row = CsvRow {
711 fields,
712 headers: Some(headers.as_slice()),
713 };
714 if let Some(values) = evaluate_row(&query, &row)? {
715 write_jsonlines_row(&query, &headers, &values, &output, &mut out)?;
716 }
717 }
718 Ok(out)
719}
720
721fn write_output_row(
722 query: &SelectQuery,
723 values: &[String],
724 output: &SelectOutputFormat,
725 out: &mut Vec<u8>,
726) -> Result<(), SelectError> {
727 match output {
728 SelectOutputFormat::Csv => {
729 let mut wtr = csv::WriterBuilder::new()
730 .terminator(csv::Terminator::CRLF)
731 .from_writer(Vec::new());
732 wtr.write_record(values.iter().map(String::as_str))
733 .map_err(|e| SelectError::InputFormat(format!("CSV write: {e}")))?;
734 wtr.flush()
735 .map_err(|e| SelectError::InputFormat(format!("CSV flush: {e}")))?;
736 let inner = wtr
737 .into_inner()
738 .map_err(|e| SelectError::InputFormat(format!("CSV finish: {e}")))?;
739 out.extend_from_slice(&inner);
740 }
741 SelectOutputFormat::Json => {
742 let names = projection_names(query, values.len());
743 let mut map = serde_json::Map::with_capacity(values.len());
744 for (n, v) in names.iter().zip(values.iter()) {
745 map.insert(n.clone(), serde_json::Value::String(v.clone()));
746 }
747 let line = serde_json::to_string(&serde_json::Value::Object(map))
748 .map_err(|e| SelectError::InputFormat(format!("JSON serialize: {e}")))?;
749 out.extend_from_slice(line.as_bytes());
750 out.push(b'\n');
751 }
752 }
753 Ok(())
754}
755
756fn write_jsonlines_row(
757 query: &SelectQuery,
758 headers: &[String],
759 values: &[String],
760 output: &SelectOutputFormat,
761 out: &mut Vec<u8>,
762) -> Result<(), SelectError> {
763 match output {
764 SelectOutputFormat::Csv => write_output_row(query, values, output, out)?,
765 SelectOutputFormat::Json => {
766 let names = projection_names_with_headers(query, headers, values.len());
767 let mut map = serde_json::Map::with_capacity(values.len());
768 for (n, v) in names.iter().zip(values.iter()) {
769 map.insert(n.clone(), serde_json::Value::String(v.clone()));
770 }
771 let line = serde_json::to_string(&serde_json::Value::Object(map))
772 .map_err(|e| SelectError::InputFormat(format!("JSON serialize: {e}")))?;
773 out.extend_from_slice(line.as_bytes());
774 out.push(b'\n');
775 }
776 }
777 Ok(())
778}
779
780fn projection_names(query: &SelectQuery, fallback_len: usize) -> Vec<String> {
781 let mut names = Vec::with_capacity(fallback_len);
782 for (i, item) in query.projection.iter().enumerate() {
783 match item {
784 SelectItem::ExprWithAlias { alias, .. } => names.push(alias.value.clone()),
785 SelectItem::UnnamedExpr(e) => match expr_as_column(e) {
786 Ok(s) => names.push(s),
787 Err(_) => names.push(format!("_{}", i + 1)),
788 },
789 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
790 for j in names.len()..fallback_len {
791 names.push(format!("_{}", j + 1));
792 }
793 return names;
794 }
795 }
796 }
797 while names.len() < fallback_len {
798 let n = names.len();
799 names.push(format!("_{}", n + 1));
800 }
801 names
802}
803
804fn projection_names_with_headers(
805 query: &SelectQuery,
806 headers: &[String],
807 fallback_len: usize,
808) -> Vec<String> {
809 let mut names = Vec::with_capacity(fallback_len);
810 for (i, item) in query.projection.iter().enumerate() {
811 match item {
812 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
813 for h in headers {
814 names.push(h.clone());
815 }
816 while names.len() < fallback_len {
817 let n = names.len();
818 names.push(format!("_{}", n + 1));
819 }
820 return names;
821 }
822 SelectItem::ExprWithAlias { alias, .. } => names.push(alias.value.clone()),
823 SelectItem::UnnamedExpr(e) => match expr_as_column(e) {
824 Ok(s) => names.push(s),
825 Err(_) => names.push(format!("_{}", i + 1)),
826 },
827 }
828 }
829 while names.len() < fallback_len {
830 let n = names.len();
831 names.push(format!("_{}", n + 1));
832 }
833 names
834}
835
836#[derive(Debug, Default)]
847pub struct EventStreamWriter {}
848
849impl EventStreamWriter {
850 #[must_use]
851 pub fn new() -> Self {
852 Self {}
853 }
854
855 pub fn records(&mut self, payload: &[u8]) -> Vec<u8> {
859 build_frame(
860 &[
861 (":event-type", "Records"),
862 (":content-type", "application/octet-stream"),
863 (":message-type", "event"),
864 ],
865 Some(payload),
866 )
867 }
868
869 pub fn stats(&mut self, scanned: u64, processed: u64, returned: u64) -> Vec<u8> {
872 let xml = format!(
873 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\
874<Stats xmlns=\"\">\
875<BytesScanned>{scanned}</BytesScanned>\
876<BytesProcessed>{processed}</BytesProcessed>\
877<BytesReturned>{returned}</BytesReturned>\
878</Stats>"
879 );
880 build_frame(
881 &[
882 (":event-type", "Stats"),
883 (":content-type", "text/xml"),
884 (":message-type", "event"),
885 ],
886 Some(xml.as_bytes()),
887 )
888 }
889
890 pub fn end(&mut self) -> Vec<u8> {
893 build_frame(
894 &[
895 (":event-type", "End"),
896 (":message-type", "event"),
897 ],
898 None,
899 )
900 }
901}
902
903fn build_frame(headers: &[(&str, &str)], payload: Option<&[u8]>) -> Vec<u8> {
904 let mut header_buf: Vec<u8> = Vec::new();
905 for (name, value) in headers {
906 let name_bytes = name.as_bytes();
907 let value_bytes = value.as_bytes();
908 debug_assert!(name_bytes.len() <= u8::MAX as usize, "header name too long");
909 debug_assert!(value_bytes.len() <= u16::MAX as usize, "header value too long");
910 header_buf.push(name_bytes.len() as u8);
911 header_buf.extend_from_slice(name_bytes);
912 header_buf.push(7); header_buf.extend_from_slice(&(value_bytes.len() as u16).to_be_bytes());
914 header_buf.extend_from_slice(value_bytes);
915 }
916 let payload_bytes = payload.unwrap_or(&[]);
917 let headers_len: u32 = header_buf.len() as u32;
918 let total_len: u32 = 12 + headers_len + payload_bytes.len() as u32 + 4;
919
920 let mut buf: Vec<u8> = Vec::with_capacity(total_len as usize);
921 buf.extend_from_slice(&total_len.to_be_bytes());
922 buf.extend_from_slice(&headers_len.to_be_bytes());
923 let prelude_crc = crc32fast::hash(&buf[..8]);
924 buf.extend_from_slice(&prelude_crc.to_be_bytes());
925 buf.extend_from_slice(&header_buf);
926 buf.extend_from_slice(payload_bytes);
927 let message_crc = crc32fast::hash(&buf[..buf.len()]);
928 buf.extend_from_slice(&message_crc.to_be_bytes());
929 buf
930}
931
932#[cfg(feature = "nvcomp-gpu")]
956mod gpu {
957 use super::{Expr, GenericDialect, Parser, SelectInputFormat, Statement, Value};
958 use s4_codec::gpu_select::{CompareOp, GpuSelectKernel};
959 use sqlparser::ast::{BinaryOperator, SelectItem, SetExpr};
960 use std::sync::OnceLock;
961
962 static KERNEL: OnceLock<Option<GpuSelectKernel>> = OnceLock::new();
968
969 fn kernel() -> Option<&'static GpuSelectKernel> {
970 KERNEL
971 .get_or_init(|| match GpuSelectKernel::new() {
972 Ok(k) => Some(k),
973 Err(e) => {
974 tracing::debug!(
975 target: "s4_server::select::gpu",
976 ?e,
977 "GpuSelectKernel init failed; falling back to CPU permanently"
978 );
979 None
980 }
981 })
982 .as_ref()
983 }
984
985 pub(super) fn try_select_gpu(
989 sql: &str,
990 body: &[u8],
991 input: &SelectInputFormat,
992 ) -> Option<Vec<u8>> {
993 let SelectInputFormat::Csv {
996 has_header: true,
997 delimiter: ',',
998 } = input
999 else {
1000 return None;
1001 };
1002
1003 let (col_name, op, literal) = parse_simple_predicate(sql)?;
1004
1005 let col_idx = resolve_header_column(body, &col_name)?;
1010
1011 let kernel = kernel()?;
1012 match kernel.scan_csv(body, col_idx, op, literal.as_bytes()) {
1013 Ok(out) => Some(out),
1014 Err(e) => {
1015 tracing::debug!(
1016 target: "s4_server::select::gpu",
1017 ?e,
1018 "GPU scan failed; falling back to CPU"
1019 );
1020 None
1021 }
1022 }
1023 }
1024
1025 fn parse_simple_predicate(sql: &str) -> Option<(String, CompareOp, String)> {
1030 let mut stmts = Parser::parse_sql(&GenericDialect {}, sql).ok()?;
1031 if stmts.len() != 1 {
1032 return None;
1033 }
1034 let Statement::Query(query) = stmts.pop()? else {
1035 return None;
1036 };
1037 if query.order_by.is_some() || query.limit.is_some() || query.with.is_some() {
1038 return None;
1039 }
1040 let SetExpr::Select(select) = *query.body else {
1041 return None;
1042 };
1043 let projection_is_star = select.projection.len() == 1
1048 && matches!(select.projection[0], SelectItem::Wildcard(_));
1049 if !projection_is_star {
1050 return None;
1051 }
1052 let where_expr = select.selection?;
1053
1054 let Expr::BinaryOp { op, left, right } = where_expr else {
1056 return None;
1057 };
1058 let col_name = match *left {
1059 Expr::Identifier(i) => i.value,
1060 _ => return None,
1061 };
1062 let (cmp_op, literal_str) = match (op, *right) {
1063 (BinaryOperator::Eq, Expr::Value(v)) => (CompareOp::Equal, value_as_str(&v)?),
1064 (BinaryOperator::NotEq, Expr::Value(v)) => {
1065 (CompareOp::NotEqual, value_as_str(&v)?)
1066 }
1067 (BinaryOperator::Gt, Expr::Value(v)) => {
1068 (CompareOp::GreaterThan, value_as_str(&v)?)
1069 }
1070 (BinaryOperator::Lt, Expr::Value(v)) => {
1071 (CompareOp::LessThan, value_as_str(&v)?)
1072 }
1073 _ => return None,
1074 };
1075
1076 Some((col_name, cmp_op, literal_str))
1077 }
1078
1079 fn value_as_str(v: &Value) -> Option<String> {
1080 match v {
1081 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Some(s.clone()),
1082 Value::Number(s, _) => Some(s.clone()),
1083 _ => None,
1084 }
1085 }
1086
1087 fn resolve_header_column(body: &[u8], col_name: &str) -> Option<usize> {
1092 let nl = body.iter().position(|&b| b == b'\n').unwrap_or(body.len());
1093 let mut header = &body[..nl];
1094 if header.last() == Some(&b'\r') {
1095 header = &header[..header.len() - 1];
1096 }
1097 let header_str = std::str::from_utf8(header).ok()?;
1098 for (i, h) in header_str.split(',').enumerate() {
1099 if h.eq_ignore_ascii_case(col_name) {
1100 return Some(i);
1101 }
1102 }
1103 None
1104 }
1105}
1106
1107#[must_use]
1114pub fn select_gpu(
1115 sql: &str,
1116 body: &[u8],
1117 input: &SelectInputFormat,
1118) -> Option<Vec<u8>> {
1119 #[cfg(feature = "nvcomp-gpu")]
1120 {
1121 gpu::try_select_gpu(sql, body, input)
1122 }
1123 #[cfg(not(feature = "nvcomp-gpu"))]
1124 {
1125 let _ = (sql, body, input);
1126 None
1127 }
1128}
1129
1130#[cfg(test)]
1135mod tests {
1136 use super::*;
1137
1138 fn csv_input() -> SelectInputFormat {
1139 SelectInputFormat::Csv {
1140 has_header: true,
1141 delimiter: ',',
1142 }
1143 }
1144
1145 #[test]
1146 fn parse_select_happy_path() {
1147 let q = parse_select("SELECT name, age FROM s3object WHERE age > 30").unwrap();
1148 assert_eq!(q.projection.len(), 2);
1149 assert!(q.where_clause.is_some());
1150 assert_eq!(q.from_alias.to_lowercase(), "s3object");
1151 }
1152
1153 #[test]
1154 fn parse_select_rejects_group_by() {
1155 let err =
1156 parse_select("SELECT name, COUNT(*) FROM s3object GROUP BY name").unwrap_err();
1157 match err {
1158 SelectError::UnsupportedFeature(_) => {}
1159 other => panic!("expected UnsupportedFeature, got {other:?}"),
1160 }
1161 }
1162
1163 #[test]
1164 fn parse_select_rejects_join() {
1165 let err = parse_select("SELECT a.x FROM s3object a JOIN other b ON a.id = b.id")
1166 .unwrap_err();
1167 assert!(matches!(err, SelectError::UnsupportedFeature(_)));
1168 }
1169
1170 #[test]
1171 fn parse_select_rejects_order_by() {
1172 let err = parse_select("SELECT name FROM s3object ORDER BY name").unwrap_err();
1173 assert!(matches!(err, SelectError::UnsupportedFeature(_)));
1174 }
1175
1176 #[test]
1177 fn evaluate_row_eq_match() {
1178 let q = parse_select("SELECT name FROM s3object WHERE name = 'alice'").unwrap();
1179 let headers = vec!["name".to_owned(), "age".to_owned()];
1180 let row = CsvRow {
1181 fields: vec!["alice", "30"],
1182 headers: Some(&headers),
1183 };
1184 let r = evaluate_row(&q, &row).unwrap();
1185 assert_eq!(r, Some(vec!["alice".to_owned()]));
1186
1187 let row2 = CsvRow {
1188 fields: vec!["bob", "30"],
1189 headers: Some(&headers),
1190 };
1191 assert_eq!(evaluate_row(&q, &row2).unwrap(), None);
1192 }
1193
1194 #[test]
1195 fn evaluate_row_int_compare() {
1196 let q = parse_select("SELECT age FROM s3object WHERE age > 100").unwrap();
1197 let headers = vec!["name".to_owned(), "age".to_owned()];
1198 let big = CsvRow {
1199 fields: vec!["x", "200"],
1200 headers: Some(&headers),
1201 };
1202 let small = CsvRow {
1203 fields: vec!["x", "50"],
1204 headers: Some(&headers),
1205 };
1206 assert!(evaluate_row(&q, &big).unwrap().is_some());
1207 assert!(evaluate_row(&q, &small).unwrap().is_none());
1208 }
1209
1210 #[test]
1211 fn evaluate_row_like_pattern() {
1212 let q = parse_select("SELECT name FROM s3object WHERE name LIKE 'foo%'").unwrap();
1213 let headers = vec!["name".to_owned()];
1214 let yes = CsvRow {
1215 fields: vec!["foobar"],
1216 headers: Some(&headers),
1217 };
1218 let no = CsvRow {
1219 fields: vec!["xfoobar"],
1220 headers: Some(&headers),
1221 };
1222 assert!(evaluate_row(&q, &yes).unwrap().is_some());
1223 assert!(evaluate_row(&q, &no).unwrap().is_none());
1224 }
1225
1226 #[test]
1227 fn run_select_csv_end_to_end_filters_rows() {
1228 let body = b"name,age\nalice,30\nbob,40\ncarol,50\n";
1229 let out = run_select_csv(
1230 "SELECT name FROM s3object WHERE age > 35",
1231 body,
1232 csv_input(),
1233 SelectOutputFormat::Csv,
1234 )
1235 .unwrap();
1236 let s = std::str::from_utf8(&out).unwrap();
1237 let lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1238 assert_eq!(lines, vec!["bob", "carol"]);
1239 }
1240
1241 #[test]
1242 fn run_select_jsonlines_filter() {
1243 let body = b"{\"name\":\"alice\",\"age\":\"30\"}\n\
1244 {\"name\":\"bob\",\"age\":\"40\"}\n\
1245 {\"name\":\"carol\",\"age\":\"50\"}\n";
1246 let out = run_select_jsonlines(
1247 "SELECT name FROM s3object WHERE age > 35",
1248 body,
1249 SelectOutputFormat::Json,
1250 )
1251 .unwrap();
1252 let s = std::str::from_utf8(&out).unwrap();
1253 let lines: Vec<&str> = s.lines().filter(|l| !l.is_empty()).collect();
1254 assert_eq!(lines.len(), 2);
1255 assert!(lines[0].contains("bob"));
1256 assert!(lines[1].contains("carol"));
1257 }
1258
1259 #[test]
1260 fn positional_column_ref() {
1261 let body = b"alice,30\nbob,40\n";
1262 let out = run_select_csv(
1263 "SELECT _1 FROM s3object WHERE _2 > 35",
1264 body,
1265 SelectInputFormat::Csv {
1266 has_header: false,
1267 delimiter: ',',
1268 },
1269 SelectOutputFormat::Csv,
1270 )
1271 .unwrap();
1272 let s = std::str::from_utf8(&out).unwrap();
1273 let lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1274 assert_eq!(lines, vec!["bob"]);
1275 }
1276
1277 #[test]
1278 fn and_or_combination() {
1279 let body = b"name,age,city\n\
1280 alice,30,nyc\n\
1281 bob,40,nyc\n\
1282 carol,50,sf\n\
1283 dan,25,sf\n";
1284 let out = run_select_csv(
1285 "SELECT name FROM s3object WHERE (city = 'nyc' AND age > 35) OR name = 'dan'",
1286 body,
1287 csv_input(),
1288 SelectOutputFormat::Csv,
1289 )
1290 .unwrap();
1291 let s = std::str::from_utf8(&out).unwrap();
1292 let mut lines: Vec<&str> = s.split("\r\n").filter(|l| !l.is_empty()).collect();
1293 lines.sort_unstable();
1294 assert_eq!(lines, vec!["bob", "dan"]);
1295 }
1296
1297 #[test]
1298 fn event_stream_records_frame_format() {
1299 let mut w = EventStreamWriter::new();
1300 let frame = w.records(b"hello,world\r\n");
1301 let total =
1302 u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1303 assert_eq!(total, frame.len());
1304 let headers_len =
1305 u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1306 let prelude_crc =
1307 u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]);
1308 assert_eq!(prelude_crc, crc32fast::hash(&frame[..8]));
1309 let msg_crc = u32::from_be_bytes([
1310 frame[total - 4],
1311 frame[total - 3],
1312 frame[total - 2],
1313 frame[total - 1],
1314 ]);
1315 assert_eq!(msg_crc, crc32fast::hash(&frame[..total - 4]));
1316 let hdr_region = &frame[12..12 + headers_len];
1317 let s = String::from_utf8_lossy(hdr_region);
1318 assert!(s.contains(":event-type"));
1319 assert!(s.contains("Records"));
1320 let payload = &frame[12 + headers_len..total - 4];
1321 assert_eq!(payload, b"hello,world\r\n");
1322 }
1323
1324 #[test]
1325 fn event_stream_end_frame_no_payload() {
1326 let mut w = EventStreamWriter::new();
1327 let frame = w.end();
1328 let total =
1329 u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1330 let headers_len =
1331 u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1332 assert_eq!(total - 4 - 12 - headers_len, 0);
1333 let s = String::from_utf8_lossy(&frame[12..12 + headers_len]);
1334 assert!(s.contains("End"));
1335 }
1336
1337 #[test]
1338 fn event_stream_stats_xml_payload() {
1339 let mut w = EventStreamWriter::new();
1340 let frame = w.stats(1024, 800, 64);
1341 let total =
1342 u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
1343 let headers_len =
1344 u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
1345 let payload = &frame[12 + headers_len..total - 4];
1346 let xml = std::str::from_utf8(payload).unwrap();
1347 assert!(xml.contains("<BytesScanned>1024</BytesScanned>"));
1348 assert!(xml.contains("<BytesProcessed>800</BytesProcessed>"));
1349 assert!(xml.contains("<BytesReturned>64</BytesReturned>"));
1350 }
1351
1352 #[test]
1353 fn gpu_no_where_falls_through() {
1354 let v = select_gpu(
1358 "SELECT * FROM s3object",
1359 b"name,age\nalice,30\n",
1360 &csv_input(),
1361 );
1362 assert!(
1363 v.is_none(),
1364 "queries without a WHERE predicate must fall through to CPU"
1365 );
1366 }
1367
1368 #[test]
1369 fn gpu_jsonlines_falls_through() {
1370 let v = select_gpu(
1373 "SELECT * FROM s3object WHERE country = 'Japan'",
1374 b"{\"country\":\"Japan\"}\n",
1375 &SelectInputFormat::JsonLines,
1376 );
1377 assert!(
1378 v.is_none(),
1379 "JSON Lines input must always fall through to CPU"
1380 );
1381 }
1382
1383 #[test]
1384 fn like_match_basics() {
1385 assert!(like_match("foobar", "foo%"));
1386 assert!(!like_match("xfoobar", "foo%"));
1387 assert!(like_match("abc", "_b_"));
1388 assert!(like_match("anything", "%"));
1389 assert!(like_match("", ""));
1390 assert!(!like_match("a", ""));
1391 }
1392}