1use serde::Deserialize;
11
12#[derive(Debug, thiserror::Error, PartialEq, Eq)]
14pub enum SchemaError {
15 #[error("empty {0} name")]
16 EmptyIdentifier(&'static str),
17 #[error("{kind} name too long: {len} > {max}")]
18 IdentifierTooLong {
19 kind: &'static str,
20 len: usize,
21 max: usize,
22 },
23 #[error("invalid {kind} name {name:?}: must match ^[A-Za-z_][A-Za-z0-9_]*$")]
24 InvalidIdentifier { kind: &'static str, name: String },
25 #[error("a table must declare at least one column")]
26 NoColumns,
27 #[error("too many columns: {count} > {max}")]
28 TooManyColumns { count: usize, max: usize },
29 #[error("column name {0:?} is reserved")]
30 ReservedColumn(String),
31 #[error("duplicate column name {0:?}")]
32 DuplicateColumn(String),
33 #[error("invalid DateTime64 precision: {precision} (must be 0..=9)")]
34 InvalidDateTime64Precision { precision: u8 },
35}
36
37#[derive(Debug, Clone, Copy)]
39pub struct SchemaLimits {
40 pub max_columns: usize,
41 pub max_identifier_length: usize,
42}
43
44impl Default for SchemaLimits {
45 fn default() -> Self {
46 Self {
47 max_columns: 1024,
48 max_identifier_length: 128,
49 }
50 }
51}
52
53pub const DEFAULT_RESERVED_COLUMNS: &[&str] = &["attrs", "raw"];
55
56fn is_valid_identifier(name: &str) -> bool {
57 let mut chars = name.chars();
58 match chars.next() {
59 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
60 _ => return false,
61 }
62 chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
63}
64
65fn is_valid_timezone(tz: &str) -> bool {
70 !tz.is_empty()
71 && tz.len() <= 64
72 && tz
73 .chars()
74 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '+' | '/' | '-'))
75}
76
77pub fn validate_identifier<'a>(
80 name: &'a str,
81 kind: &'static str,
82 limits: &SchemaLimits,
83) -> Result<&'a str, SchemaError> {
84 if name.is_empty() {
85 return Err(SchemaError::EmptyIdentifier(kind));
86 }
87 if name.len() > limits.max_identifier_length {
88 return Err(SchemaError::IdentifierTooLong {
89 kind,
90 len: name.len(),
91 max: limits.max_identifier_length,
92 });
93 }
94 if !is_valid_identifier(name) {
95 return Err(SchemaError::InvalidIdentifier {
96 kind,
97 name: name.to_string(),
98 });
99 }
100 Ok(name)
101}
102
103pub fn quote_identifier(name: &str) -> String {
105 format!("`{}`", name.replace('`', "``"))
106}
107
108pub fn assert_column_count(count: usize, limits: &SchemaLimits) -> Result<(), SchemaError> {
110 if count < 1 {
111 return Err(SchemaError::NoColumns);
112 }
113 if count > limits.max_columns {
114 return Err(SchemaError::TooManyColumns {
115 count,
116 max: limits.max_columns,
117 });
118 }
119 Ok(())
120}
121
122pub fn assert_not_reserved(name: &str, reserved: &[&str]) -> Result<(), SchemaError> {
124 if reserved.contains(&name) {
125 return Err(SchemaError::ReservedColumn(name.to_string()));
126 }
127 Ok(())
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
136pub enum ScalarType {
137 String,
138 #[serde(rename = "UUID")]
139 Uuid,
140 Bool,
141 Date,
142 DateTime,
143 DateTime64,
144 Int8,
145 Int16,
146 Int32,
147 Int64,
148 UInt8,
149 UInt16,
150 UInt32,
151 UInt64,
152 Float32,
153 Float64,
154 #[serde(rename = "JSON")]
155 Json,
156}
157
158impl ScalarType {
159 fn ch_type(self) -> &'static str {
160 match self {
161 ScalarType::String => "String",
162 ScalarType::Uuid => "UUID",
163 ScalarType::Bool => "Bool",
164 ScalarType::Date => "Date",
165 ScalarType::DateTime => "DateTime",
166 ScalarType::DateTime64 => "DateTime64(3)",
167 ScalarType::Int8 => "Int8",
168 ScalarType::Int16 => "Int16",
169 ScalarType::Int32 => "Int32",
170 ScalarType::Int64 => "Int64",
171 ScalarType::UInt8 => "UInt8",
172 ScalarType::UInt16 => "UInt16",
173 ScalarType::UInt32 => "UInt32",
174 ScalarType::UInt64 => "UInt64",
175 ScalarType::Float32 => "Float32",
176 ScalarType::Float64 => "Float64",
177 ScalarType::Json => "JSON",
178 }
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
184pub enum StringOnly {
185 String,
186}
187
188fn default_dt64_precision() -> u8 {
189 3
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
201pub struct DateTime64Spec {
202 #[serde(default = "default_dt64_precision")]
203 pub precision: u8,
204 #[serde(default)]
205 pub timezone: Option<String>,
206}
207
208impl DateTime64Spec {
209 pub fn validate(&self) -> Result<(), SchemaError> {
211 if self.precision > 9 {
212 return Err(SchemaError::InvalidDateTime64Precision {
213 precision: self.precision,
214 });
215 }
216 if let Some(tz) = &self.timezone {
217 if !is_valid_timezone(tz) {
218 return Err(SchemaError::InvalidIdentifier {
219 kind: "timezone",
220 name: tz.clone(),
221 });
222 }
223 }
224 Ok(())
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
232#[serde(untagged)]
233pub enum ColumnTypeSpec {
234 Scalar(ScalarType),
235 DateTime64 {
239 datetime64: DateTime64Spec,
240 },
241 Nullable {
242 nullable: Box<ColumnTypeSpec>,
243 },
244 LowCardinality {
245 #[serde(rename = "lowCardinality")]
246 low_cardinality: Box<ColumnTypeSpec>,
247 },
248 Array {
249 array: StringOnly,
250 },
251 Map {
252 map: (StringOnly, StringOnly),
253 },
254}
255
256impl ColumnTypeSpec {
257 pub fn to_ch_type(&self) -> String {
263 match self {
264 ColumnTypeSpec::Scalar(s) => s.ch_type().to_string(),
265 ColumnTypeSpec::DateTime64 { datetime64 } => match &datetime64.timezone {
266 Some(tz) => format!("DateTime64({}, '{}')", datetime64.precision, tz),
267 None => format!("DateTime64({})", datetime64.precision),
268 },
269 ColumnTypeSpec::Nullable { nullable } => format!("Nullable({})", nullable.to_ch_type()),
270 ColumnTypeSpec::LowCardinality { low_cardinality } => {
271 format!("LowCardinality({})", low_cardinality.to_ch_type())
272 }
273 ColumnTypeSpec::Array { .. } => "Array(String)".to_string(),
274 ColumnTypeSpec::Map { .. } => "Map(String, String)".to_string(),
275 }
276 }
277
278 pub fn is_datetime64(&self) -> bool {
283 match self {
284 ColumnTypeSpec::Scalar(ScalarType::DateTime64) => true,
285 ColumnTypeSpec::DateTime64 { .. } => true,
286 ColumnTypeSpec::Nullable { nullable } => nullable.is_datetime64(),
287 ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.is_datetime64(),
288 _ => false,
289 }
290 }
291
292 pub fn validate(&self) -> Result<(), SchemaError> {
297 match self {
298 ColumnTypeSpec::DateTime64 { datetime64 } => datetime64.validate(),
299 ColumnTypeSpec::Nullable { nullable } => nullable.validate(),
300 ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.validate(),
301 _ => Ok(()),
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 fn limits() -> SchemaLimits {
311 SchemaLimits::default()
312 }
313
314 #[test]
315 fn accepts_safe_identifiers() {
316 for ok in ["a", "A", "_x", "org_id", "col1", "X_2_y"] {
317 assert_eq!(validate_identifier(ok, "column", &limits()).unwrap(), ok);
318 }
319 }
320
321 #[test]
322 fn rejects_injection_and_metacharacters() {
323 let attacks = [
324 "a; DROP TABLE x",
325 "a`,`b",
326 "a) ENGINE=Memory AS SELECT * FROM secrets --",
327 "a' OR '1'='1",
328 "a b",
329 "a.b",
330 "a-b",
331 "1col",
332 "",
333 "a\"b",
334 "a\nb",
335 "таблица",
336 "a/*x*/",
337 ];
338 for bad in attacks {
339 assert!(
340 validate_identifier(bad, "column", &limits()).is_err(),
341 "should reject {bad:?}"
342 );
343 }
344 }
345
346 #[test]
347 fn enforces_length_bound() {
348 let lim = limits();
349 let too_long = "a".repeat(lim.max_identifier_length + 1);
350 assert!(validate_identifier(&too_long, "column", &lim).is_err());
351 let ok = "a".repeat(lim.max_identifier_length);
352 assert!(validate_identifier(&ok, "column", &lim).is_ok());
353 }
354
355 #[test]
356 fn quotes_and_escapes() {
357 assert_eq!(quote_identifier("org_id"), "`org_id`");
358 assert_eq!(quote_identifier("a`b"), "`a``b`");
359 }
360
361 #[test]
362 fn bounds_and_reserved() {
363 assert!(assert_column_count(0, &limits()).is_err());
364 assert!(assert_column_count(limits().max_columns + 1, &limits()).is_err());
365 assert!(assert_column_count(10, &limits()).is_ok());
366 assert!(assert_not_reserved("attrs", DEFAULT_RESERVED_COLUMNS).is_err());
367 assert!(assert_not_reserved("raw", DEFAULT_RESERVED_COLUMNS).is_err());
368 assert!(assert_not_reserved("user_col", DEFAULT_RESERVED_COLUMNS).is_ok());
369 }
370
371 #[test]
372 fn allowlist_builds_allowed_types() {
373 let s: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
374 assert_eq!(s.to_ch_type(), "DateTime64(3)");
375 assert!(s.is_datetime64());
376
377 let n: ColumnTypeSpec = serde_json::from_str(r#"{"nullable":"String"}"#).unwrap();
378 assert_eq!(n.to_ch_type(), "Nullable(String)");
379
380 let lc: ColumnTypeSpec =
381 serde_json::from_str(r#"{"lowCardinality":{"nullable":"String"}}"#).unwrap();
382 assert_eq!(lc.to_ch_type(), "LowCardinality(Nullable(String))");
383 let lcd: ColumnTypeSpec =
384 serde_json::from_str(r#"{"lowCardinality":"DateTime64"}"#).unwrap();
385 assert!(lcd.is_datetime64());
386
387 let a: ColumnTypeSpec = serde_json::from_str(r#"{"array":"String"}"#).unwrap();
388 assert_eq!(a.to_ch_type(), "Array(String)");
389 let m: ColumnTypeSpec = serde_json::from_str(r#"{"map":["String","String"]}"#).unwrap();
390 assert_eq!(m.to_ch_type(), "Map(String, String)");
391 }
392
393 #[test]
394 fn allowlist_rejects_disallowed_types() {
395 let bad = [
396 "\"Decimal(38, 10)\"",
397 "\"FixedString(16)\"",
398 "\"Enum8\"",
399 "\"Tuple\"",
400 "\"Nested\"",
401 r#"{"map":["String","Int32"]}"#,
402 r#"{"array":"Int32"}"#,
403 r#"{"array":{"nullable":"String"}}"#,
404 r#"{"wat":"String"}"#,
405 "42",
406 ];
407 for b in bad {
408 assert!(
409 serde_json::from_str::<ColumnTypeSpec>(b).is_err(),
410 "should reject {b}"
411 );
412 }
413 }
414
415 #[test]
416 fn parametrised_datetime64_renders_and_validates() {
417 let utc: ColumnTypeSpec =
419 serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC"}}"#).unwrap();
420 assert_eq!(utc.to_ch_type(), "DateTime64(3, 'UTC')");
421 assert!(utc.is_datetime64());
422 assert!(utc.validate().is_ok());
423
424 let p6: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{"precision":6}}"#).unwrap();
426 assert_eq!(p6.to_ch_type(), "DateTime64(6)");
427 assert!(p6.validate().is_ok());
428
429 let def: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{}}"#).unwrap();
431 assert_eq!(def.to_ch_type(), "DateTime64(3)");
432 assert!(def.is_datetime64());
433 assert!(def.validate().is_ok());
434
435 let bare: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
437 assert!(matches!(
438 bare,
439 ColumnTypeSpec::Scalar(ScalarType::DateTime64)
440 ));
441
442 let tz: ColumnTypeSpec =
444 serde_json::from_str(r#"{"datetime64":{"precision":9,"timezone":"America/New_York"}}"#)
445 .unwrap();
446 assert_eq!(tz.to_ch_type(), "DateTime64(9, 'America/New_York')");
447 assert!(tz.validate().is_ok());
448 }
449
450 #[test]
451 fn parametrised_datetime64_rejects_bad_params() {
452 let bad_tz: ColumnTypeSpec =
454 serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC'; DROP"}}"#)
455 .unwrap();
456 assert!(matches!(
457 bad_tz.validate(),
458 Err(SchemaError::InvalidIdentifier {
459 kind: "timezone",
460 ..
461 })
462 ));
463
464 let bad_p: ColumnTypeSpec =
466 serde_json::from_str(r#"{"datetime64":{"precision":12}}"#).unwrap();
467 assert!(matches!(
468 bad_p.validate(),
469 Err(SchemaError::InvalidDateTime64Precision { precision: 12 })
470 ));
471 }
472
473 #[test]
474 fn parametrised_datetime64_is_datetime64_through_nullable() {
475 let n: ColumnTypeSpec =
476 serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":3,"timezone":"UTC"}}}"#)
477 .unwrap();
478 assert!(n.is_datetime64());
479 assert_eq!(n.to_ch_type(), "Nullable(DateTime64(3, 'UTC'))");
480 assert!(n.validate().is_ok());
481
482 let bad: ColumnTypeSpec =
484 serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":12}}}"#).unwrap();
485 assert!(bad.validate().is_err());
486 }
487}