1use std::env;
75
76pub mod prompt;
82pub mod provider;
83
84#[cfg(feature = "http")]
85pub use provider::anthropic::AnthropicProvider;
86pub use provider::{Provider, Request, Response, Usage};
87
88use prompt::{CacheControl, UserMessage, build_system};
89use provider::Request as ProviderRequest;
90
91pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
95
96pub const DEFAULT_MAX_TOKENS: u32 = 1024;
101
102#[derive(Debug, Clone)]
113pub struct AskResponse {
114 pub sql: String,
115 pub explanation: String,
116 pub usage: Usage,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum CacheTtl {
131 FiveMinutes,
132 OneHour,
133 Off,
138}
139
140impl CacheTtl {
141 fn into_marker(self) -> Option<CacheControl> {
142 match self {
143 CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
144 CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
145 CacheTtl::Off => None,
146 }
147 }
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ProviderKind {
155 Anthropic,
156}
157
158impl ProviderKind {
159 fn parse(s: &str) -> Result<Self, AskError> {
160 match s.to_ascii_lowercase().as_str() {
161 "anthropic" => Ok(ProviderKind::Anthropic),
162 other => Err(AskError::UnknownProvider(other.to_string())),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
170pub struct AskConfig {
171 pub provider: ProviderKind,
172 pub api_key: Option<String>,
173 pub model: String,
174 pub max_tokens: u32,
175 pub cache_ttl: CacheTtl,
176 pub base_url: Option<String>,
179}
180
181impl Default for AskConfig {
182 fn default() -> Self {
183 Self {
184 provider: ProviderKind::Anthropic,
185 api_key: None,
186 model: DEFAULT_MODEL.to_string(),
187 max_tokens: DEFAULT_MAX_TOKENS,
188 cache_ttl: CacheTtl::FiveMinutes,
189 base_url: None,
190 }
191 }
192}
193
194impl AskConfig {
195 pub fn from_env() -> Result<Self, AskError> {
207 let mut cfg = AskConfig::default();
208 if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
209 cfg.provider = ProviderKind::parse(&p)?;
210 }
211 if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
212 if !k.is_empty() {
213 cfg.api_key = Some(k);
214 }
215 }
216 if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
217 if !m.is_empty() {
218 cfg.model = m;
219 }
220 }
221 if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
222 cfg.max_tokens = t
223 .parse()
224 .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
225 }
226 if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
227 cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
228 "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
229 "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
230 "off" | "none" | "disabled" => CacheTtl::Off,
231 other => {
232 return Err(AskError::Config(format!(
233 "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
234 )));
235 }
236 };
237 }
238 Ok(cfg)
239 }
240}
241
242#[derive(Debug, thiserror::Error)]
245pub enum AskError {
246 #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
247 MissingApiKey,
248
249 #[error("config error: {0}")]
250 Config(String),
251
252 #[error("unknown provider: {0} (supported: anthropic)")]
253 UnknownProvider(String),
254
255 #[error("HTTP transport error: {0}")]
256 Http(String),
257
258 #[error("API returned status {status}: {detail}")]
259 ApiStatus { status: u16, detail: String },
260
261 #[error("API returned no text content")]
262 EmptyResponse,
263
264 #[error("model output not valid JSON: {0}")]
265 OutputNotJson(String),
266
267 #[error("model output JSON missing required field '{0}'")]
268 OutputMissingField(&'static str),
269
270 #[error("JSON serialization error: {0}")]
271 Json(#[from] serde_json::Error),
272}
273
274#[cfg(feature = "http")]
291pub fn ask_with_schema(
292 schema_dump: &str,
293 question: &str,
294 config: &AskConfig,
295) -> Result<AskResponse, AskError> {
296 let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
297
298 let provider = match config.provider {
299 ProviderKind::Anthropic => match &config.base_url {
300 Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
301 None => AnthropicProvider::new(api_key),
302 },
303 };
304
305 ask_with_schema_and_provider(schema_dump, question, config, &provider)
306}
307
308pub fn ask_with_schema_and_provider<P: Provider>(
318 schema_dump: &str,
319 question: &str,
320 config: &AskConfig,
321 provider: &P,
322) -> Result<AskResponse, AskError> {
323 let system = build_system(schema_dump, config.cache_ttl.into_marker());
324 let messages = [UserMessage::new(question)];
325
326 let req = ProviderRequest {
327 model: &config.model,
328 max_tokens: config.max_tokens,
329 system: &system,
330 messages: &messages,
331 };
332
333 let resp = provider.complete(req)?;
334 parse_response(&resp.text, resp.usage)
335}
336
337pub fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
350 let trimmed = raw.trim();
352 let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
353
354 if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
356 return extract_fields(&value, usage);
357 }
358
359 if let Some(json_block) = extract_first_json_object(body) {
364 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
365 return extract_fields(&value, usage);
366 }
367 }
368
369 Err(AskError::OutputNotJson(raw.to_string()))
370}
371
372fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
373 let sql = value
374 .get("sql")
375 .and_then(|v| v.as_str())
376 .ok_or(AskError::OutputMissingField("sql"))?
377 .trim()
378 .trim_end_matches(';')
379 .to_string();
380 let explanation = value
381 .get("explanation")
382 .and_then(|v| v.as_str())
383 .unwrap_or("")
384 .to_string();
385 Ok(AskResponse {
386 sql,
387 explanation,
388 usage,
389 })
390}
391
392fn strip_markdown_fence(s: &str) -> Option<&str> {
393 let s = s.trim();
394 let opening_variants = ["```json\n", "```JSON\n", "```\n"];
395 for opener in opening_variants {
396 if let Some(rest) = s.strip_prefix(opener) {
397 let body = rest.trim_end();
399 let body = body.strip_suffix("```").unwrap_or(body);
400 return Some(body.trim());
401 }
402 }
403 None
404}
405
406fn extract_first_json_object(s: &str) -> Option<String> {
407 let bytes = s.as_bytes();
408 let start = s.find('{')?;
409 let mut depth = 0_i32;
410 let mut in_string = false;
411 let mut escape = false;
412 for (i, &b) in bytes.iter().enumerate().skip(start) {
413 if escape {
414 escape = false;
415 continue;
416 }
417 match b {
418 b'\\' if in_string => escape = true,
419 b'"' => in_string = !in_string,
420 b'{' if !in_string => depth += 1,
421 b'}' if !in_string => {
422 depth -= 1;
423 if depth == 0 {
424 return Some(s[start..=i].to_string());
425 }
426 }
427 _ => {}
428 }
429 }
430 None
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::provider::MockProvider;
437
438 const FIXTURE_SCHEMA: &str = "\
445CREATE TABLE users (
446 id INTEGER PRIMARY KEY,
447 name TEXT
448);
449";
450
451 fn cfg() -> AskConfig {
452 AskConfig {
453 api_key: Some("test-key".to_string()),
454 ..AskConfig::default()
455 }
456 }
457
458 #[test]
459 fn ask_with_mock_provider_returns_parsed_sql() {
460 let provider = MockProvider::new(
461 r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
462 );
463 let resp =
464 ask_with_schema_and_provider(FIXTURE_SCHEMA, "how many users?", &cfg(), &provider)
465 .unwrap();
466 assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
467 assert_eq!(resp.explanation, "counts users");
468 }
469
470 #[test]
471 fn schema_dump_appears_in_system_block() {
472 let schema = "CREATE TABLE widgets (\n id INTEGER PRIMARY KEY,\n name TEXT\n);\n";
473 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
474 let _ = ask_with_schema_and_provider(schema, "anything", &cfg(), &provider).unwrap();
475
476 let captured = provider.last_request.borrow().clone().unwrap();
477 let schema_block = &captured.system_blocks[1];
478 assert!(
479 schema_block.contains("CREATE TABLE widgets"),
480 "got: {schema_block}"
481 );
482 assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
483 }
484
485 #[test]
486 fn cache_ttl_off_omits_cache_control() {
487 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
488 let mut config = cfg();
489 config.cache_ttl = CacheTtl::Off;
490 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &config, &provider).unwrap();
491 let captured = provider.last_request.borrow().clone().unwrap();
492 assert!(!captured.schema_block_has_cache_control);
493 }
494
495 #[test]
496 fn cache_ttl_5m_sets_cache_control() {
497 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
498 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &cfg(), &provider).unwrap();
499 let captured = provider.last_request.borrow().clone().unwrap();
500 assert!(captured.schema_block_has_cache_control);
501 }
502
503 #[test]
504 fn user_question_arrives_in_messages_unchanged() {
505 let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
506 let q = "Find users with email containing '@example.com'";
507 let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, q, &cfg(), &provider).unwrap();
508 assert_eq!(
509 provider
510 .last_request
511 .borrow()
512 .as_ref()
513 .unwrap()
514 .user_message,
515 q
516 );
517 }
518
519 #[test]
520 fn missing_api_key_errors_clearly() {
521 let config = AskConfig {
523 api_key: None,
524 ..AskConfig::default()
525 };
526 let err = ask_with_schema(FIXTURE_SCHEMA, "test", &config).unwrap_err();
527 match err {
528 AskError::MissingApiKey => {}
529 other => panic!("expected MissingApiKey, got {other:?}"),
530 }
531 }
532
533 #[test]
534 fn parse_response_strips_trailing_semicolon() {
535 let resp = parse_response(
536 r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
537 Usage::default(),
538 )
539 .unwrap();
540 assert_eq!(resp.sql, "SELECT 1");
541 }
542
543 #[test]
544 fn parse_response_handles_markdown_fence() {
545 let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
546 let resp = parse_response(raw, Usage::default()).unwrap();
547 assert_eq!(resp.sql, "SELECT 1");
548 }
549
550 #[test]
551 fn parse_response_handles_leading_prose() {
552 let raw =
553 "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
554 let resp = parse_response(raw, Usage::default()).unwrap();
555 assert_eq!(resp.sql, "SELECT 1");
556 }
557
558 #[test]
559 fn parse_response_rejects_non_json() {
560 let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
561 assert!(matches!(err, AskError::OutputNotJson(_)));
562 }
563
564 #[test]
565 fn parse_response_rejects_missing_sql_field() {
566 let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
567 assert!(matches!(err, AskError::OutputMissingField("sql")));
568 }
569
570 #[test]
571 fn parse_response_allows_missing_explanation() {
572 let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
573 assert_eq!(resp.sql, "SELECT 1");
574 assert_eq!(resp.explanation, "");
575 }
576
577 #[test]
578 fn parse_response_passes_usage_through() {
579 let usage = Usage {
580 input_tokens: 100,
581 output_tokens: 20,
582 cache_creation_input_tokens: 80,
583 cache_read_input_tokens: 0,
584 };
585 let resp =
586 parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
587 assert_eq!(resp.usage.input_tokens, 100);
588 assert_eq!(resp.usage.cache_creation_input_tokens, 80);
589 }
590}