1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use thiserror::Error;
4
5pub const DEFAULT_TOP_K: usize = 5;
6pub const DEFAULT_ALPHA: f32 = 0.65;
7pub const DEFAULT_CANDIDATE_LIMIT: usize = 1000;
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum QueryProfile {
12 Latency,
13 #[default]
14 Balanced,
15 Recall,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
19pub struct SqlRequest {
20 pub statement: String,
21}
22
23impl SqlRequest {
24 pub fn validate(&self) -> Result<(), ValidationError> {
25 if self.statement.trim().is_empty() {
26 return Err(ValidationError::EmptyStatement);
27 }
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
33pub struct QueryRequest {
34 pub query_text: Option<String>,
35 pub query_embedding: Option<Vec<f32>>,
36 pub top_k: Option<usize>,
37 pub alpha: Option<f32>,
38 pub candidate_limit: Option<usize>,
39 pub include_payloads: Option<bool>,
40 pub query_profile: Option<String>,
41 pub metadata_filters: Option<HashMap<String, String>>,
42 pub doc_id: Option<String>,
43}
44
45impl QueryRequest {
46 pub fn top_k_or_default(&self) -> usize {
47 self.top_k.unwrap_or(DEFAULT_TOP_K)
48 }
49
50 pub fn alpha_or_default(&self) -> f32 {
51 self.alpha.unwrap_or(DEFAULT_ALPHA)
52 }
53
54 pub fn candidate_limit_or_default(&self) -> usize {
55 self.candidate_limit.unwrap_or(DEFAULT_CANDIDATE_LIMIT)
56 }
57
58 pub fn include_payloads_or_default(&self) -> bool {
59 self.include_payloads.unwrap_or(true)
60 }
61
62 pub fn query_profile_or_default(&self) -> Result<QueryProfile, ValidationError> {
63 match self
64 .query_profile
65 .as_deref()
66 .map(str::trim)
67 .filter(|value| !value.is_empty())
68 {
69 None => Ok(QueryProfile::Balanced),
70 Some("balanced") => Ok(QueryProfile::Balanced),
71 Some("latency") => Ok(QueryProfile::Latency),
72 Some("recall") => Ok(QueryProfile::Recall),
73 Some(other) => Err(ValidationError::InvalidQueryProfile(other.to_string())),
74 }
75 }
76
77 pub fn normalized_query_text(&self) -> Option<String> {
78 self.query_text
79 .as_ref()
80 .map(|value| value.trim())
81 .filter(|value| !value.is_empty())
82 .map(str::to_string)
83 }
84
85 pub fn normalized_doc_id(&self) -> Option<String> {
86 self.doc_id
87 .as_ref()
88 .map(|value| value.trim())
89 .filter(|value| !value.is_empty())
90 .map(str::to_string)
91 }
92
93 pub fn normalized_query_embedding(&self) -> Option<Vec<f32>> {
94 self.query_embedding
95 .as_ref()
96 .filter(|value| !value.is_empty())
97 .cloned()
98 }
99
100 pub fn normalized_metadata_filters(&self) -> HashMap<String, String> {
101 self.metadata_filters.clone().unwrap_or_default()
102 }
103
104 pub fn validate(&self) -> Result<(), ValidationError> {
105 let has_query_text = self
106 .query_text
107 .as_ref()
108 .map(|value| !value.trim().is_empty())
109 .unwrap_or(false);
110 let has_query_embedding = self
111 .query_embedding
112 .as_ref()
113 .map(|value| !value.is_empty())
114 .unwrap_or(false);
115
116 if !has_query_text && !has_query_embedding {
117 return Err(ValidationError::MissingQuery);
118 }
119
120 let top_k = self.top_k_or_default();
121 if top_k == 0 {
122 return Err(ValidationError::InvalidTopK);
123 }
124
125 let candidate_limit = self.candidate_limit_or_default();
126 if candidate_limit == 0 {
127 return Err(ValidationError::InvalidCandidateLimit);
128 }
129 if candidate_limit < top_k {
130 return Err(ValidationError::CandidateLimitTooSmall {
131 top_k,
132 candidate_limit,
133 });
134 }
135
136 let alpha = self.alpha_or_default();
137 if !(0.0..=1.0).contains(&alpha) {
138 return Err(ValidationError::InvalidAlpha(alpha));
139 }
140
141 self.query_profile_or_default()?;
142
143 Ok(())
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148pub struct QueryEnvelope<T> {
149 pub kind: String,
150 pub row_count: usize,
151 pub rows: Vec<T>,
152}
153
154impl<T> QueryEnvelope<T> {
155 pub fn from_rows(rows: Vec<T>) -> Self {
156 Self {
157 row_count: rows.len(),
158 rows,
159 kind: "query".to_string(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Error, PartialEq)]
165pub enum ValidationError {
166 #[error("statement cannot be empty")]
167 EmptyStatement,
168 #[error("query_text or query_embedding is required")]
169 MissingQuery,
170 #[error("top_k must be >= 1")]
171 InvalidTopK,
172 #[error("candidate_limit must be >= 1")]
173 InvalidCandidateLimit,
174 #[error("candidate_limit ({candidate_limit}) must be >= top_k ({top_k})")]
175 CandidateLimitTooSmall {
176 top_k: usize,
177 candidate_limit: usize,
178 },
179 #[error("alpha must be between 0.0 and 1.0 (received {0})")]
180 InvalidAlpha(f32),
181 #[error("query_profile must be one of balanced|latency|recall (received {0})")]
182 InvalidQueryProfile(String),
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn query_request_requires_text_or_embedding() {
191 let request = QueryRequest::default();
192 assert_eq!(request.validate(), Err(ValidationError::MissingQuery));
193 }
194
195 #[test]
196 fn query_request_rejects_candidate_limit_smaller_than_top_k() {
197 let request = QueryRequest {
198 query_text: Some("agent".to_string()),
199 top_k: Some(10),
200 candidate_limit: Some(2),
201 ..QueryRequest::default()
202 };
203 assert_eq!(
204 request.validate(),
205 Err(ValidationError::CandidateLimitTooSmall {
206 top_k: 10,
207 candidate_limit: 2,
208 })
209 );
210 }
211
212 #[test]
213 fn query_request_accepts_defaulted_values() {
214 let request = QueryRequest {
215 query_text: Some("agent".to_string()),
216 ..QueryRequest::default()
217 };
218 assert_eq!(request.top_k_or_default(), DEFAULT_TOP_K);
219 assert_eq!(
220 request.candidate_limit_or_default(),
221 DEFAULT_CANDIDATE_LIMIT
222 );
223 assert!((request.alpha_or_default() - DEFAULT_ALPHA).abs() < f32::EPSILON);
224 assert!(request.validate().is_ok());
225 }
226
227 #[test]
228 fn query_request_rejects_unknown_query_profile() {
229 let request = QueryRequest {
230 query_text: Some("agent".to_string()),
231 query_profile: Some("speed".to_string()),
232 ..QueryRequest::default()
233 };
234 assert_eq!(
235 request.validate(),
236 Err(ValidationError::InvalidQueryProfile("speed".to_string()))
237 );
238 }
239
240 #[test]
241 fn sql_request_rejects_blank_statement() {
242 let request = SqlRequest {
243 statement: " ".to_string(),
244 };
245 assert_eq!(request.validate(), Err(ValidationError::EmptyStatement));
246 }
247}