Skip to main content

sqlrite_ask/
lib.rs

1//! `sqlrite-ask` — natural-language → SQL adapter for SQLRite.
2//!
3//! **Phase 7g.2 made this crate pure** — it no longer depends on
4//! `sqlrite-engine`. The canonical API takes a `&str` schema dump
5//! (built however you like — see `sqlrite::ConnectionAskExt::ask`
6//! for the engine-side helper that wraps it):
7//!
8//! ```no_run
9//! use sqlrite_ask::{ask_with_schema, AskConfig};
10//!
11//! let schema = "\
12//! CREATE TABLE users (\n  id INTEGER PRIMARY KEY,\n  name TEXT NOT NULL\n);\n";
13//! let cfg  = AskConfig::from_env()?;          // reads SQLRITE_LLM_API_KEY etc.
14//! let resp = ask_with_schema(schema, "How many users?", &cfg)?;
15//! println!("{}", resp.sql);
16//! # Ok::<(), sqlrite_ask::AskError>(())
17//! ```
18//!
19//! For the engine-integrated form (`conn.ask("...", &cfg)`), enable
20//! the `sqlrite-engine` crate's `ask` feature and bring its
21//! `ConnectionAskExt` trait into scope:
22//!
23//! ```ignore
24//! use sqlrite::{Connection, ConnectionAskExt};
25//! use sqlrite_ask::AskConfig;
26//!
27//! let conn = Connection::open("foo.sqlrite")?;
28//! let cfg  = AskConfig::from_env()?;
29//! let resp = conn.ask("How many users?", &cfg)?;
30//! ```
31//!
32//! ## Why the split (Phase 7g.2 retro)
33//!
34//! Wiring the REPL's `.ask` meta-command would have required the
35//! engine binary to depend on `sqlrite-ask`, but `sqlrite-ask`
36//! already depended on `sqlrite-engine` (for `Connection`,
37//! `Database`, `Table`). Cargo's static cycle detection rejects that
38//! shape even with `optional = true`. Solution: keep this crate pure
39//! over `&str` inputs, move the engine integration (schema dump +
40//! `ConnectionAskExt`) into `sqlrite-engine` itself behind an `ask`
41//! feature. Now there's one direction of dep flow: engine →
42//! sqlrite-ask, never the other way.
43//!
44//! ## What this crate is
45//!
46//! - Provider adapters (Anthropic now; OpenAI / Ollama later) that
47//!   POST one HTTP request to a chat-completion endpoint per `ask()`
48//!   call.
49//! - Prompt construction with a `cache_control: ephemeral`
50//!   breakpoint on the schema block, so repeat calls against the
51//!   same schema served from Anthropic's prompt cache.
52//! - Output parsing tolerant to fenced JSON / leading prose / strict
53//!   JSON (model output drifts even with strict instructions).
54//! - `AskConfig` (env vars + explicit overrides), `AskResponse {
55//!   sql, explanation, usage }`, `AskError`.
56//!
57//! ## What this crate is NOT
58//!
59//! - **Not an executor.** The caller decides whether to run the
60//!   generated SQL. SDK convenience wrappers (`Python.Connection
61//!   .ask_run`, `Node.db.askRun`, etc.) layer that on top.
62//! - **Not multi-turn.** Stateless — every call is a fresh prompt.
63//! - **Not engine-coupled.** Schema introspection lives on the
64//!   engine side as of v0.1.19 — see `sqlrite::ConnectionAskExt`.
65//!
66//! ## Configuration
67//!
68//! [`AskConfig`] resolves in this priority order:
69//! 1. Explicit values on the struct.
70//! 2. Environment variables (`SQLRITE_LLM_*`).
71//! 3. Built-in defaults (model = `claude-sonnet-4-6`, max_tokens = 1024,
72//!    cache TTL = 5 min).
73
74use std::env;
75
76// `prompt` and parts of `provider` are wasm-safe (pure serde +
77// trait definitions); `provider::anthropic` is HTTP-only and lives
78// behind the `http` feature flag. Phase 7g.7 made these modules
79// `pub` so the WASM SDK can reuse `build_system` / `parse_response`
80// / `Usage` without duplicating them.
81pub mod prompt;
82pub mod provider;
83
84#[cfg(feature = "http")]
85pub use provider::anthropic::AnthropicProvider;
86pub use provider::{Provider, Request, Response, Usage};
87
88use prompt::{CacheControl, UserMessage, build_system};
89use provider::Request as ProviderRequest;
90
91/// Default model — Sonnet 4.6 hits the cost-quality sweet spot for
92/// NL→SQL. Override via `AskConfig::model` or the `SQLRITE_LLM_MODEL`
93/// env var. See `docs/phase-7-plan.md` for the model-choice rationale.
94pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
95
96/// Default `max_tokens`. SQL generation rarely needs more than ~500
97/// output tokens (single-statement queries + a one-sentence
98/// explanation). 1024 leaves headroom; under the SDK timeout cap so
99/// we don't have to stream.
100pub const DEFAULT_MAX_TOKENS: u32 = 1024;
101
102/// Result returned from a successful [`ask`] call.
103///
104/// `sql` is the generated query text — empty string if the model
105/// determined the question can't be answered against the schema.
106/// `explanation` is the model's one-sentence rationale; useful in
107/// REPL "confirm before run" UIs.
108///
109/// `usage` surfaces token counts (input/output/cache hit/cache write).
110/// Inspect it to verify prompt-caching is actually working — see
111/// `docs/phase-7-plan.md` Q3-adjacent for the audit checklist.
112#[derive(Debug, Clone)]
113pub struct AskResponse {
114    pub sql: String,
115    pub explanation: String,
116    pub usage: Usage,
117}
118
119/// Cache-TTL knob exposed on [`AskConfig`].
120///
121/// Anthropic's `ephemeral` cache supports two TTLs:
122/// - **5 minutes** (default) — break-even at 2 calls per cached
123///   prefix; right for interactive REPL use where users ask a few
124///   questions in a session.
125/// - **1 hour** — costs 2× write premium instead of 1.25×; needs
126///   3+ calls per prefix to break even. Worth it for long-running
127///   editor / desktop sessions where the same DB is queried
128///   sporadically over an hour.
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum CacheTtl {
131    FiveMinutes,
132    OneHour,
133    /// Disables caching — schema block is sent without a
134    /// `cache_control` marker. Useful when the schema is below the
135    /// model's minimum cacheable prefix size (~2K tokens for Sonnet,
136    /// ~4K for Haiku/Opus); marking it would be a no-op.
137    Off,
138}
139
140impl CacheTtl {
141    fn into_marker(self) -> Option<CacheControl> {
142        match self {
143            CacheTtl::FiveMinutes => Some(CacheControl::ephemeral()),
144            CacheTtl::OneHour => Some(CacheControl::ephemeral_1h()),
145            CacheTtl::Off => None,
146        }
147    }
148}
149
150/// Which LLM provider [`ask`] talks to. Anthropic-only in 7g.1; the
151/// enum is here so adding OpenAI/Ollama later doesn't break the
152/// `AskConfig` shape.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ProviderKind {
155    Anthropic,
156}
157
158impl ProviderKind {
159    fn parse(s: &str) -> Result<Self, AskError> {
160        match s.to_ascii_lowercase().as_str() {
161            "anthropic" => Ok(ProviderKind::Anthropic),
162            other => Err(AskError::UnknownProvider(other.to_string())),
163        }
164    }
165}
166
167/// Knobs for an `ask()` call. Construct directly, or via
168/// [`AskConfig::from_env`] to pull defaults from the environment.
169#[derive(Debug, Clone)]
170pub struct AskConfig {
171    pub provider: ProviderKind,
172    pub api_key: Option<String>,
173    pub model: String,
174    pub max_tokens: u32,
175    pub cache_ttl: CacheTtl,
176    /// Override the API base URL. Production callers leave this
177    /// `None`; tests point it at a localhost mock.
178    pub base_url: Option<String>,
179}
180
181impl Default for AskConfig {
182    fn default() -> Self {
183        Self {
184            provider: ProviderKind::Anthropic,
185            api_key: None,
186            model: DEFAULT_MODEL.to_string(),
187            max_tokens: DEFAULT_MAX_TOKENS,
188            cache_ttl: CacheTtl::FiveMinutes,
189            base_url: None,
190        }
191    }
192}
193
194impl AskConfig {
195    /// Build a config from environment variables, with built-in
196    /// defaults for anything not set.
197    ///
198    /// Recognized vars:
199    /// - `SQLRITE_LLM_PROVIDER` — `anthropic` (only currently supported)
200    /// - `SQLRITE_LLM_API_KEY` — required at call time, but a missing
201    ///   var is not an error here (lets you build a config to inspect
202    ///   without the secret loaded)
203    /// - `SQLRITE_LLM_MODEL` — overrides [`DEFAULT_MODEL`]
204    /// - `SQLRITE_LLM_MAX_TOKENS` — overrides [`DEFAULT_MAX_TOKENS`]
205    /// - `SQLRITE_LLM_CACHE_TTL` — `5m` (default) | `1h` | `off`
206    pub fn from_env() -> Result<Self, AskError> {
207        let mut cfg = AskConfig::default();
208        if let Ok(p) = env::var("SQLRITE_LLM_PROVIDER") {
209            cfg.provider = ProviderKind::parse(&p)?;
210        }
211        if let Ok(k) = env::var("SQLRITE_LLM_API_KEY") {
212            if !k.is_empty() {
213                cfg.api_key = Some(k);
214            }
215        }
216        if let Ok(m) = env::var("SQLRITE_LLM_MODEL") {
217            if !m.is_empty() {
218                cfg.model = m;
219            }
220        }
221        if let Ok(t) = env::var("SQLRITE_LLM_MAX_TOKENS") {
222            cfg.max_tokens = t
223                .parse()
224                .map_err(|_| AskError::Config(format!("SQLRITE_LLM_MAX_TOKENS not a u32: {t}")))?;
225        }
226        if let Ok(c) = env::var("SQLRITE_LLM_CACHE_TTL") {
227            cfg.cache_ttl = match c.to_ascii_lowercase().as_str() {
228                "5m" | "5min" | "5minutes" => CacheTtl::FiveMinutes,
229                "1h" | "1hr" | "1hour" => CacheTtl::OneHour,
230                "off" | "none" | "disabled" => CacheTtl::Off,
231                other => {
232                    return Err(AskError::Config(format!(
233                        "SQLRITE_LLM_CACHE_TTL: unknown value '{other}'"
234                    )));
235                }
236            };
237        }
238        Ok(cfg)
239    }
240}
241
242/// Errors `ask()` can return. Includes every failure mode along the
243/// path: config / network / API / parsing.
244#[derive(Debug, thiserror::Error)]
245pub enum AskError {
246    #[error("missing API key (set SQLRITE_LLM_API_KEY or AskConfig.api_key)")]
247    MissingApiKey,
248
249    #[error("config error: {0}")]
250    Config(String),
251
252    #[error("unknown provider: {0} (supported: anthropic)")]
253    UnknownProvider(String),
254
255    #[error("HTTP transport error: {0}")]
256    Http(String),
257
258    #[error("API returned status {status}: {detail}")]
259    ApiStatus { status: u16, detail: String },
260
261    #[error("API returned no text content")]
262    EmptyResponse,
263
264    #[error("model output not valid JSON: {0}")]
265    OutputNotJson(String),
266
267    #[error("model output JSON missing required field '{0}'")]
268    OutputMissingField(&'static str),
269
270    #[error("JSON serialization error: {0}")]
271    Json(#[from] serde_json::Error),
272}
273
274/// One-shot natural-language → SQL.
275///
276/// You pass the schema dump as a string (typically produced by the
277/// engine's `sqlrite::ConnectionAskExt` / `dump_schema_for_database`
278/// helper, but any string format the model can read is fine) and the
279/// user's question. Returns the generated SQL plus a one-sentence
280/// rationale plus token usage for cache-hit verification.
281///
282/// The library does **not** execute the returned SQL — that's the
283/// caller's call. See module docs for rationale.
284///
285/// **Feature-gated under `http`** (default-on) — wraps the built-in
286/// `AnthropicProvider`, which uses ureq and isn't wasm-safe. WASM
287/// callers should use [`ask_with_schema_and_provider`] with a
288/// caller-supplied provider, or skip this crate entirely and use
289/// the WASM SDK's `db.askPrompt()` / `db.askParse()` shape (Q9).
290#[cfg(feature = "http")]
291pub fn ask_with_schema(
292    schema_dump: &str,
293    question: &str,
294    config: &AskConfig,
295) -> Result<AskResponse, AskError> {
296    let api_key = config.api_key.clone().ok_or(AskError::MissingApiKey)?;
297
298    let provider = match config.provider {
299        ProviderKind::Anthropic => match &config.base_url {
300            Some(url) => AnthropicProvider::with_base_url(api_key, url.clone()),
301            None => AnthropicProvider::new(api_key),
302        },
303    };
304
305    ask_with_schema_and_provider(schema_dump, question, config, &provider)
306}
307
308/// Lower-level entry point — same flow as [`ask_with_schema`], but
309/// you supply the provider directly.
310///
311/// Used by the test suite (which passes a `MockProvider`) and by
312/// advanced callers who want to drive a custom backend (an internal
313/// LLM gateway, a recorded-replay test harness, a non-Anthropic
314/// provider not yet wired into [`ProviderKind`], etc.). This is the
315/// canonical inner function — every other entry point in this module
316/// reduces to this one.
317pub fn ask_with_schema_and_provider<P: Provider>(
318    schema_dump: &str,
319    question: &str,
320    config: &AskConfig,
321    provider: &P,
322) -> Result<AskResponse, AskError> {
323    let system = build_system(schema_dump, config.cache_ttl.into_marker());
324    let messages = [UserMessage::new(question)];
325
326    let req = ProviderRequest {
327        model: &config.model,
328        max_tokens: config.max_tokens,
329        system: &system,
330        messages: &messages,
331    };
332
333    let resp = provider.complete(req)?;
334    parse_response(&resp.text, resp.usage)
335}
336
337/// Pull `sql` and `explanation` out of the model's reply.
338///
339/// We accept three shapes — strict JSON object, JSON wrapped in a
340/// fenced code block, or "almost JSON" with leading/trailing prose —
341/// because real LLM output drifts even with strict instructions. The
342/// fence/prose tolerance matches what real callers do (better-sqlite3,
343/// rusqlite, etc.) when interfacing with model output.
344///
345/// **Public as of Phase 7g.7** so the WASM SDK can call this on the
346/// model-text portion of an LLM API response that JS retrieved (per
347/// Q9 the WASM module never makes the HTTP call itself; the JS
348/// caller hands the raw response back through `db.askParse()`).
349pub fn parse_response(raw: &str, usage: Usage) -> Result<AskResponse, AskError> {
350    // 1. Strip markdown fences if the model wrapped its JSON.
351    let trimmed = raw.trim();
352    let body = strip_markdown_fence(trimmed).unwrap_or(trimmed);
353
354    // 2. Try strict JSON first.
355    if let Ok(value) = serde_json::from_str::<serde_json::Value>(body) {
356        return extract_fields(&value, usage);
357    }
358
359    // 3. Fallback: extract the first {...} block. Some models tack
360    // prose like "Here is the SQL:" before the JSON despite the
361    // prompt instruction. Find the first balanced object and try
362    // parsing that.
363    if let Some(json_block) = extract_first_json_object(body) {
364        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&json_block) {
365            return extract_fields(&value, usage);
366        }
367    }
368
369    Err(AskError::OutputNotJson(raw.to_string()))
370}
371
372fn extract_fields(value: &serde_json::Value, usage: Usage) -> Result<AskResponse, AskError> {
373    let sql = value
374        .get("sql")
375        .and_then(|v| v.as_str())
376        .ok_or(AskError::OutputMissingField("sql"))?
377        .trim()
378        .trim_end_matches(';')
379        .to_string();
380    let explanation = value
381        .get("explanation")
382        .and_then(|v| v.as_str())
383        .unwrap_or("")
384        .to_string();
385    Ok(AskResponse {
386        sql,
387        explanation,
388        usage,
389    })
390}
391
392fn strip_markdown_fence(s: &str) -> Option<&str> {
393    let s = s.trim();
394    let opening_variants = ["```json\n", "```JSON\n", "```\n"];
395    for opener in opening_variants {
396        if let Some(rest) = s.strip_prefix(opener) {
397            // Strip trailing ``` (with or without a final newline).
398            let body = rest.trim_end();
399            let body = body.strip_suffix("```").unwrap_or(body);
400            return Some(body.trim());
401        }
402    }
403    None
404}
405
406fn extract_first_json_object(s: &str) -> Option<String> {
407    let bytes = s.as_bytes();
408    let start = s.find('{')?;
409    let mut depth = 0_i32;
410    let mut in_string = false;
411    let mut escape = false;
412    for (i, &b) in bytes.iter().enumerate().skip(start) {
413        if escape {
414            escape = false;
415            continue;
416        }
417        match b {
418            b'\\' if in_string => escape = true,
419            b'"' => in_string = !in_string,
420            b'{' if !in_string => depth += 1,
421            b'}' if !in_string => {
422                depth -= 1;
423                if depth == 0 {
424                    return Some(s[start..=i].to_string());
425                }
426            }
427            _ => {}
428        }
429    }
430    None
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::provider::MockProvider;
437
438    /// A small fixed schema string. After the v0.1.19 split this
439    /// crate doesn't depend on `sqlrite-engine`, so we no longer
440    /// open an in-memory DB to introspect — we just hand a literal
441    /// schema dump in. (The engine-side helper that produces these
442    /// from a `&Database` is tested separately under `sqlrite-engine
443    /// ::ask::schema`.)
444    const FIXTURE_SCHEMA: &str = "\
445CREATE TABLE users (
446  id INTEGER PRIMARY KEY,
447  name TEXT
448);
449";
450
451    fn cfg() -> AskConfig {
452        AskConfig {
453            api_key: Some("test-key".to_string()),
454            ..AskConfig::default()
455        }
456    }
457
458    #[test]
459    fn ask_with_mock_provider_returns_parsed_sql() {
460        let provider = MockProvider::new(
461            r#"{"sql": "SELECT COUNT(*) FROM users", "explanation": "counts users"}"#,
462        );
463        let resp =
464            ask_with_schema_and_provider(FIXTURE_SCHEMA, "how many users?", &cfg(), &provider)
465                .unwrap();
466        assert_eq!(resp.sql, "SELECT COUNT(*) FROM users");
467        assert_eq!(resp.explanation, "counts users");
468    }
469
470    #[test]
471    fn schema_dump_appears_in_system_block() {
472        let schema = "CREATE TABLE widgets (\n  id INTEGER PRIMARY KEY,\n  name TEXT\n);\n";
473        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
474        let _ = ask_with_schema_and_provider(schema, "anything", &cfg(), &provider).unwrap();
475
476        let captured = provider.last_request.borrow().clone().unwrap();
477        let schema_block = &captured.system_blocks[1];
478        assert!(
479            schema_block.contains("CREATE TABLE widgets"),
480            "got: {schema_block}"
481        );
482        assert!(schema_block.contains("name TEXT"), "got: {schema_block}");
483    }
484
485    #[test]
486    fn cache_ttl_off_omits_cache_control() {
487        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
488        let mut config = cfg();
489        config.cache_ttl = CacheTtl::Off;
490        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &config, &provider).unwrap();
491        let captured = provider.last_request.borrow().clone().unwrap();
492        assert!(!captured.schema_block_has_cache_control);
493    }
494
495    #[test]
496    fn cache_ttl_5m_sets_cache_control() {
497        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
498        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, "test", &cfg(), &provider).unwrap();
499        let captured = provider.last_request.borrow().clone().unwrap();
500        assert!(captured.schema_block_has_cache_control);
501    }
502
503    #[test]
504    fn user_question_arrives_in_messages_unchanged() {
505        let provider = MockProvider::new(r#"{"sql": "", "explanation": ""}"#);
506        let q = "Find users with email containing '@example.com'";
507        let _ = ask_with_schema_and_provider(FIXTURE_SCHEMA, q, &cfg(), &provider).unwrap();
508        assert_eq!(
509            provider
510                .last_request
511                .borrow()
512                .as_ref()
513                .unwrap()
514                .user_message,
515            q
516        );
517    }
518
519    #[test]
520    fn missing_api_key_errors_clearly() {
521        // Default has api_key: None already; explicit for the reader.
522        let config = AskConfig {
523            api_key: None,
524            ..AskConfig::default()
525        };
526        let err = ask_with_schema(FIXTURE_SCHEMA, "test", &config).unwrap_err();
527        match err {
528            AskError::MissingApiKey => {}
529            other => panic!("expected MissingApiKey, got {other:?}"),
530        }
531    }
532
533    #[test]
534    fn parse_response_strips_trailing_semicolon() {
535        let resp = parse_response(
536            r#"{"sql": "SELECT 1;", "explanation": "demo"}"#,
537            Usage::default(),
538        )
539        .unwrap();
540        assert_eq!(resp.sql, "SELECT 1");
541    }
542
543    #[test]
544    fn parse_response_handles_markdown_fence() {
545        let raw = "```json\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}\n```";
546        let resp = parse_response(raw, Usage::default()).unwrap();
547        assert_eq!(resp.sql, "SELECT 1");
548    }
549
550    #[test]
551    fn parse_response_handles_leading_prose() {
552        let raw =
553            "Here is the query you asked for:\n{\"sql\": \"SELECT 1\", \"explanation\": \"x\"}";
554        let resp = parse_response(raw, Usage::default()).unwrap();
555        assert_eq!(resp.sql, "SELECT 1");
556    }
557
558    #[test]
559    fn parse_response_rejects_non_json() {
560        let err = parse_response("just some prose, no JSON here", Usage::default()).unwrap_err();
561        assert!(matches!(err, AskError::OutputNotJson(_)));
562    }
563
564    #[test]
565    fn parse_response_rejects_missing_sql_field() {
566        let err = parse_response(r#"{"explanation": "no sql key"}"#, Usage::default()).unwrap_err();
567        assert!(matches!(err, AskError::OutputMissingField("sql")));
568    }
569
570    #[test]
571    fn parse_response_allows_missing_explanation() {
572        let resp = parse_response(r#"{"sql": "SELECT 1"}"#, Usage::default()).unwrap();
573        assert_eq!(resp.sql, "SELECT 1");
574        assert_eq!(resp.explanation, "");
575    }
576
577    #[test]
578    fn parse_response_passes_usage_through() {
579        let usage = Usage {
580            input_tokens: 100,
581            output_tokens: 20,
582            cache_creation_input_tokens: 80,
583            cache_read_input_tokens: 0,
584        };
585        let resp =
586            parse_response(r#"{"sql": "SELECT 1", "explanation": ""}"#, usage.clone()).unwrap();
587        assert_eq!(resp.usage.input_tokens, 100);
588        assert_eq!(resp.usage.cache_creation_input_tokens, 80);
589    }
590}