Skip to main content

samkhya_core/
residual.rs

1//! Residual correction model.
2//!
3//! Optional feedback-driven correction layer. Takes a baseline cardinality
4//! estimate plus a feature vector (query plan + column stats) and returns
5//! a corrected estimate. Trained on observations recorded by
6//! [`crate::feedback`].
7//!
8//! Contracts every backend honors:
9//!
10//! - bounded — output never exceeds the LpBound ceiling ([`crate::lpbound`]); the corrector clamps.
11//! - sub-MB / sub-ms — model footprint and per-estimate latency are the architectural budget.
12//! - optional — engines opt in; with no model attached, samkhya behaves as portable stats + envelope.
13//!
14//! Concrete backends (all behind feature flags):
15//!
16//! - `gbt` — gradient-boosted trees (the `gbt` submodule, gated on the `gbt` cargo feature)
17//! - `additive_gbt` — additive gradient-boosted trees (the `additive` submodule, gated on `additive_gbt`)
18//! - `tabpfn` — foundation-model interface (see the `tabpfn` submodule,
19//!   gated on the `tabpfn_http` cargo feature)
20//! - `llm` — LLM-pluggable corrector backend (see the `llm` submodule,
21//!   gated on the `llm_http` cargo feature). Same wire contract as
22//!   `tabpfn`; the server picks an Anthropic / OpenAI / local-Ollama /
23//!   dummy provider via the `SAMKHYA_LLM_BACKEND` env var.
24//!
25//! # Foundation-model interface (Layer 5)
26//!
27//! The architecture reserves a pluggable backend slot for foundation tabular
28//! models such as TabPFN-2.5 (arXiv 2511.08667). The contract is identical
29//! to every other backend:
30//!
31//! > *feed [`CorrectionFeatures`], receive `Option<u64>` clamped to the
32//! > LpBound ceiling.*
33//!
34//! Two deployment shapes are scaffolded:
35//!
36//! 1. **localhost HTTP** — a Python TabPFN inference server runs out of
37//!    band; samkhya POSTs the feature vector as JSON and reads back an
38//!    `{"estimate": <u64>}` response. Implemented today behind the
39//!    `tabpfn_http` cargo feature (see `tabpfn::TabPfnHttpCorrector`).
40//! 2. **subprocess** — samkhya spawns a Python child, frames JSON over
41//!    stdin/stdout, and keeps the process warm across estimates. Deferred:
42//!    the scaffolding is present (umbrella `tabpfn` feature), the
43//!    transport itself is not implemented in this revision.
44//!
45//! A no-op [`TabPfnStub`] is **always** compiled in, regardless of
46//! features. Its job is purely architectural: downstream code can reference
47//! `TabPfnStub` to mark "TabPFN integration point, currently disabled"
48//! without taking the `tabpfn_http` feature dependency. This reflects the
49//! integration point in every build, so the contract is visible even when
50//! the transport is not.
51//!
52//! Failure policy across all TabPFN backends: any transport error, parse
53//! error, or timeout returns `Ok(None)` (never `Err`). The engine then
54//! falls back cleanly to the native estimate. This is the safety contract;
55//! a remote inference server going down must never surface as a query
56//! failure.
57
58use crate::Result;
59
60/// Emit a single `log::warn!` the first time a plaintext-HTTP URL
61/// pointing at a non-loopback host is configured for an HTTP corrector
62/// backend. See `documents/SECURITY-REVIEW-2026-05-17.md` (H2): such a
63/// URL means features and any embedded baseline estimate travel
64/// unencrypted on the wire. The warning is fire-and-forget — no behaviour
65/// change, so well-configured operators (the typical case, defaults are
66/// localhost) see nothing.
67#[cfg(any(feature = "tabpfn_http", feature = "llm_http"))]
68fn warn_if_remote_plaintext_http(url: &str, backend: &'static str) {
69    let lower = url.to_ascii_lowercase();
70    if !lower.starts_with("http://") {
71        return;
72    }
73    // Pull the host (between "http://" and the next "/" or ":" or end).
74    let rest = &url[7..]; // safe: starts_with confirmed above
75    let host_end = rest
76        .find(|c: char| c == '/' || c == ':' || c == '?')
77        .unwrap_or(rest.len());
78    let host = &rest[..host_end];
79    let is_loopback = matches!(host, "127.0.0.1" | "::1" | "localhost")
80        || host.starts_with("[::1]")
81        || host.starts_with("127.");
82    if is_loopback {
83        return;
84    }
85    if std::env::var("SAMKHYA_ALLOW_REMOTE_HTTP").as_deref() == Ok("1") {
86        return;
87    }
88    log::warn!(
89        "samkhya {backend} corrector configured with plaintext HTTP to non-loopback host {host}; \
90         features and baseline_estimate will travel unencrypted. Use https:// or set \
91         SAMKHYA_ALLOW_REMOTE_HTTP=1 to silence this warning."
92    );
93}
94
95/// Feature vector handed to the corrector at estimate time.
96///
97/// Intentionally minimal at v0.0.1: row count + distinct count + null
98/// count + a small set of operator-level features. Will grow as the
99/// feedback-collection surface widens.
100///
101/// # Examples
102///
103/// ```
104/// use samkhya_core::residual::CorrectionFeatures;
105///
106/// let features = CorrectionFeatures {
107///     baseline_estimate: 1000,
108///     left_input_rows: Some(500),
109///     right_input_rows: Some(2000),
110///     predicate_count: 2,
111///     join_depth: 1,
112///     ..Default::default()
113/// };
114/// assert_eq!(features.to_vec().len(), CorrectionFeatures::FEATURE_LEN);
115/// ```
116#[derive(Debug, Clone, Default)]
117pub struct CorrectionFeatures {
118    pub baseline_estimate: u64,
119    pub left_input_rows: Option<u64>,
120    pub right_input_rows: Option<u64>,
121    pub left_distinct: Option<u64>,
122    pub right_distinct: Option<u64>,
123    pub predicate_count: u32,
124    pub join_depth: u32,
125}
126
127impl CorrectionFeatures {
128    /// Flatten the feature struct into a fixed-length numeric vector for a
129    /// regression model. `Option<u64>` slots are zero-filled when absent —
130    /// callers should treat zero as "unknown" rather than "literally zero
131    /// rows", which is the convention the corrector is trained against.
132    ///
133    /// Layout (stable; new features must be appended, never reordered):
134    ///
135    /// 0. `baseline_estimate`
136    /// 1. `left_input_rows`  (0 if `None`)
137    /// 2. `right_input_rows` (0 if `None`)
138    /// 3. `left_distinct`    (0 if `None`)
139    /// 4. `right_distinct`   (0 if `None`)
140    /// 5. `predicate_count`
141    /// 6. `join_depth`
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use samkhya_core::residual::CorrectionFeatures;
147    ///
148    /// let f = CorrectionFeatures {
149    ///     baseline_estimate: 100,
150    ///     left_input_rows: Some(10),
151    ///     predicate_count: 3,
152    ///     ..Default::default()
153    /// };
154    /// let v = f.to_vec();
155    /// assert_eq!(v[0], 100.0);
156    /// assert_eq!(v[1], 10.0);
157    /// assert_eq!(v[2], 0.0); // None → 0
158    /// assert_eq!(v[5], 3.0);
159    /// ```
160    pub fn to_vec(&self) -> Vec<f64> {
161        vec![
162            self.baseline_estimate as f64,
163            self.left_input_rows.unwrap_or(0) as f64,
164            self.right_input_rows.unwrap_or(0) as f64,
165            self.left_distinct.unwrap_or(0) as f64,
166            self.right_distinct.unwrap_or(0) as f64,
167            f64::from(self.predicate_count),
168            f64::from(self.join_depth),
169        ]
170    }
171
172    /// Number of entries [`to_vec`](Self::to_vec) produces.
173    pub const FEATURE_LEN: usize = 7;
174}
175
176/// A pluggable corrector. Engines call [`Corrector::correct`] on every
177/// estimate that passes through samkhya's optimizer hook.
178///
179/// Returning `Ok(None)` lets the engine fall back to the baseline estimate;
180/// returning `Ok(Some(_))` overrides it (subject to the LpBound envelope).
181///
182/// # Examples
183///
184/// ```
185/// use samkhya_core::residual::{CorrectionFeatures, Corrector, IdentityCorrector};
186///
187/// let corrector = IdentityCorrector;
188/// let features = CorrectionFeatures {
189///     baseline_estimate: 42,
190///     ..Default::default()
191/// };
192/// // The identity corrector passes the baseline through unchanged.
193/// assert_eq!(corrector.correct(&features).unwrap(), Some(42));
194/// ```
195pub trait Corrector: Send + Sync {
196    /// Return a corrected estimate, or `None` to fall back to the baseline.
197    fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>>;
198
199    /// Stable identifier for logging / model-version tracking.
200    fn name(&self) -> &'static str;
201}
202
203/// Default zero-cost corrector: passes the baseline through unchanged.
204///
205/// Used when no feedback history exists yet (cold start) or when the
206/// caller opts out of feedback-driven correction entirely.
207///
208/// # Examples
209///
210/// ```
211/// use samkhya_core::residual::{CorrectionFeatures, Corrector, IdentityCorrector};
212///
213/// let c = IdentityCorrector;
214/// let f = CorrectionFeatures { baseline_estimate: 1234, ..Default::default() };
215/// assert_eq!(c.correct(&f).unwrap(), Some(1234));
216/// assert_eq!(c.name(), "identity");
217/// ```
218pub struct IdentityCorrector;
219
220impl Corrector for IdentityCorrector {
221    fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
222        Ok(Some(features.baseline_estimate))
223    }
224
225    fn name(&self) -> &'static str {
226        "identity"
227    }
228}
229
230/// No-op stub for the foundation-model interface (Layer 5).
231///
232/// Always compiled, regardless of cargo features. Returns `Ok(None)` from
233/// every call, signalling the engine to fall back to its native estimate.
234///
235/// The point of an always-on stub is architectural: it lets downstream
236/// callers reference `TabPfnStub` to mark "TabPFN integration point,
237/// currently disabled" without taking the `tabpfn_http` feature
238/// dependency. The integration shape is visible in every build.
239///
240/// To wire in a real foundation-model backend, swap this for
241/// `tabpfn::TabPfnHttpCorrector` (gated on `tabpfn_http`) or a future
242/// subprocess adapter. The trait contract is identical, so the swap is
243/// a one-line change at the call site.
244///
245/// # Examples
246///
247/// ```
248/// use samkhya_core::residual::{CorrectionFeatures, Corrector, TabPfnStub};
249///
250/// let stub = TabPfnStub;
251/// // Stub always returns Ok(None) — the engine falls back to its native estimate.
252/// let f = CorrectionFeatures { baseline_estimate: 999, ..Default::default() };
253/// assert_eq!(stub.correct(&f).unwrap(), None);
254/// assert_eq!(stub.name(), "tabpfn-stub");
255/// ```
256pub struct TabPfnStub;
257
258impl Corrector for TabPfnStub {
259    fn correct(&self, _features: &CorrectionFeatures) -> Result<Option<u64>> {
260        // Deliberately `None`: the integration point is wired but
261        // disabled. The engine falls back to the native estimate.
262        Ok(None)
263    }
264
265    fn name(&self) -> &'static str {
266        "tabpfn-stub"
267    }
268}
269
270#[cfg(feature = "gbt")]
271pub mod gbt {
272    //! Gradient-boosted-tree residual corrector.
273    //!
274    //! Wraps the `gbdt` crate (Baidu / mesalock-linux,
275    //! <https://github.com/mesalock-linux/gbdt-rs>) — pure-Rust, no native
276    //! deps, builds on stable Rust 1.94 / edition 2024. Compiled in only
277    //! when the `gbt` cargo feature is enabled.
278    //!
279    //! Training target is `log(actual_rows / est_rows)` — the
280    //! multiplicative correction ratio in log-space. At prediction time
281    //! we exponentiate and multiply through the baseline, then clamp to
282    //! the configured LpBound ceiling via
283    //! [`crate::lpbound::saturating_clamp`] so the corrector cannot ever
284    //! violate the envelope contract.
285    //!
286    //! Observations with `est_rows == 0` or `actual_rows == 0` are
287    //! silently dropped (log of zero is undefined); we do not invent a
288    //! Laplace-style smoothing constant at the corrector layer.
289
290    use gbdt::config::{Config, Loss};
291    use gbdt::decision_tree::{Data, DataVec};
292    use gbdt::gradient_boost::GBDT;
293
294    use super::{CorrectionFeatures, Corrector};
295    use crate::feedback::Observation;
296    use crate::lpbound::saturating_clamp;
297    use crate::{Error, Result};
298
299    /// Tunables for [`GbtCorrector::train`]. Defaults are an MVP starting
300    /// point: shallow trees, modest depth, square-error loss.
301    #[derive(Debug, Clone)]
302    pub struct GbtOptions {
303        /// Shrinkage / learning rate applied to each tree's contribution.
304        pub learning_rate: f64,
305        /// Max depth of each regression tree. Root is depth 0.
306        pub max_depth: u32,
307        /// Number of boosting iterations (one tree per iteration).
308        pub num_trees: u32,
309        /// Inclusive upper bound applied to every corrected estimate.
310        /// Use `u64::MAX` to disable (the trait signature has no ceiling
311        /// slot, so we store it here at train time).
312        pub ceiling: u64,
313        /// Minimum samples per leaf — guards against overfitting tiny
314        /// feedback histories.
315        pub min_leaf_size: usize,
316    }
317
318    impl Default for GbtOptions {
319        fn default() -> Self {
320            Self {
321                learning_rate: 0.1,
322                max_depth: 4,
323                num_trees: 50,
324                ceiling: u64::MAX,
325                min_leaf_size: 1,
326            }
327        }
328    }
329
330    /// Trained GBT-backed residual corrector.
331    pub struct GbtCorrector {
332        model: GBDT,
333        ceiling: u64,
334    }
335
336    impl GbtCorrector {
337        /// Train a corrector from a slice of [`Observation`]s.
338        ///
339        /// Returns [`Error::Feedback`] if the observation slice is empty,
340        /// or if every observation is unusable (zero est_rows or zero
341        /// actual_rows). Non-positive-ratio observations are silently
342        /// filtered, matching the convention in [`Observation::q_error`].
343        pub fn train(observations: &[Observation], options: GbtOptions) -> Result<Self> {
344            if observations.is_empty() {
345                return Err(Error::Feedback(
346                    "cannot train GbtCorrector: observation slice is empty".into(),
347                ));
348            }
349
350            let mut training: DataVec = Vec::with_capacity(observations.len());
351            for obs in observations {
352                if obs.est_rows == 0 || obs.actual_rows == 0 {
353                    continue;
354                }
355                // Reconstruct a feature vector from the observation. The
356                // feedback table doesn't yet carry full plan features, so
357                // we synthesize the minimal `baseline_estimate`-only
358                // vector. As `Observation` gains columns the mapping
359                // below should grow in lockstep with `CorrectionFeatures`.
360                let features = CorrectionFeatures {
361                    baseline_estimate: obs.est_rows,
362                    ..Default::default()
363                };
364                let feature_f32: Vec<f32> =
365                    features.to_vec().into_iter().map(|v| v as f32).collect();
366                let ratio_log = (obs.actual_rows as f64 / obs.est_rows as f64).ln() as f32;
367                training.push(Data::new_training_data(feature_f32, 1.0, ratio_log, None));
368            }
369
370            if training.is_empty() {
371                return Err(Error::Feedback(
372                    "cannot train GbtCorrector: all observations had zero est or actual rows"
373                        .into(),
374                ));
375            }
376
377            let mut cfg = Config::new();
378            cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
379            cfg.set_max_depth(options.max_depth);
380            cfg.set_iterations(options.num_trees as usize);
381            cfg.set_shrinkage(options.learning_rate as f32);
382            cfg.set_min_leaf_size(options.min_leaf_size);
383            cfg.set_loss(&loss_name(Loss::SquaredError));
384
385            let mut model = GBDT::new(&cfg);
386            model.fit(&mut training);
387
388            Ok(Self {
389                model,
390                ceiling: options.ceiling,
391            })
392        }
393
394        /// Predict the log-ratio correction for a single feature vector.
395        /// Exposed for diagnostics / unit tests; the production path is
396        /// [`Corrector::correct`].
397        pub fn predict_log_ratio(&self, features: &CorrectionFeatures) -> f64 {
398            let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
399            let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
400            let preds = self.model.predict(&probe);
401            preds.first().copied().unwrap_or(0.0) as f64
402        }
403
404        /// Configured upper bound. Set at training time; the trait method
405        /// [`Corrector::correct`] enforces it via `saturating_clamp`.
406        pub fn ceiling(&self) -> u64 {
407            self.ceiling
408        }
409    }
410
411    impl Corrector for GbtCorrector {
412        fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
413            let log_ratio = self.predict_log_ratio(features);
414            let ratio = log_ratio.exp();
415            let scaled = features.baseline_estimate as f64 * ratio;
416            Ok(Some(saturating_clamp(scaled, self.ceiling)))
417        }
418
419        fn name(&self) -> &'static str {
420            "gbt"
421        }
422    }
423
424    /// `gbdt::config::Config::set_loss` takes a string; this is the
425    /// canonical spelling for square-error in that crate.
426    fn loss_name(loss: Loss) -> String {
427        gbdt::config::loss2string(&loss)
428    }
429}
430
431#[cfg(feature = "additive_gbt")]
432pub mod additive {
433    //! Additive gradient-boosted-tree residual corrector.
434    //!
435    //! Sibling backend to [`super::gbt`]. The multiplicative form trains on
436    //! `log(actual / baseline_estimate)` and applies the correction as
437    //! `baseline * exp(predicted)`. That model is structurally trapped at
438    //! zero whenever the engine hands us `baseline_estimate = 0` — the
439    //! q=∞ regime where the upstream estimator has completely collapsed
440    //! (a common DataFusion 46 symptom on chained joins).
441    //!
442    //! The additive backend sidesteps that trap by training the model to
443    //! predict the **absolute** `actual_rows` from the full
444    //! [`CorrectionFeatures`] vector (all 7 features, not just the
445    //! baseline). The prediction is clamped to a non-negative integer and
446    //! then to the configured LpBound ceiling via
447    //! [`crate::lpbound::saturating_clamp`], so the envelope contract is
448    //! preserved.
449    //!
450    //! Cargo feature: `additive_gbt`. Independent of the `gbt` feature —
451    //! they can be enabled separately or together.
452
453    use gbdt::config::{Config, Loss};
454    use gbdt::decision_tree::{Data, DataVec};
455    use gbdt::gradient_boost::GBDT;
456    use std::sync::Mutex;
457
458    use super::{CorrectionFeatures, Corrector};
459    use crate::feedback::Observation;
460    use crate::lpbound::saturating_clamp;
461    use crate::{Error, Result};
462
463    /// Tunables for [`AdditiveGbtCorrector::train`]. Defaults mirror
464    /// [`super::gbt::GbtOptions`] so the two backends are
465    /// drop-in-comparable when benchmarking.
466    #[derive(Debug, Clone)]
467    pub struct AdditiveGbtOptions {
468        /// Shrinkage / learning rate applied to each tree's contribution.
469        pub learning_rate: f64,
470        /// Max depth of each regression tree. Root is depth 0.
471        pub max_depth: u32,
472        /// Number of boosting iterations (one tree per iteration).
473        pub num_trees: u32,
474        /// Inclusive upper bound applied to every corrected estimate.
475        /// Use `u64::MAX` to disable.
476        pub ceiling: u64,
477        /// Minimum samples per leaf — guards against overfitting tiny
478        /// feedback histories.
479        pub min_leaf_size: usize,
480    }
481
482    impl Default for AdditiveGbtOptions {
483        fn default() -> Self {
484            Self {
485                learning_rate: 0.1,
486                max_depth: 4,
487                num_trees: 50,
488                ceiling: u64::MAX,
489                min_leaf_size: 1,
490            }
491        }
492    }
493
494    /// Trained additive GBT corrector. Predicts absolute row counts.
495    ///
496    /// The model is wrapped in a [`Mutex`] because `gbdt::GBDT::predict`
497    /// takes `&mut self` on some configurations; the lock is held only
498    /// for the prediction call and is uncontended in the common single-
499    /// threaded estimate path.
500    pub struct AdditiveGbtCorrector {
501        model: Mutex<GBDT>,
502        ceiling: u64,
503    }
504
505    impl AdditiveGbtCorrector {
506        /// Train an additive corrector from a slice of [`Observation`]s.
507        ///
508        /// Returns [`Error::Feedback`] if the observation slice is empty.
509        /// Unlike the multiplicative backend, observations with
510        /// `est_rows == 0` are **kept** — they are precisely the q=∞
511        /// regime this backend exists to handle. Observations with
512        /// `actual_rows == 0` are also kept (a true-zero output is a
513        /// valid signal for an additive model).
514        pub fn train(observations: &[Observation], options: AdditiveGbtOptions) -> Result<Self> {
515            if observations.is_empty() {
516                return Err(Error::Feedback(
517                    "cannot train AdditiveGbtCorrector: observation slice is empty".into(),
518                ));
519            }
520
521            let mut training: DataVec = Vec::with_capacity(observations.len());
522            for obs in observations {
523                // Reconstruct a feature vector from the observation. The
524                // feedback table doesn't yet carry the full plan-shape
525                // feature set, so we synthesize from `est_rows`. As
526                // `Observation` gains columns, mirror the additions here.
527                let features = CorrectionFeatures {
528                    baseline_estimate: obs.est_rows,
529                    ..Default::default()
530                };
531                let feature_f32: Vec<f32> =
532                    features.to_vec().into_iter().map(|v| v as f32).collect();
533                let target = obs.actual_rows as f32;
534                training.push(Data::new_training_data(feature_f32, 1.0, target, None));
535            }
536
537            // Empty observations are caught above; the synthesized
538            // training set here is always non-empty.
539            debug_assert!(!training.is_empty());
540
541            let mut cfg = Config::new();
542            cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
543            cfg.set_max_depth(options.max_depth);
544            cfg.set_iterations(options.num_trees as usize);
545            cfg.set_shrinkage(options.learning_rate as f32);
546            cfg.set_min_leaf_size(options.min_leaf_size);
547            cfg.set_loss(&gbdt::config::loss2string(&Loss::SquaredError));
548
549            let mut model = GBDT::new(&cfg);
550            model.fit(&mut training);
551
552            Ok(Self {
553                model: Mutex::new(model),
554                ceiling: options.ceiling,
555            })
556        }
557
558        /// Predict the absolute row count for a feature vector.
559        /// Exposed for diagnostics; the production path is
560        /// [`Corrector::correct`].
561        pub fn predict_rows(&self, features: &CorrectionFeatures) -> f64 {
562            let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
563            let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
564            let model = self.model.lock().expect("AdditiveGbtCorrector model lock");
565            let preds = model.predict(&probe);
566            preds.first().copied().unwrap_or(0.0) as f64
567        }
568
569        /// Configured upper bound. Set at training time; the trait method
570        /// [`Corrector::correct`] enforces it via `saturating_clamp`.
571        pub fn ceiling(&self) -> u64 {
572            self.ceiling
573        }
574    }
575
576    impl Corrector for AdditiveGbtCorrector {
577        fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
578            let raw = self.predict_rows(features).max(0.0);
579            Ok(Some(saturating_clamp(raw, self.ceiling)))
580        }
581
582        fn name(&self) -> &'static str {
583            "additive_gbt"
584        }
585    }
586}
587
588#[cfg(feature = "tabpfn_http")]
589pub mod tabpfn {
590    //! Foundation-model interface — HTTP transport.
591    //!
592    //! Posts a [`super::CorrectionFeatures`] vector as JSON to a
593    //! user-configured endpoint (e.g., a Python TabPFN inference server
594    //! listening on `http://localhost:8765/infer`), parses an
595    //! `{"estimate": <u64>}` reply, and clamps the result to the LpBound
596    //! ceiling via [`crate::lpbound::saturating_clamp`].
597    //!
598    //! Transport: pure-Rust `ureq` (rustls-only, no OpenSSL). Compiled in
599    //! only when the `tabpfn_http` cargo feature is enabled.
600    //!
601    //! # Safety contract
602    //!
603    //! Any failure — DNS, connection refused, HTTP non-2xx, body parse
604    //! error, timeout — returns `Ok(None)`. The engine falls back to the
605    //! native estimate. We never propagate transport errors to the
606    //! optimizer hot path; a remote inference server going down must not
607    //! surface as a query failure.
608    //!
609    //! Note on naming: this is *the foundation-model interface*, not a
610    //! "learned" or "AI" feature. The corrector is a pluggable backend
611    //! behind the same `Corrector` trait as every other backend in this
612    //! module.
613    //!
614    //! # Wire format
615    //!
616    //! Request body (JSON):
617    //!
618    //! ```json
619    //! {
620    //!   "features": [<f64>, <f64>, ...],
621    //!   "baseline_estimate": <u64>
622    //! }
623    //! ```
624    //!
625    //! Response body (JSON):
626    //!
627    //! ```json
628    //! { "estimate": <u64> }
629    //! ```
630    //!
631    //! Any extra fields in the response are ignored, so server
632    //! implementations are free to add diagnostics without breaking the
633    //! client.
634    //!
635    //! # See also
636    //!
637    //! - [`super::TabPfnStub`] — always-on no-op for the same integration
638    //!   slot, no transport dependency.
639
640    use serde::{Deserialize, Serialize};
641    use std::time::Duration;
642
643    use super::{CorrectionFeatures, Corrector};
644    use crate::Result;
645    use crate::lpbound::saturating_clamp;
646
647    /// Configuration for [`TabPfnHttpCorrector`].
648    #[derive(Debug, Clone)]
649    pub struct TabPfnHttpOptions {
650        /// Inference endpoint URL. The corrector POSTs here on every
651        /// `correct()` call. Example: `http://localhost:8765/infer`.
652        pub base_url: String,
653        /// Per-request timeout. Applies independently to the connect and
654        /// read phases. Bounded by the architecture's sub-ms budget for
655        /// the production path, but configurable so users can dial it up
656        /// for diagnostics.
657        pub timeout_ms: u64,
658        /// Inclusive upper bound applied to every corrected estimate via
659        /// [`saturating_clamp`]. The Layer 3 safety guarantee — corrections
660        /// can never exceed this regardless of what the remote backend
661        /// returns. Use `u64::MAX` to disable.
662        pub ceiling: u64,
663    }
664
665    impl Default for TabPfnHttpOptions {
666        fn default() -> Self {
667            Self {
668                base_url: "http://localhost:8765/infer".into(),
669                timeout_ms: 50,
670                ceiling: u64::MAX,
671            }
672        }
673    }
674
675    /// JSON request body sent to the inference endpoint.
676    #[derive(Serialize)]
677    struct InferRequest<'a> {
678        features: &'a [f64],
679        baseline_estimate: u64,
680    }
681
682    /// JSON response body. Extra fields are ignored.
683    #[derive(Deserialize)]
684    struct InferResponse {
685        estimate: u64,
686    }
687
688    /// HTTP-backed foundation-model corrector.
689    ///
690    /// Holds a tiny client config and a base URL. The `ureq` agent is
691    /// constructed per-call: the per-estimate cost is dominated by network
692    /// round-trip, not agent allocation, and per-call agents keep the
693    /// struct cheaply `Send + Sync` without interior mutability.
694    pub struct TabPfnHttpCorrector {
695        options: TabPfnHttpOptions,
696    }
697
698    impl TabPfnHttpCorrector {
699        /// Build a corrector from explicit options.
700        pub fn new(options: TabPfnHttpOptions) -> Self {
701            super::warn_if_remote_plaintext_http(&options.base_url, "tabpfn_http");
702            Self { options }
703        }
704
705        /// Convenience constructor: default options with the supplied URL.
706        pub fn with_url(base_url: impl Into<String>) -> Self {
707            let opts = TabPfnHttpOptions {
708                base_url: base_url.into(),
709                ..TabPfnHttpOptions::default()
710            };
711            super::warn_if_remote_plaintext_http(&opts.base_url, "tabpfn_http");
712            Self { options: opts }
713        }
714
715        /// Configured options (for diagnostics / logging).
716        pub fn options(&self) -> &TabPfnHttpOptions {
717            &self.options
718        }
719
720        /// Attempt one inference call. Returns `None` on any failure
721        /// (network, parse, non-2xx). The `correct()` trait method wraps
722        /// this and applies the LpBound clamp.
723        fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
724            let feature_vec = features.to_vec();
725            let payload = InferRequest {
726                features: &feature_vec,
727                baseline_estimate: features.baseline_estimate,
728            };
729
730            let timeout = Duration::from_millis(self.options.timeout_ms);
731            let agent = ureq::AgentBuilder::new()
732                .timeout_connect(timeout)
733                .timeout_read(timeout)
734                .timeout_write(timeout)
735                .build();
736
737            let response = match agent.post(&self.options.base_url).send_json(&payload) {
738                Ok(r) => r,
739                Err(err) => {
740                    // Map every transport error to None and log at debug.
741                    // The Error::Feedback diagnostic carries the URL plus
742                    // the underlying message so callers tailing logs can
743                    // see what failed without us aborting the query.
744                    log::debug!(
745                        "tabpfn_http: request to {} failed: {}",
746                        self.options.base_url,
747                        err
748                    );
749                    return None;
750                }
751            };
752
753            match response.into_json::<InferResponse>() {
754                Ok(body) => Some(body.estimate),
755                Err(err) => {
756                    log::debug!(
757                        "tabpfn_http: response from {} failed to parse: {}",
758                        self.options.base_url,
759                        err
760                    );
761                    None
762                }
763            }
764        }
765    }
766
767    impl Corrector for TabPfnHttpCorrector {
768        fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
769            // Safety contract: every failure returns Ok(None), not Err.
770            // The engine then transparently falls back to the native
771            // estimate. We use Result here to honour the trait shape and
772            // to keep a door open for future *non-fallback* error modes
773            // (e.g. a deliberate misconfiguration check), but on the hot
774            // path failures are absorbed.
775            let Some(raw) = self.try_infer(features) else {
776                return Ok(None);
777            };
778            Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
779        }
780
781        fn name(&self) -> &'static str {
782            "tabpfn-http"
783        }
784    }
785}
786
787#[cfg(feature = "llm_http")]
788pub mod llm {
789    //! LLM-pluggable corrector backend — HTTP transport.
790    //!
791    //! Posts a [`super::CorrectionFeatures`] vector as JSON to a
792    //! user-configured endpoint (e.g., a Python LLM inference server
793    //! listening on `http://localhost:8766/infer`) and parses an
794    //! `{"estimate": <u64>}` reply. The server-side LLM provider
795    //! (Anthropic, OpenAI, local Ollama, dummy) is selected by the
796    //! `SAMKHYA_LLM_BACKEND` env var on the server process — the wire
797    //! contract is identical regardless of which provider is configured.
798    //!
799    //! Transport: pure-Rust `ureq` (rustls-only, no OpenSSL). Compiled in
800    //! only when the `llm_http` cargo feature is enabled.
801    //!
802    //! # Naming
803    //!
804    //! This is *the LLM-pluggable corrector backend* — a transport-level
805    //! integration that lets a foundation language model serve as the
806    //! cardinality corrector behind the same `Corrector` trait as every
807    //! other backend in this module. It is **not** an "AI", "adaptive",
808    //! or "learned" feature; the samkhya envelope still dominates the
809    //! safety contract and the LLM is strictly an opt-in pluggable
810    //! backend. The default samkhya build does not pull this in.
811    //!
812    //! # Safety contract
813    //!
814    //! Any failure — DNS, connection refused, HTTP non-2xx, body parse
815    //! error, timeout — returns `Ok(None)`. The engine falls back to the
816    //! native estimate. We never propagate transport errors to the
817    //! optimizer hot path; a remote inference server going down must not
818    //! surface as a query failure. Mirrors the
819    //! [`super::tabpfn::TabPfnHttpCorrector`] contract exactly.
820    //!
821    //! # Wire format
822    //!
823    //! Request body (JSON):
824    //!
825    //! ```json
826    //! {
827    //!   "features": [<f64>, <f64>, ...],
828    //!   "baseline_estimate": <u64>
829    //! }
830    //! ```
831    //!
832    //! Response body (JSON):
833    //!
834    //! ```json
835    //! { "estimate": <u64> }
836    //! ```
837    //!
838    //! Any extra fields in the response are ignored, so server
839    //! implementations are free to add diagnostics (e.g., the LLM's raw
840    //! text reply, parse-status flags) without breaking the client.
841    //!
842    //! # Latency expectations
843    //!
844    //! LLM round-trips are 2–3 orders of magnitude slower than the TabPFN
845    //! tier (P95 in the 0.3–2 s range vs. ~30 ms for TabPFN). The default
846    //! per-request timeout is therefore 2 000 ms (vs. 50 ms for TabPFN),
847    //! with a 60 s hard cap available for cold-cache diagnostics. The
848    //! `llm_http` backend is intended for *offline / overnight*
849    //! re-validation and schema-introspection use cases, not the online
850    //! query hot path. See `bench-results/19_llm_corrector.md` §6 for
851    //! routing guidance.
852
853    use serde::{Deserialize, Serialize};
854    use std::time::Duration;
855
856    use super::{CorrectionFeatures, Corrector};
857    use crate::Result;
858    use crate::lpbound::saturating_clamp;
859
860    /// Default per-request timeout for the LLM HTTP backend (milliseconds).
861    /// LLMs are 2–3 orders of magnitude slower than TabPFN; the 2 s
862    /// default is the smallest budget that consistently covers warm-cache
863    /// Anthropic Claude / OpenAI GPT-4o-mini calls without spurious
864    /// timeouts in measurement.
865    pub const DEFAULT_TIMEOUT_MS: u64 = 2_000;
866
867    /// Hard per-request ceiling (milliseconds). Constructors that accept
868    /// a `timeout_ms` saturate to this value so a misconfigured caller
869    /// cannot pin the optimizer for longer than 60 s on a single call.
870    pub const MAX_TIMEOUT_MS: u64 = 60_000;
871
872    /// Default inference endpoint. Distinct from the TabPFN default port
873    /// (`8765`) so an operator can run both servers side-by-side without
874    /// collision.
875    pub const DEFAULT_URL: &str = "http://127.0.0.1:8766/infer";
876
877    /// Configuration for [`LlmHttpCorrector`].
878    #[derive(Debug, Clone)]
879    pub struct LlmHttpOptions {
880        /// Inference endpoint URL. The corrector POSTs here on every
881        /// `correct()` call. Example: `http://localhost:8766/infer`.
882        pub base_url: String,
883        /// Per-request timeout. Applies to connect, read, and write
884        /// phases. Capped at [`MAX_TIMEOUT_MS`] so a misconfigured caller
885        /// cannot stall the optimizer indefinitely.
886        pub timeout_ms: u64,
887        /// Inclusive upper bound applied to every corrected estimate via
888        /// [`saturating_clamp`]. The Layer 3 safety guarantee —
889        /// corrections can never exceed this regardless of what the
890        /// remote LLM returns. Use `u64::MAX` to disable.
891        pub ceiling: u64,
892    }
893
894    impl Default for LlmHttpOptions {
895        fn default() -> Self {
896            Self {
897                base_url: DEFAULT_URL.into(),
898                timeout_ms: DEFAULT_TIMEOUT_MS,
899                ceiling: u64::MAX,
900            }
901        }
902    }
903
904    /// JSON request body sent to the inference endpoint.
905    #[derive(Serialize)]
906    struct InferRequest<'a> {
907        features: &'a [f64],
908        baseline_estimate: u64,
909    }
910
911    /// JSON response body. Extra fields are ignored.
912    #[derive(Deserialize)]
913    struct InferResponse {
914        estimate: u64,
915    }
916
917    /// HTTP-backed LLM-pluggable corrector.
918    ///
919    /// Holds a tiny client config and a base URL. The `ureq` agent is
920    /// constructed per-call: the per-estimate cost is dominated by LLM
921    /// inference (hundreds of milliseconds), not agent allocation, and
922    /// per-call agents keep the struct cheaply `Send + Sync` without
923    /// interior mutability.
924    pub struct LlmHttpCorrector {
925        options: LlmHttpOptions,
926    }
927
928    impl LlmHttpCorrector {
929        /// Build a corrector from explicit options. The `timeout_ms`
930        /// value is saturated to [`MAX_TIMEOUT_MS`] so misconfigured
931        /// callers cannot stall the optimizer for longer than that.
932        pub fn new(mut options: LlmHttpOptions) -> Self {
933            if options.timeout_ms > MAX_TIMEOUT_MS {
934                options.timeout_ms = MAX_TIMEOUT_MS;
935            }
936            super::warn_if_remote_plaintext_http(&options.base_url, "llm_http");
937            Self { options }
938        }
939
940        /// Convenience constructor: default options with the supplied
941        /// URL. Useful for ad-hoc bench / smoke clients.
942        pub fn with_url(base_url: impl Into<String>) -> Self {
943            Self::new(LlmHttpOptions {
944                base_url: base_url.into(),
945                ..LlmHttpOptions::default()
946            })
947        }
948
949        /// Configured options (for diagnostics / logging).
950        pub fn options(&self) -> &LlmHttpOptions {
951            &self.options
952        }
953
954        /// Attempt one inference call. Returns `None` on any failure
955        /// (network, parse, non-2xx). The `correct()` trait method wraps
956        /// this and applies the LpBound clamp.
957        fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
958            let feature_vec = features.to_vec();
959            let payload = InferRequest {
960                features: &feature_vec,
961                baseline_estimate: features.baseline_estimate,
962            };
963
964            let timeout = Duration::from_millis(self.options.timeout_ms);
965            let agent = ureq::AgentBuilder::new()
966                .timeout_connect(timeout)
967                .timeout_read(timeout)
968                .timeout_write(timeout)
969                .build();
970
971            let response = match agent.post(&self.options.base_url).send_json(&payload) {
972                Ok(r) => r,
973                Err(err) => {
974                    log::debug!(
975                        "llm_http: request to {} failed: {}",
976                        self.options.base_url,
977                        err
978                    );
979                    return None;
980                }
981            };
982
983            match response.into_json::<InferResponse>() {
984                Ok(body) => Some(body.estimate),
985                Err(err) => {
986                    log::debug!(
987                        "llm_http: response from {} failed to parse: {}",
988                        self.options.base_url,
989                        err
990                    );
991                    None
992                }
993            }
994        }
995    }
996
997    impl Corrector for LlmHttpCorrector {
998        fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
999            // Safety contract: every failure returns Ok(None), not Err.
1000            // Mirrors `TabPfnHttpCorrector::correct`. On the optimizer's
1001            // hot path a remote LLM going down (rate limit, network
1002            // partition, mis-config) must never surface as a query
1003            // failure.
1004            let Some(raw) = self.try_infer(features) else {
1005                return Ok(None);
1006            };
1007            Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
1008        }
1009
1010        fn name(&self) -> &'static str {
1011            "llm-http"
1012        }
1013    }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019
1020    #[test]
1021    fn identity_returns_baseline() {
1022        let corrector = IdentityCorrector;
1023        let features = CorrectionFeatures {
1024            baseline_estimate: 1234,
1025            ..Default::default()
1026        };
1027        assert_eq!(corrector.correct(&features).unwrap(), Some(1234));
1028        assert_eq!(corrector.name(), "identity");
1029    }
1030
1031    #[test]
1032    fn tabpfn_stub_always_returns_none() {
1033        let corrector = TabPfnStub;
1034        let features = CorrectionFeatures {
1035            baseline_estimate: 9999,
1036            ..Default::default()
1037        };
1038        assert_eq!(
1039            corrector.correct(&features).unwrap(),
1040            None,
1041            "TabPfnStub must always return Ok(None) — it documents the integration point"
1042        );
1043        assert_eq!(corrector.name(), "tabpfn-stub");
1044
1045        // Also exercise an empty feature vector — the stub should still
1046        // return None without inspecting the input.
1047        let empty = CorrectionFeatures::default();
1048        assert_eq!(corrector.correct(&empty).unwrap(), None);
1049    }
1050
1051    #[test]
1052    fn feature_vec_layout_is_stable() {
1053        let f = CorrectionFeatures {
1054            baseline_estimate: 100,
1055            left_input_rows: Some(10),
1056            right_input_rows: None,
1057            left_distinct: Some(7),
1058            right_distinct: None,
1059            predicate_count: 3,
1060            join_depth: 2,
1061        };
1062        let v = f.to_vec();
1063        assert_eq!(v.len(), CorrectionFeatures::FEATURE_LEN);
1064        assert_eq!(v[0], 100.0);
1065        assert_eq!(v[1], 10.0);
1066        assert_eq!(v[2], 0.0); // None → 0
1067        assert_eq!(v[3], 7.0);
1068        assert_eq!(v[4], 0.0);
1069        assert_eq!(v[5], 3.0);
1070        assert_eq!(v[6], 2.0);
1071    }
1072}
1073
1074#[cfg(all(test, feature = "gbt"))]
1075mod gbt_tests {
1076    use super::gbt::{GbtCorrector, GbtOptions};
1077    use super::{CorrectionFeatures, Corrector};
1078    use crate::feedback::Observation;
1079
1080    /// Build N synthetic observations where `actual = est * 2` for a
1081    /// spread of est values. Plenty of signal for the trees to latch on.
1082    fn synthetic_double(n: u64) -> Vec<Observation> {
1083        (1..=n)
1084            .map(|i| Observation {
1085                template_hash: "syn".into(),
1086                plan_fingerprint: "p".into(),
1087                est_rows: i * 10,
1088                actual_rows: i * 10 * 2,
1089                latency_ms: None,
1090            })
1091            .collect()
1092    }
1093
1094    #[test]
1095    fn predicts_roughly_double_when_training_says_double() {
1096        let obs = synthetic_double(200);
1097        let opts = GbtOptions {
1098            learning_rate: 0.3,
1099            max_depth: 4,
1100            num_trees: 50,
1101            ceiling: u64::MAX,
1102            min_leaf_size: 1,
1103        };
1104        let corrector = GbtCorrector::train(&obs, opts).expect("training");
1105
1106        let features = CorrectionFeatures {
1107            baseline_estimate: 500,
1108            ..Default::default()
1109        };
1110        let corrected = corrector
1111            .correct(&features)
1112            .expect("correct")
1113            .expect("Some");
1114        // True target is 1000. Trees won't be exact; require within 25%.
1115        let ratio = corrected as f64 / 1000.0;
1116        assert!(
1117            (0.75..=1.25).contains(&ratio),
1118            "expected ~1000, got {} (ratio {})",
1119            corrected,
1120            ratio
1121        );
1122        assert_eq!(corrector.name(), "gbt");
1123    }
1124
1125    #[test]
1126    fn ceiling_clamps_when_prediction_exceeds_it() {
1127        let obs = synthetic_double(200);
1128        let opts = GbtOptions {
1129            learning_rate: 0.3,
1130            max_depth: 4,
1131            num_trees: 50,
1132            ceiling: 100, // far below 2 × baseline
1133            min_leaf_size: 1,
1134        };
1135        let corrector = GbtCorrector::train(&obs, opts).expect("training");
1136
1137        let features = CorrectionFeatures {
1138            baseline_estimate: 500,
1139            ..Default::default()
1140        };
1141        let corrected = corrector
1142            .correct(&features)
1143            .expect("correct")
1144            .expect("Some");
1145        assert_eq!(corrected, 100, "ceiling must clamp the corrected estimate");
1146        assert_eq!(corrector.ceiling(), 100);
1147    }
1148
1149    #[test]
1150    fn empty_observations_errors() {
1151        match GbtCorrector::train(&[], GbtOptions::default()) {
1152            Ok(_) => panic!("expected error on empty observations"),
1153            Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1154        }
1155    }
1156
1157    #[test]
1158    fn all_zero_observations_errors() {
1159        let obs = vec![
1160            Observation {
1161                template_hash: "z".into(),
1162                plan_fingerprint: "p".into(),
1163                est_rows: 0,
1164                actual_rows: 5,
1165                latency_ms: None,
1166            },
1167            Observation {
1168                template_hash: "z".into(),
1169                plan_fingerprint: "p".into(),
1170                est_rows: 5,
1171                actual_rows: 0,
1172                latency_ms: None,
1173            },
1174        ];
1175        match GbtCorrector::train(&obs, GbtOptions::default()) {
1176            Ok(_) => panic!("expected error when all observations are zero"),
1177            Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1178        }
1179    }
1180}
1181
1182#[cfg(all(test, feature = "additive_gbt"))]
1183mod additive_tests {
1184    use super::additive::{AdditiveGbtCorrector, AdditiveGbtOptions};
1185    use super::{CorrectionFeatures, Corrector};
1186    use crate::feedback::Observation;
1187
1188    /// Build N synthetic observations where every actual row count is
1189    /// the same constant `target`. An additive model trained on this
1190    /// should regress toward `target` regardless of the input features.
1191    fn synthetic_constant(n: u64, target: u64) -> Vec<Observation> {
1192        (1..=n)
1193            .map(|i| Observation {
1194                template_hash: "syn-add".into(),
1195                plan_fingerprint: "p".into(),
1196                est_rows: i * 10,
1197                actual_rows: target,
1198                latency_ms: None,
1199            })
1200            .collect()
1201    }
1202
1203    #[test]
1204    fn predicts_near_constant_when_training_is_constant() {
1205        let obs = synthetic_constant(200, 1000);
1206        let opts = AdditiveGbtOptions {
1207            learning_rate: 0.3,
1208            max_depth: 4,
1209            num_trees: 50,
1210            ceiling: u64::MAX,
1211            min_leaf_size: 1,
1212        };
1213        let corrector =
1214            AdditiveGbtCorrector::train(&obs, opts).expect("training additive corrector");
1215
1216        let features = CorrectionFeatures {
1217            baseline_estimate: 500,
1218            ..Default::default()
1219        };
1220        let corrected = corrector
1221            .correct(&features)
1222            .expect("correct")
1223            .expect("Some");
1224        assert!(
1225            (800..=1200).contains(&corrected),
1226            "expected ~1000, got {corrected}"
1227        );
1228        assert_eq!(corrector.name(), "additive_gbt");
1229    }
1230
1231    #[test]
1232    fn ceiling_clamps_when_prediction_exceeds_it() {
1233        let obs = synthetic_constant(200, 1000);
1234        let opts = AdditiveGbtOptions {
1235            learning_rate: 0.3,
1236            max_depth: 4,
1237            num_trees: 50,
1238            ceiling: 100, // far below the trained constant
1239            min_leaf_size: 1,
1240        };
1241        let corrector = AdditiveGbtCorrector::train(&obs, opts).expect("training");
1242
1243        let features = CorrectionFeatures {
1244            baseline_estimate: 500,
1245            ..Default::default()
1246        };
1247        let corrected = corrector
1248            .correct(&features)
1249            .expect("correct")
1250            .expect("Some");
1251        assert_eq!(corrected, 100, "ceiling must clamp the additive correction");
1252        assert_eq!(corrector.ceiling(), 100);
1253    }
1254
1255    #[test]
1256    fn corrects_nonzero_even_when_baseline_estimate_is_zero() {
1257        // This is the q=∞ fix proof. The multiplicative GbtCorrector
1258        // would return 0 here (baseline * exp(predicted) = 0 * _ = 0).
1259        // The additive backend must escape that trap.
1260        let obs = synthetic_constant(200, 1000);
1261        let corrector =
1262            AdditiveGbtCorrector::train(&obs, AdditiveGbtOptions::default()).expect("training");
1263
1264        let features = CorrectionFeatures {
1265            baseline_estimate: 0,
1266            ..Default::default()
1267        };
1268        let corrected = corrector
1269            .correct(&features)
1270            .expect("correct")
1271            .expect("Some");
1272        assert!(
1273            corrected > 0,
1274            "additive corrector must return non-zero even when baseline_estimate = 0; got {corrected}"
1275        );
1276    }
1277
1278    #[test]
1279    fn empty_observations_errors() {
1280        match AdditiveGbtCorrector::train(&[], AdditiveGbtOptions::default()) {
1281            Ok(_) => panic!("expected error on empty observations"),
1282            Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1283        }
1284    }
1285}
1286
1287#[cfg(all(test, feature = "tabpfn_http"))]
1288mod tabpfn_http_tests {
1289    use super::tabpfn::{TabPfnHttpCorrector, TabPfnHttpOptions};
1290    use super::{CorrectionFeatures, Corrector};
1291
1292    /// Pointing at port 1 on the loopback interface is the canonical
1293    /// "guaranteed-to-refuse-connection" target on Linux/macOS. The
1294    /// safety contract says: any transport failure must surface as
1295    /// `Ok(None)`, never `Err`, never a panic. We verify that here
1296    /// without standing up a real inference server.
1297    #[test]
1298    fn http_failure_returns_none_not_error() {
1299        let corrector = TabPfnHttpCorrector::new(TabPfnHttpOptions {
1300            base_url: "http://127.0.0.1:1/infer".into(),
1301            timeout_ms: 50,
1302            ceiling: u64::MAX,
1303        });
1304        let features = CorrectionFeatures {
1305            baseline_estimate: 1234,
1306            ..Default::default()
1307        };
1308        let result = corrector.correct(&features);
1309        assert!(
1310            result.is_ok(),
1311            "tabpfn-http transport failure must not propagate as Err; got {result:?}"
1312        );
1313        assert_eq!(
1314            result.unwrap(),
1315            None,
1316            "tabpfn-http transport failure must yield Ok(None) so the engine falls back cleanly"
1317        );
1318        assert_eq!(corrector.name(), "tabpfn-http");
1319    }
1320
1321    #[test]
1322    fn malformed_url_returns_none() {
1323        // Not even a valid URL — `ureq` rejects this at request-build
1324        // time, which our error path must absorb the same as any other
1325        // transport failure.
1326        let corrector = TabPfnHttpCorrector::with_url("not a url at all");
1327        let features = CorrectionFeatures::default();
1328        let result = corrector.correct(&features).expect("never Err");
1329        assert_eq!(result, None);
1330    }
1331
1332    #[test]
1333    fn options_default_is_localhost() {
1334        let opts = TabPfnHttpOptions::default();
1335        assert!(opts.base_url.starts_with("http://"));
1336        assert!(opts.timeout_ms > 0);
1337        assert_eq!(opts.ceiling, u64::MAX);
1338    }
1339}
1340
1341#[cfg(all(test, feature = "llm_http"))]
1342mod llm_http_tests {
1343    use super::llm::{
1344        DEFAULT_TIMEOUT_MS, DEFAULT_URL, LlmHttpCorrector, LlmHttpOptions, MAX_TIMEOUT_MS,
1345    };
1346    use super::{CorrectionFeatures, Corrector};
1347    use std::io::{Read, Write};
1348    use std::net::TcpListener;
1349    use std::sync::atomic::{AtomicUsize, Ordering};
1350    use std::sync::{Arc, Mutex};
1351    use std::thread;
1352    use std::time::Duration;
1353
1354    /// Tiny hand-rolled mock HTTP server, one-shot per accept. We avoid
1355    /// pulling `mockito` (not currently a dep) and keep the test binary
1356    /// lean. The server reads the full request, then writes a fixed
1357    /// response. The handler closure decides what to send so the same
1358    /// scaffolding serves both success and parse-error cases.
1359    fn spawn_mock(
1360        responder: impl Fn(usize) -> Vec<u8> + Send + Sync + 'static,
1361        max_requests: usize,
1362    ) -> (String, Arc<AtomicUsize>) {
1363        let listener = TcpListener::bind("127.0.0.1:0").expect("bind loopback");
1364        let port = listener.local_addr().unwrap().port();
1365        let url = format!("http://127.0.0.1:{port}/infer");
1366        let counter = Arc::new(AtomicUsize::new(0));
1367        let counter_thread = Arc::clone(&counter);
1368        let responder = Arc::new(Mutex::new(responder));
1369        thread::spawn(move || {
1370            listener
1371                .set_nonblocking(false)
1372                .expect("blocking mode for mock");
1373            for stream in listener.incoming().take(max_requests) {
1374                let Ok(mut stream) = stream else { continue };
1375                let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
1376                let _ = stream.set_write_timeout(Some(Duration::from_secs(2)));
1377                // Drain HTTP request: read headers + body. We pull a
1378                // bounded chunk; the bench client sends tiny payloads
1379                // (sub-200 bytes) so this is sufficient for the tests
1380                // and avoids the parsing complexity of a full HTTP
1381                // server.
1382                let mut buf = [0u8; 4096];
1383                let _ = stream.read(&mut buf);
1384                let idx = counter_thread.fetch_add(1, Ordering::SeqCst);
1385                let body = responder.lock().unwrap()(idx);
1386                let header = format!(
1387                    "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
1388                    body.len()
1389                );
1390                let _ = stream.write_all(header.as_bytes());
1391                let _ = stream.write_all(&body);
1392                let _ = stream.flush();
1393            }
1394        });
1395        (url, counter)
1396    }
1397
1398    /// Pointing at port 1 on the loopback interface is the canonical
1399    /// "guaranteed-to-refuse-connection" target on Linux/macOS. The
1400    /// safety contract says: any transport failure must surface as
1401    /// `Ok(None)`, never `Err`, never a panic.
1402    #[test]
1403    fn http_failure_returns_none_not_error() {
1404        let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1405            base_url: "http://127.0.0.1:1/infer".into(),
1406            timeout_ms: 50,
1407            ceiling: u64::MAX,
1408        });
1409        let features = CorrectionFeatures {
1410            baseline_estimate: 1234,
1411            ..Default::default()
1412        };
1413        let result = corrector.correct(&features);
1414        assert!(
1415            result.is_ok(),
1416            "llm-http transport failure must not propagate as Err; got {result:?}"
1417        );
1418        assert_eq!(
1419            result.unwrap(),
1420            None,
1421            "llm-http transport failure must yield Ok(None) so the engine falls back cleanly"
1422        );
1423        assert_eq!(corrector.name(), "llm-http");
1424    }
1425
1426    #[test]
1427    fn malformed_url_returns_none() {
1428        let corrector = LlmHttpCorrector::with_url("not a url at all");
1429        let features = CorrectionFeatures::default();
1430        let result = corrector.correct(&features).expect("never Err");
1431        assert_eq!(result, None);
1432    }
1433
1434    #[test]
1435    fn options_default_is_localhost_on_llm_port() {
1436        let opts = LlmHttpOptions::default();
1437        assert_eq!(opts.base_url, DEFAULT_URL);
1438        assert!(opts.base_url.contains(":8766"));
1439        assert_eq!(opts.timeout_ms, DEFAULT_TIMEOUT_MS);
1440        assert_eq!(opts.ceiling, u64::MAX);
1441    }
1442
1443    #[test]
1444    fn timeout_is_saturated_to_max() {
1445        let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1446            base_url: "http://127.0.0.1:1/infer".into(),
1447            timeout_ms: MAX_TIMEOUT_MS * 10,
1448            ceiling: u64::MAX,
1449        });
1450        assert_eq!(corrector.options().timeout_ms, MAX_TIMEOUT_MS);
1451    }
1452
1453    #[test]
1454    fn mock_success_returns_clamped_estimate() {
1455        let (url, counter) = spawn_mock(|_| br#"{"estimate": 4242}"#.to_vec(), 2);
1456        let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1457            base_url: url,
1458            timeout_ms: 2_000,
1459            ceiling: 1_000_000,
1460        });
1461        let features = CorrectionFeatures {
1462            baseline_estimate: 1_000,
1463            ..Default::default()
1464        };
1465        let result = corrector.correct(&features).expect("ok");
1466        assert_eq!(result, Some(4242));
1467        assert!(counter.load(Ordering::SeqCst) >= 1);
1468    }
1469
1470    #[test]
1471    fn mock_clamps_to_ceiling() {
1472        let (url, _counter) = spawn_mock(|_| br#"{"estimate": 99999999}"#.to_vec(), 2);
1473        let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1474            base_url: url,
1475            timeout_ms: 2_000,
1476            ceiling: 500,
1477        });
1478        let result = corrector
1479            .correct(&CorrectionFeatures::default())
1480            .expect("ok");
1481        assert_eq!(result, Some(500));
1482    }
1483
1484    #[test]
1485    fn mock_parse_error_returns_none() {
1486        let (url, _counter) = spawn_mock(|_| b"not json at all".to_vec(), 2);
1487        let corrector = LlmHttpCorrector::with_url(url);
1488        let result = corrector
1489            .correct(&CorrectionFeatures::default())
1490            .expect("ok");
1491        assert_eq!(result, None);
1492    }
1493}