Skip to main content

sqlite_vector_rs/vtab/
config.rs

1use std::fmt;
2
3use crate::distance::DistanceMetric;
4use crate::index::HnswParams;
5use crate::types::VectorType;
6
7#[derive(Debug)]
8pub struct ConfigError(pub String);
9
10impl fmt::Display for ConfigError {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        write!(f, "config error: {}", self.0)
13    }
14}
15
16impl std::error::Error for ConfigError {}
17
18/// Parsed configuration from CREATE VIRTUAL TABLE arguments.
19#[derive(Debug, Clone)]
20pub struct VectorTableConfig {
21    pub db_name: String,
22    pub table_name: String,
23    pub dim: usize,
24    pub vtype: VectorType,
25    pub metric: DistanceMetric,
26    pub hnsw_params: HnswParams,
27    pub metadata_columns: Vec<(String, String)>,
28}
29
30impl VectorTableConfig {
31    pub fn parse(args: &[&str]) -> Result<Self, ConfigError> {
32        if args.len() < 3 {
33            return Err(ConfigError(
34                "expected at least module, db, and table name".into(),
35            ));
36        }
37
38        let db_name = args[1].to_string();
39        let table_name = args[2].to_string();
40
41        let mut dim: Option<usize> = None;
42        let mut vtype = VectorType::Float4;
43        let mut metric = DistanceMetric::L2;
44        let mut hnsw_params = HnswParams::default();
45        let mut metadata_columns = Vec::new();
46
47        for &arg in &args[3..] {
48            let (key, value) = arg
49                .split_once('=')
50                .ok_or_else(|| ConfigError(format!("invalid argument: {arg}")))?;
51            let key = key.trim();
52            let value = value.trim().trim_matches('"');
53
54            match key {
55                "dim" => {
56                    let d: i64 = value
57                        .parse()
58                        .map_err(|_| ConfigError(format!("invalid dim: {value}")))?;
59                    if d <= 0 {
60                        return Err(ConfigError(format!("dim must be positive, got {d}")));
61                    }
62                    dim = Some(d as usize);
63                }
64                "type" => {
65                    vtype = VectorType::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
66                }
67                "metric" => {
68                    metric =
69                        DistanceMetric::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
70                }
71                "m" => {
72                    hnsw_params.m = value
73                        .parse()
74                        .map_err(|_| ConfigError(format!("invalid m: {value}")))?;
75                }
76                "ef_construction" => {
77                    hnsw_params.ef_construction = value
78                        .parse()
79                        .map_err(|_| ConfigError(format!("invalid ef_construction: {value}")))?;
80                }
81                "ef_search" => {
82                    hnsw_params.ef_search = value
83                        .parse()
84                        .map_err(|_| ConfigError(format!("invalid ef_search: {value}")))?;
85                }
86                "metadata" => {
87                    metadata_columns = parse_metadata_columns(value)?;
88                }
89                other => {
90                    return Err(ConfigError(format!("unknown parameter: {other}")));
91                }
92            }
93        }
94
95        let dim = dim.ok_or_else(|| ConfigError("dim is required".into()))?;
96
97        Ok(Self {
98            db_name,
99            table_name,
100            dim,
101            vtype,
102            metric,
103            hnsw_params,
104            metadata_columns,
105        })
106    }
107
108    pub fn vtab_schema(&self) -> String {
109        let mut cols = vec![
110            "id INTEGER PRIMARY KEY".to_string(),
111            "vector BLOB".to_string(),
112        ];
113        for (name, sql_type) in &self.metadata_columns {
114            cols.push(format!("{name} {sql_type}"));
115        }
116        cols.push("distance REAL HIDDEN".to_string());
117        format!("CREATE TABLE x({})", cols.join(", "))
118    }
119}
120
121fn parse_metadata_columns(spec: &str) -> Result<Vec<(String, String)>, ConfigError> {
122    let mut columns = Vec::new();
123    for part in spec.split(',') {
124        let part = part.trim();
125        if part.is_empty() {
126            continue;
127        }
128        let mut tokens = part.split_whitespace();
129        let name = tokens
130            .next()
131            .ok_or_else(|| ConfigError("empty metadata column definition".to_string()))?
132            .to_string();
133        let sql_type = tokens
134            .next()
135            .ok_or_else(|| ConfigError(format!("missing type for metadata column {name}")))?
136            .to_string();
137        columns.push((name, sql_type));
138    }
139    Ok(columns)
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::distance::DistanceMetric;
146    use crate::types::VectorType;
147
148    // ----------------------------------------------------------------
149    // parse — minimal args (dim only, defaults for type/metric)
150    // ----------------------------------------------------------------
151
152    #[test]
153    fn parse_minimal_args_uses_defaults() {
154        let cfg = VectorTableConfig::parse(&["vector", "main", "test", "dim=3"]).unwrap();
155        assert_eq!(cfg.db_name, "main");
156        assert_eq!(cfg.table_name, "test");
157        assert_eq!(cfg.dim, 3);
158        assert_eq!(cfg.vtype, VectorType::Float4);
159        assert_eq!(cfg.metric, DistanceMetric::L2);
160        // HNSW defaults
161        assert_eq!(cfg.hnsw_params.m, 16);
162        assert_eq!(cfg.hnsw_params.ef_construction, 200);
163        assert_eq!(cfg.hnsw_params.ef_search, 64);
164        assert!(cfg.metadata_columns.is_empty());
165    }
166
167    // ----------------------------------------------------------------
168    // parse — all parameters specified
169    // ----------------------------------------------------------------
170
171    #[test]
172    fn parse_all_params_specified() {
173        let cfg = VectorTableConfig::parse(&[
174            "vector",
175            "main",
176            "embeddings",
177            "dim=128",
178            "type=float4",
179            "metric=l2",
180            "m=32",
181            "ef_construction=400",
182            "ef_search=128",
183        ])
184        .unwrap();
185        assert_eq!(cfg.dim, 128);
186        assert_eq!(cfg.vtype, VectorType::Float4);
187        assert_eq!(cfg.metric, DistanceMetric::L2);
188        assert_eq!(cfg.hnsw_params.m, 32);
189        assert_eq!(cfg.hnsw_params.ef_construction, 400);
190        assert_eq!(cfg.hnsw_params.ef_search, 128);
191    }
192
193    // ----------------------------------------------------------------
194    // parse — type=float8, metric=cosine
195    // ----------------------------------------------------------------
196
197    #[test]
198    fn parse_float8_cosine() {
199        let cfg = VectorTableConfig::parse(&[
200            "vector",
201            "main",
202            "vecs",
203            "dim=64",
204            "type=float8",
205            "metric=cosine",
206        ])
207        .unwrap();
208        assert_eq!(cfg.vtype, VectorType::Float8);
209        assert_eq!(cfg.metric, DistanceMetric::Cosine);
210    }
211
212    // ----------------------------------------------------------------
213    // parse — custom HNSW params (m, ef_construction, ef_search)
214    // ----------------------------------------------------------------
215
216    #[test]
217    fn parse_custom_hnsw_params() {
218        let cfg = VectorTableConfig::parse(&[
219            "vector",
220            "main",
221            "idx",
222            "dim=16",
223            "m=8",
224            "ef_construction=100",
225            "ef_search=50",
226        ])
227        .unwrap();
228        assert_eq!(cfg.hnsw_params.m, 8);
229        assert_eq!(cfg.hnsw_params.ef_construction, 100);
230        assert_eq!(cfg.hnsw_params.ef_search, 50);
231    }
232
233    // ----------------------------------------------------------------
234    // parse — metadata columns
235    // ----------------------------------------------------------------
236
237    #[test]
238    fn parse_metadata_columns_parsed_correctly() {
239        let cfg = VectorTableConfig::parse(&[
240            "vector",
241            "main",
242            "docs",
243            "dim=4",
244            "metadata=label TEXT,score REAL",
245        ])
246        .unwrap();
247        assert_eq!(cfg.metadata_columns.len(), 2);
248        assert_eq!(
249            cfg.metadata_columns[0],
250            ("label".to_string(), "TEXT".to_string())
251        );
252        assert_eq!(
253            cfg.metadata_columns[1],
254            ("score".to_string(), "REAL".to_string())
255        );
256    }
257
258    // ----------------------------------------------------------------
259    // parse errors — too few args (< 3)
260    // ----------------------------------------------------------------
261
262    #[test]
263    fn parse_error_too_few_args_zero() {
264        let err = VectorTableConfig::parse(&[]).unwrap_err();
265        assert!(
266            err.0.contains("at least"),
267            "unexpected error message: {}",
268            err.0
269        );
270    }
271
272    #[test]
273    fn parse_error_too_few_args_two() {
274        // Only module + db name; table name is absent.
275        let err = VectorTableConfig::parse(&["vector", "main"]).unwrap_err();
276        assert!(
277            err.0.contains("at least"),
278            "unexpected error message: {}",
279            err.0
280        );
281    }
282
283    // ----------------------------------------------------------------
284    // parse errors — missing dim
285    // ----------------------------------------------------------------
286
287    #[test]
288    fn parse_error_missing_dim() {
289        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "type=float4"]).unwrap_err();
290        assert!(
291            err.0.contains("dim"),
292            "expected error mentioning 'dim', got: {}",
293            err.0
294        );
295    }
296
297    // ----------------------------------------------------------------
298    // parse errors — invalid dim values
299    // ----------------------------------------------------------------
300
301    #[test]
302    fn parse_error_dim_zero() {
303        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=0"]).unwrap_err();
304        assert!(
305            err.0.contains("positive") || err.0.contains("dim"),
306            "unexpected error: {}",
307            err.0
308        );
309    }
310
311    #[test]
312    fn parse_error_dim_negative() {
313        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=-5"]).unwrap_err();
314        assert!(
315            err.0.contains("positive") || err.0.contains("dim"),
316            "unexpected error: {}",
317            err.0
318        );
319    }
320
321    #[test]
322    fn parse_error_dim_non_numeric() {
323        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=abc"]).unwrap_err();
324        assert!(
325            err.0.contains("dim"),
326            "expected error mentioning 'dim', got: {}",
327            err.0
328        );
329    }
330
331    // ----------------------------------------------------------------
332    // parse errors — unknown parameter
333    // ----------------------------------------------------------------
334
335    #[test]
336    fn parse_error_unknown_parameter() {
337        let err =
338            VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "foo=bar"]).unwrap_err();
339        assert!(
340            err.0.contains("unknown") && err.0.contains("foo"),
341            "unexpected error: {}",
342            err.0
343        );
344    }
345
346    // ----------------------------------------------------------------
347    // parse errors — arg without '='
348    // ----------------------------------------------------------------
349
350    #[test]
351    fn parse_error_arg_without_equals() {
352        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "invalidarg"])
353            .unwrap_err();
354        assert!(
355            err.0.contains("invalid argument") || err.0.contains("invalidarg"),
356            "unexpected error: {}",
357            err.0
358        );
359    }
360
361    // ----------------------------------------------------------------
362    // vtab_schema — no metadata
363    // ----------------------------------------------------------------
364
365    #[test]
366    fn vtab_schema_no_metadata() {
367        let cfg = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=3"]).unwrap();
368        let schema = cfg.vtab_schema();
369        assert_eq!(
370            schema,
371            "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, distance REAL HIDDEN)"
372        );
373    }
374
375    // ----------------------------------------------------------------
376    // vtab_schema — with metadata columns
377    // ----------------------------------------------------------------
378
379    #[test]
380    fn vtab_schema_with_metadata_columns_before_distance() {
381        let cfg = VectorTableConfig::parse(&[
382            "vector",
383            "main",
384            "tbl",
385            "dim=4",
386            "metadata=label TEXT,score REAL",
387        ])
388        .unwrap();
389        let schema = cfg.vtab_schema();
390        assert_eq!(
391            schema,
392            "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, label TEXT, score REAL, distance REAL HIDDEN)"
393        );
394        // Verify ordering: metadata must appear before distance.
395        let label_pos = schema.find("label").unwrap();
396        let distance_pos = schema.find("distance").unwrap();
397        assert!(
398            label_pos < distance_pos,
399            "metadata columns must precede distance in schema"
400        );
401    }
402}