Skip to main content

clickhouse_kit/
safety.rs

1//! Safe-by-construction primitives for user-defined / multi-tenant ClickHouse
2//! schemas — the Rust-canonical port of `@smooai/clickhouse-kit`'s safety core.
3//!
4//! When column names + types come from untrusted input (a customer config, a DB
5//! row, JSON), these make SQL injection and unbounded tables impossible on the
6//! happy path. In Rust the type allowlist is even stronger than the TS version:
7//! disallowed types (`Decimal`, `FixedString`, `Tuple`, …) have **no representation**
8//! in [`ColumnTypeSpec`], so untrusted input naming them fails to deserialize.
9
10use serde::Deserialize;
11
12/// Raised when untrusted schema input violates a safety rule.
13#[derive(Debug, thiserror::Error, PartialEq, Eq)]
14pub enum SchemaError {
15    #[error("empty {0} name")]
16    EmptyIdentifier(&'static str),
17    #[error("{kind} name too long: {len} > {max}")]
18    IdentifierTooLong {
19        kind: &'static str,
20        len: usize,
21        max: usize,
22    },
23    #[error("invalid {kind} name {name:?}: must match ^[A-Za-z_][A-Za-z0-9_]*$")]
24    InvalidIdentifier { kind: &'static str, name: String },
25    #[error("a table must declare at least one column")]
26    NoColumns,
27    #[error("too many columns: {count} > {max}")]
28    TooManyColumns { count: usize, max: usize },
29    #[error("column name {0:?} is reserved")]
30    ReservedColumn(String),
31    #[error("duplicate column name {0:?}")]
32    DuplicateColumn(String),
33    #[error("invalid DateTime64 precision: {precision} (must be 0..=9)")]
34    InvalidDateTime64Precision { precision: u8 },
35}
36
37/// Size bounds for a schema.
38#[derive(Debug, Clone, Copy)]
39pub struct SchemaLimits {
40    pub max_columns: usize,
41    pub max_identifier_length: usize,
42}
43
44impl Default for SchemaLimits {
45    fn default() -> Self {
46        Self {
47            max_columns: 1024,
48            max_identifier_length: 128,
49        }
50    }
51}
52
53/// Columns reserved for the flexible/hybrid table shape (catch-all + raw payload).
54pub const DEFAULT_RESERVED_COLUMNS: &[&str] = &["attrs", "raw"];
55
56fn is_valid_identifier(name: &str) -> bool {
57    let mut chars = name.chars();
58    match chars.next() {
59        Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
60        _ => return false,
61    }
62    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
63}
64
65/// Whether `tz` is a plausible IANA timezone name: 1..=64 chars from the
66/// `[A-Za-z0-9_+/-]` charset (covers names like `UTC`, `America/New_York`,
67/// `Etc/GMT+5`). Anything outside this charset (quotes, semicolons, spaces) is
68/// rejected, so an untrusted timezone string cannot inject SQL.
69fn is_valid_timezone(tz: &str) -> bool {
70    !tz.is_empty()
71        && tz.len() <= 64
72        && tz
73            .chars()
74            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '+' | '/' | '-'))
75}
76
77/// Validate a table/column identifier against the strict ASCII allowlist + length
78/// bound. `kind` is `"table"` / `"column"` / `"identifier"` for error messages.
79pub fn validate_identifier<'a>(
80    name: &'a str,
81    kind: &'static str,
82    limits: &SchemaLimits,
83) -> Result<&'a str, SchemaError> {
84    if name.is_empty() {
85        return Err(SchemaError::EmptyIdentifier(kind));
86    }
87    if name.len() > limits.max_identifier_length {
88        return Err(SchemaError::IdentifierTooLong {
89            kind,
90            len: name.len(),
91            max: limits.max_identifier_length,
92        });
93    }
94    if !is_valid_identifier(name) {
95        return Err(SchemaError::InvalidIdentifier {
96            kind,
97            name: name.to_string(),
98        });
99    }
100    Ok(name)
101}
102
103/// Backtick-quote an identifier, escaping embedded backticks (defense-in-depth).
104pub fn quote_identifier(name: &str) -> String {
105    format!("`{}`", name.replace('`', "``"))
106}
107
108/// Error unless `count` is within the column-count bound.
109pub fn assert_column_count(count: usize, limits: &SchemaLimits) -> Result<(), SchemaError> {
110    if count < 1 {
111        return Err(SchemaError::NoColumns);
112    }
113    if count > limits.max_columns {
114        return Err(SchemaError::TooManyColumns {
115            count,
116            max: limits.max_columns,
117        });
118    }
119    Ok(())
120}
121
122/// Error if `name` is reserved.
123pub fn assert_not_reserved(name: &str, reserved: &[&str]) -> Result<(), SchemaError> {
124    if reserved.contains(&name) {
125        return Err(SchemaError::ReservedColumn(name.to_string()));
126    }
127    Ok(())
128}
129
130// ── Type allowlist ───────────────────────────────────────────────────────────
131
132/// The allowlisted scalar column types. Anything else (Decimal, FixedString,
133/// Tuple, Enum, Nested, …) has no variant, so it cannot be constructed and
134/// untrusted input naming it fails to deserialize.
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
136pub enum ScalarType {
137    String,
138    #[serde(rename = "UUID")]
139    Uuid,
140    Bool,
141    Date,
142    DateTime,
143    DateTime64,
144    Int8,
145    Int16,
146    Int32,
147    Int64,
148    UInt8,
149    UInt16,
150    UInt32,
151    UInt64,
152    Float32,
153    Float64,
154    #[serde(rename = "JSON")]
155    Json,
156}
157
158impl ScalarType {
159    fn ch_type(self) -> &'static str {
160        match self {
161            ScalarType::String => "String",
162            ScalarType::Uuid => "UUID",
163            ScalarType::Bool => "Bool",
164            ScalarType::Date => "Date",
165            ScalarType::DateTime => "DateTime",
166            ScalarType::DateTime64 => "DateTime64(3)",
167            ScalarType::Int8 => "Int8",
168            ScalarType::Int16 => "Int16",
169            ScalarType::Int32 => "Int32",
170            ScalarType::Int64 => "Int64",
171            ScalarType::UInt8 => "UInt8",
172            ScalarType::UInt16 => "UInt16",
173            ScalarType::UInt32 => "UInt32",
174            ScalarType::UInt64 => "UInt64",
175            ScalarType::Float32 => "Float32",
176            ScalarType::Float64 => "Float64",
177            ScalarType::Json => "JSON",
178        }
179    }
180}
181
182/// `String` is the only allowed `Array`/`Map` element type.
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
184pub enum StringOnly {
185    String,
186}
187
188fn default_dt64_precision() -> u8 {
189    3
190}
191
192/// A parametrised `DateTime64(precision[, 'timezone'])` column type.
193///
194/// **Safety posture:** `precision` and `timezone` may come from untrusted JSON, so
195/// they are **validated before rendering** (via [`DateTime64Spec::validate`], called
196/// from the table builder's per-column loop): `precision` must be `0..=9` and
197/// `timezone` must match the IANA charset `^[A-Za-z0-9_+/-]{1,64}$`. The default
198/// (bare `{"datetime64":{}}`) is `DateTime64(3)`, matching the legacy
199/// [`ScalarType::DateTime64`] rendering.
200#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
201pub struct DateTime64Spec {
202    #[serde(default = "default_dt64_precision")]
203    pub precision: u8,
204    #[serde(default)]
205    pub timezone: Option<String>,
206}
207
208impl DateTime64Spec {
209    /// Validate the (possibly untrusted) precision + timezone before they reach SQL.
210    pub fn validate(&self) -> Result<(), SchemaError> {
211        if self.precision > 9 {
212            return Err(SchemaError::InvalidDateTime64Precision {
213                precision: self.precision,
214            });
215        }
216        if let Some(tz) = &self.timezone {
217            if !is_valid_timezone(tz) {
218                return Err(SchemaError::InvalidIdentifier {
219                    kind: "timezone",
220                    name: tz.clone(),
221                });
222            }
223        }
224        Ok(())
225    }
226}
227
228/// A column type as supplied by untrusted input — the allowlisted recursive shape.
229/// Mirrors the TS `ColumnTypeSpec`: a bare scalar string, or a single-key wrapper
230/// object (`nullable` / `lowCardinality` / `array` / `map`).
231#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
232#[serde(untagged)]
233pub enum ColumnTypeSpec {
234    Scalar(ScalarType),
235    /// Parametrised `DateTime64(precision[, 'timezone'])`. JSON:
236    /// `{"datetime64": {"precision": 3, "timezone": "UTC"}}`. The single `datetime64`
237    /// key keeps the untagged match unambiguous against the other wrappers.
238    DateTime64 {
239        datetime64: DateTime64Spec,
240    },
241    Nullable {
242        nullable: Box<ColumnTypeSpec>,
243    },
244    LowCardinality {
245        #[serde(rename = "lowCardinality")]
246        low_cardinality: Box<ColumnTypeSpec>,
247    },
248    Array {
249        array: StringOnly,
250    },
251    Map {
252        map: (StringOnly, StringOnly),
253    },
254}
255
256impl ColumnTypeSpec {
257    /// The ClickHouse type string for this spec.
258    ///
259    /// For [`ColumnTypeSpec::DateTime64`] this trusts the spec to be valid; untrusted
260    /// precision/timezone must be checked first via [`ColumnTypeSpec::validate`] (the
261    /// table builder does this in its per-column loop).
262    pub fn to_ch_type(&self) -> String {
263        match self {
264            ColumnTypeSpec::Scalar(s) => s.ch_type().to_string(),
265            ColumnTypeSpec::DateTime64 { datetime64 } => match &datetime64.timezone {
266                Some(tz) => format!("DateTime64({}, '{}')", datetime64.precision, tz),
267                None => format!("DateTime64({})", datetime64.precision),
268            },
269            ColumnTypeSpec::Nullable { nullable } => format!("Nullable({})", nullable.to_ch_type()),
270            ColumnTypeSpec::LowCardinality { low_cardinality } => {
271                format!("LowCardinality({})", low_cardinality.to_ch_type())
272            }
273            ColumnTypeSpec::Array { .. } => "Array(String)".to_string(),
274            ColumnTypeSpec::Map { .. } => "Map(String, String)".to_string(),
275        }
276    }
277
278    /// Whether a `DateTime64` is at the core (so a TTL move expression must wrap it
279    /// in `toDateTime(...)`). Propagates through `Nullable`/`LowCardinality` and covers
280    /// both the bare [`ScalarType::DateTime64`] and the parametrised
281    /// [`ColumnTypeSpec::DateTime64`] variant.
282    pub fn is_datetime64(&self) -> bool {
283        match self {
284            ColumnTypeSpec::Scalar(ScalarType::DateTime64) => true,
285            ColumnTypeSpec::DateTime64 { .. } => true,
286            ColumnTypeSpec::Nullable { nullable } => nullable.is_datetime64(),
287            ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.is_datetime64(),
288            _ => false,
289        }
290    }
291
292    /// Validate any embedded untrusted parameters (currently the parametrised
293    /// `DateTime64` precision + timezone) before this type is rendered to SQL.
294    /// Recurses through `Nullable`/`LowCardinality`. Identifier-shaped scalars/arrays/
295    /// maps have nothing to validate here.
296    pub fn validate(&self) -> Result<(), SchemaError> {
297        match self {
298            ColumnTypeSpec::DateTime64 { datetime64 } => datetime64.validate(),
299            ColumnTypeSpec::Nullable { nullable } => nullable.validate(),
300            ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.validate(),
301            _ => Ok(()),
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    fn limits() -> SchemaLimits {
311        SchemaLimits::default()
312    }
313
314    #[test]
315    fn accepts_safe_identifiers() {
316        for ok in ["a", "A", "_x", "org_id", "col1", "X_2_y"] {
317            assert_eq!(validate_identifier(ok, "column", &limits()).unwrap(), ok);
318        }
319    }
320
321    #[test]
322    fn rejects_injection_and_metacharacters() {
323        let attacks = [
324            "a; DROP TABLE x",
325            "a`,`b",
326            "a) ENGINE=Memory AS SELECT * FROM secrets --",
327            "a' OR '1'='1",
328            "a b",
329            "a.b",
330            "a-b",
331            "1col",
332            "",
333            "a\"b",
334            "a\nb",
335            "таблица",
336            "a/*x*/",
337        ];
338        for bad in attacks {
339            assert!(
340                validate_identifier(bad, "column", &limits()).is_err(),
341                "should reject {bad:?}"
342            );
343        }
344    }
345
346    #[test]
347    fn enforces_length_bound() {
348        let lim = limits();
349        let too_long = "a".repeat(lim.max_identifier_length + 1);
350        assert!(validate_identifier(&too_long, "column", &lim).is_err());
351        let ok = "a".repeat(lim.max_identifier_length);
352        assert!(validate_identifier(&ok, "column", &lim).is_ok());
353    }
354
355    #[test]
356    fn quotes_and_escapes() {
357        assert_eq!(quote_identifier("org_id"), "`org_id`");
358        assert_eq!(quote_identifier("a`b"), "`a``b`");
359    }
360
361    #[test]
362    fn bounds_and_reserved() {
363        assert!(assert_column_count(0, &limits()).is_err());
364        assert!(assert_column_count(limits().max_columns + 1, &limits()).is_err());
365        assert!(assert_column_count(10, &limits()).is_ok());
366        assert!(assert_not_reserved("attrs", DEFAULT_RESERVED_COLUMNS).is_err());
367        assert!(assert_not_reserved("raw", DEFAULT_RESERVED_COLUMNS).is_err());
368        assert!(assert_not_reserved("user_col", DEFAULT_RESERVED_COLUMNS).is_ok());
369    }
370
371    #[test]
372    fn allowlist_builds_allowed_types() {
373        let s: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
374        assert_eq!(s.to_ch_type(), "DateTime64(3)");
375        assert!(s.is_datetime64());
376
377        let n: ColumnTypeSpec = serde_json::from_str(r#"{"nullable":"String"}"#).unwrap();
378        assert_eq!(n.to_ch_type(), "Nullable(String)");
379
380        let lc: ColumnTypeSpec =
381            serde_json::from_str(r#"{"lowCardinality":{"nullable":"String"}}"#).unwrap();
382        assert_eq!(lc.to_ch_type(), "LowCardinality(Nullable(String))");
383        let lcd: ColumnTypeSpec =
384            serde_json::from_str(r#"{"lowCardinality":"DateTime64"}"#).unwrap();
385        assert!(lcd.is_datetime64());
386
387        let a: ColumnTypeSpec = serde_json::from_str(r#"{"array":"String"}"#).unwrap();
388        assert_eq!(a.to_ch_type(), "Array(String)");
389        let m: ColumnTypeSpec = serde_json::from_str(r#"{"map":["String","String"]}"#).unwrap();
390        assert_eq!(m.to_ch_type(), "Map(String, String)");
391    }
392
393    #[test]
394    fn allowlist_rejects_disallowed_types() {
395        let bad = [
396            "\"Decimal(38, 10)\"",
397            "\"FixedString(16)\"",
398            "\"Enum8\"",
399            "\"Tuple\"",
400            "\"Nested\"",
401            r#"{"map":["String","Int32"]}"#,
402            r#"{"array":"Int32"}"#,
403            r#"{"array":{"nullable":"String"}}"#,
404            r#"{"wat":"String"}"#,
405            "42",
406        ];
407        for b in bad {
408            assert!(
409                serde_json::from_str::<ColumnTypeSpec>(b).is_err(),
410                "should reject {b}"
411            );
412        }
413    }
414
415    #[test]
416    fn parametrised_datetime64_renders_and_validates() {
417        // Full precision + timezone.
418        let utc: ColumnTypeSpec =
419            serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC"}}"#).unwrap();
420        assert_eq!(utc.to_ch_type(), "DateTime64(3, 'UTC')");
421        assert!(utc.is_datetime64());
422        assert!(utc.validate().is_ok());
423
424        // Precision only, no timezone.
425        let p6: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{"precision":6}}"#).unwrap();
426        assert_eq!(p6.to_ch_type(), "DateTime64(6)");
427        assert!(p6.validate().is_ok());
428
429        // Empty object → defaults to DateTime64(3), matching the legacy scalar.
430        let def: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{}}"#).unwrap();
431        assert_eq!(def.to_ch_type(), "DateTime64(3)");
432        assert!(def.is_datetime64());
433        assert!(def.validate().is_ok());
434
435        // The bare string still deserializes to the legacy scalar variant.
436        let bare: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
437        assert!(matches!(
438            bare,
439            ColumnTypeSpec::Scalar(ScalarType::DateTime64)
440        ));
441
442        // A real IANA name with a slash + plus is accepted.
443        let tz: ColumnTypeSpec =
444            serde_json::from_str(r#"{"datetime64":{"precision":9,"timezone":"America/New_York"}}"#)
445                .unwrap();
446        assert_eq!(tz.to_ch_type(), "DateTime64(9, 'America/New_York')");
447        assert!(tz.validate().is_ok());
448    }
449
450    #[test]
451    fn parametrised_datetime64_rejects_bad_params() {
452        // Injection attempt in the timezone string.
453        let bad_tz: ColumnTypeSpec =
454            serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC'; DROP"}}"#)
455                .unwrap();
456        assert!(matches!(
457            bad_tz.validate(),
458            Err(SchemaError::InvalidIdentifier {
459                kind: "timezone",
460                ..
461            })
462        ));
463
464        // Out-of-range precision.
465        let bad_p: ColumnTypeSpec =
466            serde_json::from_str(r#"{"datetime64":{"precision":12}}"#).unwrap();
467        assert!(matches!(
468            bad_p.validate(),
469            Err(SchemaError::InvalidDateTime64Precision { precision: 12 })
470        ));
471    }
472
473    #[test]
474    fn parametrised_datetime64_is_datetime64_through_nullable() {
475        let n: ColumnTypeSpec =
476            serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":3,"timezone":"UTC"}}}"#)
477                .unwrap();
478        assert!(n.is_datetime64());
479        assert_eq!(n.to_ch_type(), "Nullable(DateTime64(3, 'UTC'))");
480        assert!(n.validate().is_ok());
481
482        // Validation propagates through the wrapper too.
483        let bad: ColumnTypeSpec =
484            serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":12}}}"#).unwrap();
485        assert!(bad.validate().is_err());
486    }
487}