Skip to main content

sqlrite_sdk_core/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5pub const DEFAULT_TOP_K: usize = 5;
6pub const DEFAULT_ALPHA: f32 = 0.65;
7pub const DEFAULT_CANDIDATE_LIMIT: usize = 1000;
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum QueryProfile {
12    Latency,
13    #[default]
14    Balanced,
15    Recall,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
19pub struct SqlRequest {
20    pub statement: String,
21}
22
23impl SqlRequest {
24    pub fn validate(&self) -> Result<(), ValidationError> {
25        if self.statement.trim().is_empty() {
26            return Err(ValidationError::EmptyStatement);
27        }
28        Ok(())
29    }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
33pub struct QueryRequest {
34    pub query_text: Option<String>,
35    pub query_embedding: Option<Vec<f32>>,
36    pub top_k: Option<usize>,
37    pub alpha: Option<f32>,
38    pub candidate_limit: Option<usize>,
39    pub include_payloads: Option<bool>,
40    pub query_profile: Option<String>,
41    pub metadata_filters: Option<HashMap<String, String>>,
42    pub doc_id: Option<String>,
43}
44
45impl QueryRequest {
46    pub fn top_k_or_default(&self) -> usize {
47        self.top_k.unwrap_or(DEFAULT_TOP_K)
48    }
49
50    pub fn alpha_or_default(&self) -> f32 {
51        self.alpha.unwrap_or(DEFAULT_ALPHA)
52    }
53
54    pub fn candidate_limit_or_default(&self) -> usize {
55        self.candidate_limit.unwrap_or(DEFAULT_CANDIDATE_LIMIT)
56    }
57
58    pub fn include_payloads_or_default(&self) -> bool {
59        self.include_payloads.unwrap_or(true)
60    }
61
62    pub fn query_profile_or_default(&self) -> Result<QueryProfile, ValidationError> {
63        match self
64            .query_profile
65            .as_deref()
66            .map(str::trim)
67            .filter(|value| !value.is_empty())
68        {
69            None => Ok(QueryProfile::Balanced),
70            Some("balanced") => Ok(QueryProfile::Balanced),
71            Some("latency") => Ok(QueryProfile::Latency),
72            Some("recall") => Ok(QueryProfile::Recall),
73            Some(other) => Err(ValidationError::InvalidQueryProfile(other.to_string())),
74        }
75    }
76
77    pub fn normalized_query_text(&self) -> Option<String> {
78        self.query_text
79            .as_ref()
80            .map(|value| value.trim())
81            .filter(|value| !value.is_empty())
82            .map(str::to_string)
83    }
84
85    pub fn normalized_doc_id(&self) -> Option<String> {
86        self.doc_id
87            .as_ref()
88            .map(|value| value.trim())
89            .filter(|value| !value.is_empty())
90            .map(str::to_string)
91    }
92
93    pub fn normalized_query_embedding(&self) -> Option<Vec<f32>> {
94        self.query_embedding
95            .as_ref()
96            .filter(|value| !value.is_empty())
97            .cloned()
98    }
99
100    pub fn normalized_metadata_filters(&self) -> HashMap<String, String> {
101        self.metadata_filters.clone().unwrap_or_default()
102    }
103
104    pub fn validate(&self) -> Result<(), ValidationError> {
105        let has_query_text = self
106            .query_text
107            .as_ref()
108            .map(|value| !value.trim().is_empty())
109            .unwrap_or(false);
110        let has_query_embedding = self
111            .query_embedding
112            .as_ref()
113            .map(|value| !value.is_empty())
114            .unwrap_or(false);
115
116        if !has_query_text && !has_query_embedding {
117            return Err(ValidationError::MissingQuery);
118        }
119
120        let top_k = self.top_k_or_default();
121        if top_k == 0 {
122            return Err(ValidationError::InvalidTopK);
123        }
124
125        let candidate_limit = self.candidate_limit_or_default();
126        if candidate_limit == 0 {
127            return Err(ValidationError::InvalidCandidateLimit);
128        }
129        if candidate_limit < top_k {
130            return Err(ValidationError::CandidateLimitTooSmall {
131                top_k,
132                candidate_limit,
133            });
134        }
135
136        let alpha = self.alpha_or_default();
137        if !(0.0..=1.0).contains(&alpha) {
138            return Err(ValidationError::InvalidAlpha(alpha));
139        }
140
141        self.query_profile_or_default()?;
142
143        Ok(())
144    }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148pub struct QueryEnvelope<T> {
149    pub kind: String,
150    pub row_count: usize,
151    pub rows: Vec<T>,
152}
153
154impl<T> QueryEnvelope<T> {
155    pub fn from_rows(rows: Vec<T>) -> Self {
156        Self {
157            row_count: rows.len(),
158            rows,
159            kind: "query".to_string(),
160        }
161    }
162}
163
164#[derive(Debug, Clone, Error, PartialEq)]
165pub enum ValidationError {
166    #[error("statement cannot be empty")]
167    EmptyStatement,
168    #[error("query_text or query_embedding is required")]
169    MissingQuery,
170    #[error("top_k must be >= 1")]
171    InvalidTopK,
172    #[error("candidate_limit must be >= 1")]
173    InvalidCandidateLimit,
174    #[error("candidate_limit ({candidate_limit}) must be >= top_k ({top_k})")]
175    CandidateLimitTooSmall {
176        top_k: usize,
177        candidate_limit: usize,
178    },
179    #[error("alpha must be between 0.0 and 1.0 (received {0})")]
180    InvalidAlpha(f32),
181    #[error("query_profile must be one of balanced|latency|recall (received {0})")]
182    InvalidQueryProfile(String),
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn query_request_requires_text_or_embedding() {
191        let request = QueryRequest::default();
192        assert_eq!(request.validate(), Err(ValidationError::MissingQuery));
193    }
194
195    #[test]
196    fn query_request_rejects_candidate_limit_smaller_than_top_k() {
197        let request = QueryRequest {
198            query_text: Some("agent".to_string()),
199            top_k: Some(10),
200            candidate_limit: Some(2),
201            ..QueryRequest::default()
202        };
203        assert_eq!(
204            request.validate(),
205            Err(ValidationError::CandidateLimitTooSmall {
206                top_k: 10,
207                candidate_limit: 2,
208            })
209        );
210    }
211
212    #[test]
213    fn query_request_accepts_defaulted_values() {
214        let request = QueryRequest {
215            query_text: Some("agent".to_string()),
216            ..QueryRequest::default()
217        };
218        assert_eq!(request.top_k_or_default(), DEFAULT_TOP_K);
219        assert_eq!(
220            request.candidate_limit_or_default(),
221            DEFAULT_CANDIDATE_LIMIT
222        );
223        assert!((request.alpha_or_default() - DEFAULT_ALPHA).abs() < f32::EPSILON);
224        assert!(request.validate().is_ok());
225    }
226
227    #[test]
228    fn query_request_rejects_unknown_query_profile() {
229        let request = QueryRequest {
230            query_text: Some("agent".to_string()),
231            query_profile: Some("speed".to_string()),
232            ..QueryRequest::default()
233        };
234        assert_eq!(
235            request.validate(),
236            Err(ValidationError::InvalidQueryProfile("speed".to_string()))
237        );
238    }
239
240    #[test]
241    fn sql_request_rejects_blank_statement() {
242        let request = SqlRequest {
243            statement: "   ".to_string(),
244        };
245        assert_eq!(request.validate(), Err(ValidationError::EmptyStatement));
246    }
247}