1use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use thiserror::Error;
30
31pub const DEFAULT_TOP_K: usize = 5;
32pub const DEFAULT_ALPHA: f32 = 0.65;
33pub const DEFAULT_CANDIDATE_LIMIT: usize = 1000;
34
35#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
36#[serde(rename_all = "snake_case")]
37pub enum QueryProfile {
38 Latency,
39 #[default]
40 Balanced,
41 Recall,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
45pub struct SqlRequest {
46 pub statement: String,
47}
48
49impl SqlRequest {
50 pub fn validate(&self) -> Result<(), ValidationError> {
51 if self.statement.trim().is_empty() {
52 return Err(ValidationError::EmptyStatement);
53 }
54 Ok(())
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
59pub struct QueryRequest {
60 pub query_text: Option<String>,
61 pub query_embedding: Option<Vec<f32>>,
62 pub top_k: Option<usize>,
63 pub alpha: Option<f32>,
64 pub candidate_limit: Option<usize>,
65 pub include_payloads: Option<bool>,
66 pub query_profile: Option<String>,
67 pub metadata_filters: Option<HashMap<String, String>>,
68 pub doc_id: Option<String>,
69}
70
71impl QueryRequest {
72 pub fn top_k_or_default(&self) -> usize {
73 self.top_k.unwrap_or(DEFAULT_TOP_K)
74 }
75
76 pub fn alpha_or_default(&self) -> f32 {
77 self.alpha.unwrap_or(DEFAULT_ALPHA)
78 }
79
80 pub fn candidate_limit_or_default(&self) -> usize {
81 self.candidate_limit.unwrap_or(DEFAULT_CANDIDATE_LIMIT)
82 }
83
84 pub fn include_payloads_or_default(&self) -> bool {
85 self.include_payloads.unwrap_or(true)
86 }
87
88 pub fn query_profile_or_default(&self) -> Result<QueryProfile, ValidationError> {
89 match self
90 .query_profile
91 .as_deref()
92 .map(str::trim)
93 .filter(|value| !value.is_empty())
94 {
95 None => Ok(QueryProfile::Balanced),
96 Some("balanced") => Ok(QueryProfile::Balanced),
97 Some("latency") => Ok(QueryProfile::Latency),
98 Some("recall") => Ok(QueryProfile::Recall),
99 Some(other) => Err(ValidationError::InvalidQueryProfile(other.to_string())),
100 }
101 }
102
103 pub fn normalized_query_text(&self) -> Option<String> {
104 self.query_text
105 .as_ref()
106 .map(|value| value.trim())
107 .filter(|value| !value.is_empty())
108 .map(str::to_string)
109 }
110
111 pub fn normalized_doc_id(&self) -> Option<String> {
112 self.doc_id
113 .as_ref()
114 .map(|value| value.trim())
115 .filter(|value| !value.is_empty())
116 .map(str::to_string)
117 }
118
119 pub fn normalized_query_embedding(&self) -> Option<Vec<f32>> {
120 self.query_embedding
121 .as_ref()
122 .filter(|value| !value.is_empty())
123 .cloned()
124 }
125
126 pub fn normalized_metadata_filters(&self) -> HashMap<String, String> {
127 self.metadata_filters.clone().unwrap_or_default()
128 }
129
130 pub fn validate(&self) -> Result<(), ValidationError> {
131 let has_query_text = self
132 .query_text
133 .as_ref()
134 .map(|value| !value.trim().is_empty())
135 .unwrap_or(false);
136 let has_query_embedding = self
137 .query_embedding
138 .as_ref()
139 .map(|value| !value.is_empty())
140 .unwrap_or(false);
141
142 if !has_query_text && !has_query_embedding {
143 return Err(ValidationError::MissingQuery);
144 }
145
146 let top_k = self.top_k_or_default();
147 if top_k == 0 {
148 return Err(ValidationError::InvalidTopK);
149 }
150
151 let candidate_limit = self.candidate_limit_or_default();
152 if candidate_limit == 0 {
153 return Err(ValidationError::InvalidCandidateLimit);
154 }
155 if candidate_limit < top_k {
156 return Err(ValidationError::CandidateLimitTooSmall {
157 top_k,
158 candidate_limit,
159 });
160 }
161
162 let alpha = self.alpha_or_default();
163 if !(0.0..=1.0).contains(&alpha) {
164 return Err(ValidationError::InvalidAlpha(alpha));
165 }
166
167 self.query_profile_or_default()?;
168
169 Ok(())
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
174pub struct QueryEnvelope<T> {
175 pub kind: String,
176 pub row_count: usize,
177 pub rows: Vec<T>,
178}
179
180impl<T> QueryEnvelope<T> {
181 pub fn from_rows(rows: Vec<T>) -> Self {
182 Self {
183 row_count: rows.len(),
184 rows,
185 kind: "query".to_string(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, Error, PartialEq)]
191pub enum ValidationError {
192 #[error("statement cannot be empty")]
193 EmptyStatement,
194 #[error("query_text or query_embedding is required")]
195 MissingQuery,
196 #[error("top_k must be >= 1")]
197 InvalidTopK,
198 #[error("candidate_limit must be >= 1")]
199 InvalidCandidateLimit,
200 #[error("candidate_limit ({candidate_limit}) must be >= top_k ({top_k})")]
201 CandidateLimitTooSmall {
202 top_k: usize,
203 candidate_limit: usize,
204 },
205 #[error("alpha must be between 0.0 and 1.0 (received {0})")]
206 InvalidAlpha(f32),
207 #[error("query_profile must be one of balanced|latency|recall (received {0})")]
208 InvalidQueryProfile(String),
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn query_request_requires_text_or_embedding() {
217 let request = QueryRequest::default();
218 assert_eq!(request.validate(), Err(ValidationError::MissingQuery));
219 }
220
221 #[test]
222 fn query_request_rejects_candidate_limit_smaller_than_top_k() {
223 let request = QueryRequest {
224 query_text: Some("agent".to_string()),
225 top_k: Some(10),
226 candidate_limit: Some(2),
227 ..QueryRequest::default()
228 };
229 assert_eq!(
230 request.validate(),
231 Err(ValidationError::CandidateLimitTooSmall {
232 top_k: 10,
233 candidate_limit: 2,
234 })
235 );
236 }
237
238 #[test]
239 fn query_request_accepts_defaulted_values() {
240 let request = QueryRequest {
241 query_text: Some("agent".to_string()),
242 ..QueryRequest::default()
243 };
244 assert_eq!(request.top_k_or_default(), DEFAULT_TOP_K);
245 assert_eq!(
246 request.candidate_limit_or_default(),
247 DEFAULT_CANDIDATE_LIMIT
248 );
249 assert!((request.alpha_or_default() - DEFAULT_ALPHA).abs() < f32::EPSILON);
250 assert!(request.validate().is_ok());
251 }
252
253 #[test]
254 fn query_request_rejects_unknown_query_profile() {
255 let request = QueryRequest {
256 query_text: Some("agent".to_string()),
257 query_profile: Some("speed".to_string()),
258 ..QueryRequest::default()
259 };
260 assert_eq!(
261 request.validate(),
262 Err(ValidationError::InvalidQueryProfile("speed".to_string()))
263 );
264 }
265
266 #[test]
267 fn sql_request_rejects_blank_statement() {
268 let request = SqlRequest {
269 statement: " ".to_string(),
270 };
271 assert_eq!(request.validate(), Err(ValidationError::EmptyStatement));
272 }
273}