1use std::fmt;
2
3use crate::distance::DistanceMetric;
4use crate::index::HnswParams;
5use crate::types::VectorType;
6
7#[derive(Debug)]
8pub struct ConfigError(pub String);
9
10impl fmt::Display for ConfigError {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 write!(f, "config error: {}", self.0)
13 }
14}
15
16impl std::error::Error for ConfigError {}
17
18#[derive(Debug, Clone)]
20pub struct VectorTableConfig {
21 pub db_name: String,
22 pub table_name: String,
23 pub dim: usize,
24 pub vtype: VectorType,
25 pub metric: DistanceMetric,
26 pub hnsw_params: HnswParams,
27 pub metadata_columns: Vec<(String, String)>,
28}
29
30impl VectorTableConfig {
31 pub fn parse(args: &[&str]) -> Result<Self, ConfigError> {
32 if args.len() < 3 {
33 return Err(ConfigError(
34 "expected at least module, db, and table name".into(),
35 ));
36 }
37
38 let db_name = args[1].to_string();
39 let table_name = args[2].to_string();
40
41 let mut dim: Option<usize> = None;
42 let mut vtype = VectorType::Float4;
43 let mut metric = DistanceMetric::L2;
44 let mut hnsw_params = HnswParams::default();
45 let mut metadata_columns = Vec::new();
46
47 for &arg in &args[3..] {
48 let (key, value) = arg
49 .split_once('=')
50 .ok_or_else(|| ConfigError(format!("invalid argument: {arg}")))?;
51 let key = key.trim();
52 let value = value.trim().trim_matches('"');
53
54 match key {
55 "dim" => {
56 let d: i64 = value
57 .parse()
58 .map_err(|_| ConfigError(format!("invalid dim: {value}")))?;
59 if d <= 0 {
60 return Err(ConfigError(format!("dim must be positive, got {d}")));
61 }
62 dim = Some(d as usize);
63 }
64 "type" => {
65 vtype = VectorType::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
66 }
67 "metric" => {
68 metric =
69 DistanceMetric::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
70 }
71 "m" => {
72 hnsw_params.m = value
73 .parse()
74 .map_err(|_| ConfigError(format!("invalid m: {value}")))?;
75 }
76 "ef_construction" => {
77 hnsw_params.ef_construction = value
78 .parse()
79 .map_err(|_| ConfigError(format!("invalid ef_construction: {value}")))?;
80 }
81 "ef_search" => {
82 hnsw_params.ef_search = value
83 .parse()
84 .map_err(|_| ConfigError(format!("invalid ef_search: {value}")))?;
85 }
86 "metadata" => {
87 metadata_columns = parse_metadata_columns(value)?;
88 }
89 other => {
90 return Err(ConfigError(format!("unknown parameter: {other}")));
91 }
92 }
93 }
94
95 let dim = dim.ok_or_else(|| ConfigError("dim is required".into()))?;
96
97 Ok(Self {
98 db_name,
99 table_name,
100 dim,
101 vtype,
102 metric,
103 hnsw_params,
104 metadata_columns,
105 })
106 }
107
108 pub fn vtab_schema(&self) -> String {
109 let mut cols = vec![
110 "id INTEGER PRIMARY KEY".to_string(),
111 "vector BLOB".to_string(),
112 ];
113 for (name, sql_type) in &self.metadata_columns {
114 cols.push(format!("{name} {sql_type}"));
115 }
116 cols.push("distance REAL HIDDEN".to_string());
117 format!("CREATE TABLE x({})", cols.join(", "))
118 }
119}
120
121fn parse_metadata_columns(spec: &str) -> Result<Vec<(String, String)>, ConfigError> {
122 let mut columns = Vec::new();
123 for part in spec.split(',') {
124 let part = part.trim();
125 if part.is_empty() {
126 continue;
127 }
128 let mut tokens = part.split_whitespace();
129 let name = tokens
130 .next()
131 .ok_or_else(|| ConfigError("empty metadata column definition".to_string()))?
132 .to_string();
133 let sql_type = tokens
134 .next()
135 .ok_or_else(|| ConfigError(format!("missing type for metadata column {name}")))?
136 .to_string();
137 columns.push((name, sql_type));
138 }
139 Ok(columns)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::distance::DistanceMetric;
146 use crate::types::VectorType;
147
148 #[test]
153 fn parse_minimal_args_uses_defaults() {
154 let cfg = VectorTableConfig::parse(&["vector", "main", "test", "dim=3"]).unwrap();
155 assert_eq!(cfg.db_name, "main");
156 assert_eq!(cfg.table_name, "test");
157 assert_eq!(cfg.dim, 3);
158 assert_eq!(cfg.vtype, VectorType::Float4);
159 assert_eq!(cfg.metric, DistanceMetric::L2);
160 assert_eq!(cfg.hnsw_params.m, 16);
162 assert_eq!(cfg.hnsw_params.ef_construction, 200);
163 assert_eq!(cfg.hnsw_params.ef_search, 64);
164 assert!(cfg.metadata_columns.is_empty());
165 }
166
167 #[test]
172 fn parse_all_params_specified() {
173 let cfg = VectorTableConfig::parse(&[
174 "vector",
175 "main",
176 "embeddings",
177 "dim=128",
178 "type=float4",
179 "metric=l2",
180 "m=32",
181 "ef_construction=400",
182 "ef_search=128",
183 ])
184 .unwrap();
185 assert_eq!(cfg.dim, 128);
186 assert_eq!(cfg.vtype, VectorType::Float4);
187 assert_eq!(cfg.metric, DistanceMetric::L2);
188 assert_eq!(cfg.hnsw_params.m, 32);
189 assert_eq!(cfg.hnsw_params.ef_construction, 400);
190 assert_eq!(cfg.hnsw_params.ef_search, 128);
191 }
192
193 #[test]
198 fn parse_float8_cosine() {
199 let cfg = VectorTableConfig::parse(&[
200 "vector",
201 "main",
202 "vecs",
203 "dim=64",
204 "type=float8",
205 "metric=cosine",
206 ])
207 .unwrap();
208 assert_eq!(cfg.vtype, VectorType::Float8);
209 assert_eq!(cfg.metric, DistanceMetric::Cosine);
210 }
211
212 #[test]
217 fn parse_custom_hnsw_params() {
218 let cfg = VectorTableConfig::parse(&[
219 "vector",
220 "main",
221 "idx",
222 "dim=16",
223 "m=8",
224 "ef_construction=100",
225 "ef_search=50",
226 ])
227 .unwrap();
228 assert_eq!(cfg.hnsw_params.m, 8);
229 assert_eq!(cfg.hnsw_params.ef_construction, 100);
230 assert_eq!(cfg.hnsw_params.ef_search, 50);
231 }
232
233 #[test]
238 fn parse_metadata_columns_parsed_correctly() {
239 let cfg = VectorTableConfig::parse(&[
240 "vector",
241 "main",
242 "docs",
243 "dim=4",
244 "metadata=label TEXT,score REAL",
245 ])
246 .unwrap();
247 assert_eq!(cfg.metadata_columns.len(), 2);
248 assert_eq!(
249 cfg.metadata_columns[0],
250 ("label".to_string(), "TEXT".to_string())
251 );
252 assert_eq!(
253 cfg.metadata_columns[1],
254 ("score".to_string(), "REAL".to_string())
255 );
256 }
257
258 #[test]
263 fn parse_error_too_few_args_zero() {
264 let err = VectorTableConfig::parse(&[]).unwrap_err();
265 assert!(
266 err.0.contains("at least"),
267 "unexpected error message: {}",
268 err.0
269 );
270 }
271
272 #[test]
273 fn parse_error_too_few_args_two() {
274 let err = VectorTableConfig::parse(&["vector", "main"]).unwrap_err();
276 assert!(
277 err.0.contains("at least"),
278 "unexpected error message: {}",
279 err.0
280 );
281 }
282
283 #[test]
288 fn parse_error_missing_dim() {
289 let err = VectorTableConfig::parse(&["vector", "main", "tbl", "type=float4"]).unwrap_err();
290 assert!(
291 err.0.contains("dim"),
292 "expected error mentioning 'dim', got: {}",
293 err.0
294 );
295 }
296
297 #[test]
302 fn parse_error_dim_zero() {
303 let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=0"]).unwrap_err();
304 assert!(
305 err.0.contains("positive") || err.0.contains("dim"),
306 "unexpected error: {}",
307 err.0
308 );
309 }
310
311 #[test]
312 fn parse_error_dim_negative() {
313 let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=-5"]).unwrap_err();
314 assert!(
315 err.0.contains("positive") || err.0.contains("dim"),
316 "unexpected error: {}",
317 err.0
318 );
319 }
320
321 #[test]
322 fn parse_error_dim_non_numeric() {
323 let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=abc"]).unwrap_err();
324 assert!(
325 err.0.contains("dim"),
326 "expected error mentioning 'dim', got: {}",
327 err.0
328 );
329 }
330
331 #[test]
336 fn parse_error_unknown_parameter() {
337 let err =
338 VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "foo=bar"]).unwrap_err();
339 assert!(
340 err.0.contains("unknown") && err.0.contains("foo"),
341 "unexpected error: {}",
342 err.0
343 );
344 }
345
346 #[test]
351 fn parse_error_arg_without_equals() {
352 let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "invalidarg"])
353 .unwrap_err();
354 assert!(
355 err.0.contains("invalid argument") || err.0.contains("invalidarg"),
356 "unexpected error: {}",
357 err.0
358 );
359 }
360
361 #[test]
366 fn vtab_schema_no_metadata() {
367 let cfg = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=3"]).unwrap();
368 let schema = cfg.vtab_schema();
369 assert_eq!(
370 schema,
371 "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, distance REAL HIDDEN)"
372 );
373 }
374
375 #[test]
380 fn vtab_schema_with_metadata_columns_before_distance() {
381 let cfg = VectorTableConfig::parse(&[
382 "vector",
383 "main",
384 "tbl",
385 "dim=4",
386 "metadata=label TEXT,score REAL",
387 ])
388 .unwrap();
389 let schema = cfg.vtab_schema();
390 assert_eq!(
391 schema,
392 "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, label TEXT, score REAL, distance REAL HIDDEN)"
393 );
394 let label_pos = schema.find("label").unwrap();
396 let distance_pos = schema.find("distance").unwrap();
397 assert!(
398 label_pos < distance_pos,
399 "metadata columns must precede distance in schema"
400 );
401 }
402}