Skip to main content

sqlrite_ask/
lib.rs

1//! `sqlrite-ask` — natural-language → SQL adapter for SQLRite.
2//!
3//! **Phase 7g.2 made this crate pure** — it no longer depends on
4//! `sqlrite-engine`. The canonical API takes a `&str` schema dump
5//! (built however you like — see `sqlrite::ConnectionAskExt::ask`
6//! for the engine-side helper that wraps it):
7//!
8//! ```no_run
9//! use sqlrite_ask::{ask_with_schema, AskConfig};
10//!
11//! let schema = "\
12//! CREATE TABLE users (\n  id INTEGER PRIMARY KEY,\n  name TEXT NOT NULL\n);\n";
13//! let cfg  = AskConfig::from_env()?;          // reads SQLRITE_LLM_API_KEY etc.
14//! let resp = ask_with_schema(schema, "How many users?", &cfg)?;
15//! println!("{}", resp.sql);
16//! # Ok::<(), sqlrite_ask::AskError>(())
17//! ```
18//!
19//! For the engine-integrated form (`conn.ask("...", &cfg)`), enable
20//! the `sqlrite-engine` crate's `ask` feature and bring its
21//! `ConnectionAskExt` trait into scope:
22//!
23//! ```ignore
24//! use sqlrite::{Connection, ConnectionAskExt};
25//! use sqlrite_ask::AskConfig;
26//!
27//! let conn = Connection::open("foo.sqlrite")?;
28//! let cfg  = AskConfig::from_env()?;
29//! let resp = conn.ask("How many users?", &cfg)?;
30//! ```
31//!
32//! ## Why the split (Phase 7g.2 retro)
33//!
34//! Wiring the REPL's `.ask` meta-command would have required the
35//! engine binary to depend on `sqlrite-ask`, but `sqlrite-ask`
36//! already depended on `sqlrite-engine` (for `Connection`,
37//! `Database`, `Table`). Cargo's static cycle detection rejects that
38//! shape even with `optional = true`. Solution: keep this crate pure
39//! over `&str` inputs, move the engine integration (schema dump +
40//! `ConnectionAskExt`) into `sqlrite-engine` itself behind an `ask`
41//! feature. Now there's one direction of dep flow: engine →
42//! sqlrite-ask, never the other way.
43//!
44//! ## What this crate is
45//!
46//! - Provider adapters (Anthropic now; OpenAI / Ollama later) that
47//!   POST one HTTP request to a chat-completion endpoint per `ask()`
48//!   call.
49//! - Prompt construction with a `cache_control: ephemeral`
50//!   breakpoint on the schema block, so repeat calls against the
51//!   same schema served from Anthropic's prompt cache.
52//! - Output parsing tolerant to fenced JSON / leading prose / strict
53//!   JSON (model output drifts even with strict instructions).
54//! - `AskConfig` (env vars + explicit overrides), `AskResponse {
55//!   sql, explanation, usage }`, `AskError`.
56//!
57//! ## What this crate is NOT
58//!
59//! - **Not an executor.** The caller decides whether to run the
60//!   generated SQL. SDK convenience wrappers (`Python.Connection
61//!   .ask_run`, `Node.db.askRun`, etc.) layer that on top.
62//! - **Not multi-turn.** Stateless — every call is a fresh prompt.
63//! - **Not engine-coupled.** Schema introspection lives on the
64//!   engine side as of v0.1.19 — see `sqlrite::ConnectionAskExt`.
65//!
66//! ## Configuration
67//!
68//! [`AskConfig`] resolves in this priority order:
69//! 1. Explicit values on the struct.
70//! 2. Environment variables (`SQLRITE_LLM_*`).
71//! 3. Built-in defaults (model = `claude-sonnet-4-6`, max_tokens = 1024,
72//!    cache TTL = 5 min).
73
74use std::env;
75
76mod prompt;
77mod provider;
78
79pub use provider::anthropic::AnthropicProvider;
80pub use provider::{Provider, Request, Response, Usage};
81
82use prompt::{CacheControl, UserMessage, build_system};
83use provider::Request as ProviderRequest;
84
85/// Default model — Sonnet 4.6 hits the cost-quality sweet spot for
86/// NL→SQL. Override via `AskConfig::model` or the `SQLRITE_LLM_MODEL`
87/// env var. See `docs/phase-7-plan.md` for the model-choice rationale.
88pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
89
90/// Default `max_tokens`. SQL generation rarely needs more than ~500
91/// output tokens (single-statement queries + a one-sentence
92/// explanation). 1024 leaves headroom; under the SDK timeout cap so
93/// we don't have to stream.
94pub const DEFAULT_MAX_TOKENS: u32 = 1024;
95
96/// Result returned from a successful [`ask`] call.
97///
98/// `sql` is the generated query text — empty string if the model
99/// determined the question can't be answered against the schema.
100/// `explanation` is the model's one-sentence rationale; useful in
101/// REPL "confirm before run" UIs.
102///
103/// `usage` surfaces token counts (input/output/cache hit/cache write).
104/// Inspect it to verify prompt-caching is actually working — see
105/// `docs/phase-7-plan.md` Q3-adjacent for the audit checklist.
106#[derive(Debug, Clone)]
107pub struct AskResponse {
108    pub sql: String,
109    pub explanation: String,
110    pub usage: Usage,
111}
112
113/// Cache-TTL knob exposed on [`AskConfig`].
114///
115/// Anthropic's `ephemeral` cache supports two TTLs:
116/// - **5 minutes** (default) — break-even at 2 calls per cached
117///   prefix; right for interactive REPL use where users ask a few
118///   questions in a session.
119/// - **1 hour** — costs 2× write premium instead of 1.25×; needs
120///   3+ calls per prefix to break even. Worth it for long-running
121///   editor / desktop sessions where the same DB is queried
122///   sporadically over an hour.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum CacheTtl {
125    FiveMinutes,
126    OneHour,
127    /// Disables caching — schema block is sent without a
128    /// `cache_control` marker. Useful when the schema is below the
129    /// model's minimum cacheable prefix size (~2K tokens for Sonnet,
130    /// ~4K for Haiku/Opus); marking it would be a no-op.
131    Off,
132}
133
134impl CacheTtl {
135    fn into_marker(self) -> Option<CacheControl> {
136        match self {
137            CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
138            CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
139            CacheTtl::Off => None,
140        }
141    }
142}
143
144/// Which LLM provider [`ask`] talks to. Anthropic-only in 7g.1; the
145/// enum is here so adding OpenAI/Ollama later doesn't break the
146/// `AskConfig` shape.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ProviderKind {
149    Anthropic,
150}
151
152impl ProviderKind {
153    fn parse(s: &str) -> Result<Self, AskError> {
154        match s.to_ascii_lowercase().as_str() {
155            "anthropic" => Ok(ProviderKind::Anthropic),
156            other => Err(AskError::UnknownProvider(other.to_string())),
157        }
158    }
159}
160
161/// Knobs for an `ask()` call. Construct directly, or via
162/// [`AskConfig::from_env`] to pull defaults from the environment.
163#[derive(Debug, Clone)]
164pub struct AskConfig {
165    pub provider: ProviderKind,
166    pub api_key: Option<String>,
167    pub model: String,
168    pub max_tokens: u32,
169    pub cache_ttl: CacheTtl,
170    /// Override the API base URL. Production callers leave this
171    /// `None`; tests point it at a localhost mock.
172    pub base_url: Option<String>,
173}
174
175impl Default for AskConfig {
176    fn default() -> Self {
177        Self {
178            provider: ProviderKind::Anthropic,
179            api_key: None,
180            model: DEFAULT_MODEL.to_string(),
181            max_tokens: DEFAULT_MAX_TOKENS,
182            cache_ttl: CacheTtl::FiveMinutes,
183            base_url: None,
184        }
185    }
186}
187
188impl AskConfig {
189    /// Build a config from environment variables, with built-in
190    /// defaults for anything not set.
191    ///
192    /// Recognized vars:
193    /// - `SQLRITE_LLM_PROVIDER` — `anthropic` (only currently supported)
194    /// - `SQLRITE_LLM_API_KEY` — required at call time, but a missing
195    ///   var is not an error here (lets you build a config to inspect
196    ///   without the secret loaded)
197    /// - `SQLRITE_LLM_MODEL` — overrides [`DEFAULT_MODEL`]
198    /// - `SQLRITE_LLM_MAX_TOKENS` — overrides [`DEFAULT_MAX_TOKENS`]
199    /// - `SQLRITE_LLM_CACHE_TTL` — `5m` (default) | `1h` | `off`
200    pub fn from_env() -> Result<Self, AskError> {
201        let mut cfg = AskConfig::default();
202        if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
203            cfg.provider = ProviderKind::parse(&p)?;
204        }
205        if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
206            if !k.is_empty() {
207                cfg.api_key = Some(k);
208            }
209        }
210        if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
211            if !m.is_empty() {
212                cfg.model = m;
213            }
214        }
215        if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
216            cfg.max_tokens = t
217                .parse()
218                .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
219        }
220        if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
221            cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
222                "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
223                "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
224                "off" | "none" | "disabled" => CacheTtl::Off,
225                other => {
226                    return Err(AskError::Config(format!(
227                        "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
228                    )));
229                }
230            };
231        }
232        Ok(cfg)
233    }
234}
235
236/// Errors `ask()` can return. Includes every failure mode along the
237/// path: config / network / API / parsing.
238#[derive(Debug, thiserror::Error)]
239pub enum AskError {
240    #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
241    MissingApiKey,
242
243    #[error("config error: {0}")]
244    Config(String),
245
246    #[error("unknown provider: {0} (supported: anthropic)")]
247    UnknownProvider(String),
248
249    #[error("HTTP transport error: {0}")]
250    Http(String),
251
252    #[error("API returned status {status}: {detail}")]
253    ApiStatus { status: u16, detail: String },
254
255    #[error("API returned no text content")]
256    EmptyResponse,
257
258    #[error("model output not valid JSON: {0}")]
259    OutputNotJson(String),
260
261    #[error("model output JSON missing required field '{0}'")]
262    OutputMissingField(&'static str),
263
264    #[error("JSON serialization error: {0}")]
265    Json(#[from] serde_json::Error),
266}
267
268/// One-shot natural-language → SQL.
269///
270/// You pass the schema dump as a string (typically produced by the
271/// engine's `sqlrite::ConnectionAskExt` / `dump_schema_for_database`
272/// helper, but any string format the model can read is fine) and the
273/// user's question. Returns the generated SQL plus a one-sentence
274/// rationale plus token usage for cache-hit verification.
275///
276/// The library does **not** execute the returned SQL — that's the
277/// caller's call. See module docs for rationale.
278pub fn ask_with_schema(
279    schema_dump: &str,
280    question: &str,
281    config: &AskConfig,
282) -> Result<AskResponse, AskError> {
283    let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
284
285    let provider = match config.provider {
286        ProviderKind::Anthropic => match &config.base_url {
287            Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
288            None => AnthropicProvider::new(api_key),
289        },
290    };
291
292    ask_with_schema_and_provider(schema_dump, question, config, &provider)
293}
294
295/// Lower-level entry point — same flow as [`ask_with_schema`], but
296/// you supply the provider directly.
297///
298/// Used by the test suite (which passes a `MockProvider`) and by
299/// advanced callers who want to drive a custom backend (an internal
300/// LLM gateway, a recorded-replay test harness, a non-Anthropic
301/// provider not yet wired into [`ProviderKind`], etc.). This is the
302/// canonical inner function — every other entry point in this module
303/// reduces to this one.
304pub fn ask_with_schema_and_provider<P: Provider>(
305    schema_dump: &str,
306    question: &str,
307    config: &AskConfig,
308    provider: &P,
309) -> Result<AskResponse, AskError> {
310    let system = build_system(schema_dump, config.cache_ttl.into_marker());
311    let messages = [UserMessage::new(question)];
312
313    let req = ProviderRequest {
314        model: &config.model,
315        max_tokens: config.max_tokens,
316        system: &system,
317        messages: &messages,
318    };
319
320    let resp = provider.complete(req)?;
321    parse_response(&resp.text, resp.usage)
322}
323
324/// Pull `sql` and `explanation` out of the model's reply.
325///
326/// We accept three shapes — strict JSON object, JSON wrapped in a
327/// fenced code block, or "almost JSON" with leading/trailing prose —
328/// because real LLM output drifts even with strict instructions. The
329/// fence/prose tolerance matches what real callers do (better-sqlite3,
330/// rusqlite, etc.) when interfacing with model output.
331fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
332    // 1. Strip markdown fences if the model wrapped its JSON.
333    let trimmed = raw.trim();
334    let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
335
336    // 2. Try strict JSON first.
337    if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
338        return extract_fields(&value, usage);
339    }
340
341    // 3. Fallback: extract the first {...} block. Some models tack
342    // prose like "Here is the SQL:" before the JSON despite the
343    // prompt instruction. Find the first balanced object and try
344    // parsing that.
345    if let Some(json_block) = extract_first_json_object(body) {
346        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
347            return extract_fields(&value, usage);
348        }
349    }
350
351    Err(AskError::OutputNotJson(raw.to_string()))
352}
353
354fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
355    let sql = value
356        .get("sql")
357        .and_then(|v| v.as_str())
358        .ok_or(AskError::OutputMissingField("sql"))?
359        .trim()
360        .trim_end_matches(';')
361        .to_string();
362    let explanation = value
363        .get("explanation")
364        .and_then(|v| v.as_str())
365        .unwrap_or("")
366        .to_string();
367    Ok(AskResponse {
368        sql,
369        explanation,
370        usage,
371    })
372}
373
374fn strip_markdown_fence(s: &str) -> Option<&str> {
375    let s = s.trim();
376    let opening_variants = ["```json\n", "```JSON\n", "```\n"];
377    for opener in opening_variants {
378        if let Some(rest) = s.strip_prefix(opener) {
379            // Strip trailing ``` (with or without a final newline).
380            let body = rest.trim_end();
381            let body = body.strip_suffix("```").unwrap_or(body);
382            return Some(body.trim());
383        }
384    }
385    None
386}
387
388fn extract_first_json_object(s: &str) -> Option<String> {
389    let bytes = s.as_bytes();
390    let start = s.find('{')?;
391    let mut depth = 0_i32;
392    let mut in_string = false;
393    let mut escape = false;
394    for (i, &b) in bytes.iter().enumerate().skip(start) {
395        if escape {
396            escape = false;
397            continue;
398        }
399        match b {
400            b'\\' if in_string => escape = true,
401            b'"' => in_string = !in_string,
402            b'{' if !in_string => depth += 1,
403            b'}' if !in_string => {
404                depth -= 1;
405                if depth == 0 {
406                    return Some(s[start..=i].to_string());
407                }
408            }
409            _ => {}
410        }
411    }
412    None
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::provider::MockProvider;
419
420    /// A small fixed schema string. After the v0.1.19 split this
421    /// crate doesn't depend on `sqlrite-engine`, so we no longer
422    /// open an in-memory DB to introspect — we just hand a literal
423    /// schema dump in. (The engine-side helper that produces these
424    /// from a `&Database` is tested separately under `sqlrite-engine
425    /// ::ask::schema`.)
426    const FIXTURE_SCHEMA: &str = "\
427CREATE TABLE users (
428  id INTEGER PRIMARY KEY,
429  name TEXT
430);
431";
432
433    fn cfg() -> AskConfig {
434        AskConfig {
435            api_key: Some("test-key".to_string()),
436            ..AskConfig::default()
437        }
438    }
439
440    #[test]
441    fn ask_with_mock_provider_returns_parsed_sql() {
442        let provider = MockProvider::new(
443            r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
444        );
445        let resp =
446            ask_with_schema_and_provider(FIXTURE_SCHEMA, "how many users?", &cfg(), &provider)
447                .unwrap();
448        assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
449        assert_eq!(resp.explanation, "counts users");
450    }
451
452    #[test]
453    fn schema_dump_appears_in_system_block() {
454        let schema = "CREATE TABLE widgets (\n  id INTEGER PRIMARY KEY,\n  name TEXT\n);\n";
455        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
456        let _ = ask_with_schema_and_provider(schema, "anything", &cfg(), &provider).unwrap();
457
458        let captured = provider.last_request.borrow().clone().unwrap();
459        let schema_block = &captured.system_blocks[1];
460        assert!(
461            schema_block.contains("CREATE TABLE widgets"),
462            "got: {schema_block}"
463        );
464        assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
465    }
466
467    #[test]
468    fn cache_ttl_off_omits_cache_control() {
469        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
470        let mut config = cfg();
471        config.cache_ttl = CacheTtl::Off;
472        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &config, &provider).unwrap();
473        let captured = provider.last_request.borrow().clone().unwrap();
474        assert!(!captured.schema_block_has_cache_control);
475    }
476
477    #[test]
478    fn cache_ttl_5m_sets_cache_control() {
479        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
480        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &cfg(), &provider).unwrap();
481        let captured = provider.last_request.borrow().clone().unwrap();
482        assert!(captured.schema_block_has_cache_control);
483    }
484
485    #[test]
486    fn user_question_arrives_in_messages_unchanged() {
487        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
488        let q = "Find users with email containing '@example.com'";
489        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, q, &cfg(), &provider).unwrap();
490        assert_eq!(
491            provider
492                .last_request
493                .borrow()
494                .as_ref()
495                .unwrap()
496                .user_message,
497            q
498        );
499    }
500
501    #[test]
502    fn missing_api_key_errors_clearly() {
503        // Default has api_key: None already; explicit for the reader.
504        let config = AskConfig {
505            api_key: None,
506            ..AskConfig::default()
507        };
508        let err = ask_with_schema(FIXTURE_SCHEMA, "test", &config).unwrap_err();
509        match err {
510            AskError::MissingApiKey => {}
511            other => panic!("expected MissingApiKey, got {other:?}"),
512        }
513    }
514
515    #[test]
516    fn parse_response_strips_trailing_semicolon() {
517        let resp = parse_response(
518            r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
519            Usage::default(),
520        )
521        .unwrap();
522        assert_eq!(resp.sql, "SELECT 1");
523    }
524
525    #[test]
526    fn parse_response_handles_markdown_fence() {
527        let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
528        let resp = parse_response(raw, Usage::default()).unwrap();
529        assert_eq!(resp.sql, "SELECT 1");
530    }
531
532    #[test]
533    fn parse_response_handles_leading_prose() {
534        let raw =
535            "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
536        let resp = parse_response(raw, Usage::default()).unwrap();
537        assert_eq!(resp.sql, "SELECT 1");
538    }
539
540    #[test]
541    fn parse_response_rejects_non_json() {
542        let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
543        assert!(matches!(err, AskError::OutputNotJson(_)));
544    }
545
546    #[test]
547    fn parse_response_rejects_missing_sql_field() {
548        let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
549        assert!(matches!(err, AskError::OutputMissingField("sql")));
550    }
551
552    #[test]
553    fn parse_response_allows_missing_explanation() {
554        let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
555        assert_eq!(resp.sql, "SELECT 1");
556        assert_eq!(resp.explanation, "");
557    }
558
559    #[test]
560    fn parse_response_passes_usage_through() {
561        let usage = Usage {
562            input_tokens: 100,
563            output_tokens: 20,
564            cache_creation_input_tokens: 80,
565            cache_read_input_tokens: 0,
566        };
567        let resp =
568            parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
569        assert_eq!(resp.usage.input_tokens, 100);
570        assert_eq!(resp.usage.cache_creation_input_tokens, 80);
571    }
572}