samkhya_core/residual.rs
1//! Residual correction model.
2//!
3//! Optional feedback-driven correction layer. Takes a baseline cardinality
4//! estimate plus a feature vector (query plan + column stats) and returns
5//! a corrected estimate. Trained on observations recorded by
6//! [`crate::feedback`].
7//!
8//! Contracts every backend honors:
9//!
10//! - bounded — output never exceeds the LpBound ceiling ([`crate::lpbound`]); the corrector clamps.
11//! - sub-MB / sub-ms — model footprint and per-estimate latency are the architectural budget.
12//! - optional — engines opt in; with no model attached, samkhya behaves as portable stats + envelope.
13//!
14//! Concrete backends (all behind feature flags):
15//!
16//! - `gbt` — gradient-boosted trees (the `gbt` submodule, gated on the `gbt` cargo feature)
17//! - `additive_gbt` — additive gradient-boosted trees (the `additive` submodule, gated on `additive_gbt`)
18//! - `tabpfn` — foundation-model interface (see the `tabpfn` submodule,
19//! gated on the `tabpfn_http` cargo feature)
20//! - `llm` — LLM-pluggable corrector backend (see the `llm` submodule,
21//! gated on the `llm_http` cargo feature). Same wire contract as
22//! `tabpfn`; the server picks an Anthropic / OpenAI / local-Ollama /
23//! dummy provider via the `SAMKHYA_LLM_BACKEND` env var.
24//!
25//! # Foundation-model interface (Layer 5)
26//!
27//! The architecture reserves a pluggable backend slot for foundation tabular
28//! models such as TabPFN-2.5 (arXiv 2511.08667). The contract is identical
29//! to every other backend:
30//!
31//! > *feed [`CorrectionFeatures`], receive `Option<u64>` clamped to the
32//! > LpBound ceiling.*
33//!
34//! Two deployment shapes are scaffolded:
35//!
36//! 1. **localhost HTTP** — a Python TabPFN inference server runs out of
37//! band; samkhya POSTs the feature vector as JSON and reads back an
38//! `{"estimate": <u64>}` response. Implemented today behind the
39//! `tabpfn_http` cargo feature (see `tabpfn::TabPfnHttpCorrector`).
40//! 2. **subprocess** — samkhya spawns a Python child, frames JSON over
41//! stdin/stdout, and keeps the process warm across estimates. Deferred:
42//! the scaffolding is present (umbrella `tabpfn` feature), the
43//! transport itself is not implemented in this revision.
44//!
45//! A no-op [`TabPfnStub`] is **always** compiled in, regardless of
46//! features. Its job is purely architectural: downstream code can reference
47//! `TabPfnStub` to mark "TabPFN integration point, currently disabled"
48//! without taking the `tabpfn_http` feature dependency. This reflects the
49//! integration point in every build, so the contract is visible even when
50//! the transport is not.
51//!
52//! Failure policy across all TabPFN backends: any transport error, parse
53//! error, or timeout returns `Ok(None)` (never `Err`). The engine then
54//! falls back cleanly to the native estimate. This is the safety contract;
55//! a remote inference server going down must never surface as a query
56//! failure.
57
58use crate::Result;
59
60/// Emit a single `log::warn!` the first time a plaintext-HTTP URL
61/// pointing at a non-loopback host is configured for an HTTP corrector
62/// backend. See `documents/SECURITY-REVIEW-2026-05-17.md` (H2): such a
63/// URL means features and any embedded baseline estimate travel
64/// unencrypted on the wire. The warning is fire-and-forget — no behaviour
65/// change, so well-configured operators (the typical case, defaults are
66/// localhost) see nothing.
67#[cfg(any(feature = "tabpfn_http", feature = "llm_http"))]
68fn warn_if_remote_plaintext_http(url: &str, backend: &'static str) {
69 let lower = url.to_ascii_lowercase();
70 if !lower.starts_with("http://") {
71 return;
72 }
73 // Pull the host (between "http://" and the next "/" or ":" or end).
74 let rest = &url[7..]; // safe: starts_with confirmed above
75 let host_end = rest
76 .find(|c: char| c == '/' || c == ':' || c == '?')
77 .unwrap_or(rest.len());
78 let host = &rest[..host_end];
79 let is_loopback = matches!(host, "127.0.0.1" | "::1" | "localhost")
80 || host.starts_with("[::1]")
81 || host.starts_with("127.");
82 if is_loopback {
83 return;
84 }
85 if std::env::var("SAMKHYA_ALLOW_REMOTE_HTTP").as_deref() == Ok("1") {
86 return;
87 }
88 log::warn!(
89 "samkhya {backend} corrector configured with plaintext HTTP to non-loopback host {host}; \
90 features and baseline_estimate will travel unencrypted. Use https:// or set \
91 SAMKHYA_ALLOW_REMOTE_HTTP=1 to silence this warning."
92 );
93}
94
95/// Feature vector handed to the corrector at estimate time.
96///
97/// Intentionally minimal at v0.0.1: row count + distinct count + null
98/// count + a small set of operator-level features. Will grow as the
99/// feedback-collection surface widens.
100///
101/// # Examples
102///
103/// ```
104/// use samkhya_core::residual::CorrectionFeatures;
105///
106/// let features = CorrectionFeatures {
107/// baseline_estimate: 1000,
108/// left_input_rows: Some(500),
109/// right_input_rows: Some(2000),
110/// predicate_count: 2,
111/// join_depth: 1,
112/// ..Default::default()
113/// };
114/// assert_eq!(features.to_vec().len(), CorrectionFeatures::FEATURE_LEN);
115/// ```
116#[derive(Debug, Clone, Default)]
117pub struct CorrectionFeatures {
118 pub baseline_estimate: u64,
119 pub left_input_rows: Option<u64>,
120 pub right_input_rows: Option<u64>,
121 pub left_distinct: Option<u64>,
122 pub right_distinct: Option<u64>,
123 pub predicate_count: u32,
124 pub join_depth: u32,
125}
126
127impl CorrectionFeatures {
128 /// Flatten the feature struct into a fixed-length numeric vector for a
129 /// regression model. `Option<u64>` slots are zero-filled when absent —
130 /// callers should treat zero as "unknown" rather than "literally zero
131 /// rows", which is the convention the corrector is trained against.
132 ///
133 /// Layout (stable; new features must be appended, never reordered):
134 ///
135 /// 0. `baseline_estimate`
136 /// 1. `left_input_rows` (0 if `None`)
137 /// 2. `right_input_rows` (0 if `None`)
138 /// 3. `left_distinct` (0 if `None`)
139 /// 4. `right_distinct` (0 if `None`)
140 /// 5. `predicate_count`
141 /// 6. `join_depth`
142 ///
143 /// # Examples
144 ///
145 /// ```
146 /// use samkhya_core::residual::CorrectionFeatures;
147 ///
148 /// let f = CorrectionFeatures {
149 /// baseline_estimate: 100,
150 /// left_input_rows: Some(10),
151 /// predicate_count: 3,
152 /// ..Default::default()
153 /// };
154 /// let v = f.to_vec();
155 /// assert_eq!(v[0], 100.0);
156 /// assert_eq!(v[1], 10.0);
157 /// assert_eq!(v[2], 0.0); // None → 0
158 /// assert_eq!(v[5], 3.0);
159 /// ```
160 pub fn to_vec(&self) -> Vec<f64> {
161 vec![
162 self.baseline_estimate as f64,
163 self.left_input_rows.unwrap_or(0) as f64,
164 self.right_input_rows.unwrap_or(0) as f64,
165 self.left_distinct.unwrap_or(0) as f64,
166 self.right_distinct.unwrap_or(0) as f64,
167 f64::from(self.predicate_count),
168 f64::from(self.join_depth),
169 ]
170 }
171
172 /// Number of entries [`to_vec`](Self::to_vec) produces.
173 pub const FEATURE_LEN: usize = 7;
174}
175
176/// A pluggable corrector. Engines call [`Corrector::correct`] on every
177/// estimate that passes through samkhya's optimizer hook.
178///
179/// Returning `Ok(None)` lets the engine fall back to the baseline estimate;
180/// returning `Ok(Some(_))` overrides it (subject to the LpBound envelope).
181///
182/// # Examples
183///
184/// ```
185/// use samkhya_core::residual::{CorrectionFeatures, Corrector, IdentityCorrector};
186///
187/// let corrector = IdentityCorrector;
188/// let features = CorrectionFeatures {
189/// baseline_estimate: 42,
190/// ..Default::default()
191/// };
192/// // The identity corrector passes the baseline through unchanged.
193/// assert_eq!(corrector.correct(&features).unwrap(), Some(42));
194/// ```
195pub trait Corrector: Send + Sync {
196 /// Return a corrected estimate, or `None` to fall back to the baseline.
197 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>>;
198
199 /// Stable identifier for logging / model-version tracking.
200 fn name(&self) -> &'static str;
201}
202
203/// Default zero-cost corrector: passes the baseline through unchanged.
204///
205/// Used when no feedback history exists yet (cold start) or when the
206/// caller opts out of feedback-driven correction entirely.
207///
208/// # Examples
209///
210/// ```
211/// use samkhya_core::residual::{CorrectionFeatures, Corrector, IdentityCorrector};
212///
213/// let c = IdentityCorrector;
214/// let f = CorrectionFeatures { baseline_estimate: 1234, ..Default::default() };
215/// assert_eq!(c.correct(&f).unwrap(), Some(1234));
216/// assert_eq!(c.name(), "identity");
217/// ```
218pub struct IdentityCorrector;
219
220impl Corrector for IdentityCorrector {
221 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
222 Ok(Some(features.baseline_estimate))
223 }
224
225 fn name(&self) -> &'static str {
226 "identity"
227 }
228}
229
230/// No-op stub for the foundation-model interface (Layer 5).
231///
232/// Always compiled, regardless of cargo features. Returns `Ok(None)` from
233/// every call, signalling the engine to fall back to its native estimate.
234///
235/// The point of an always-on stub is architectural: it lets downstream
236/// callers reference `TabPfnStub` to mark "TabPFN integration point,
237/// currently disabled" without taking the `tabpfn_http` feature
238/// dependency. The integration shape is visible in every build.
239///
240/// To wire in a real foundation-model backend, swap this for
241/// `tabpfn::TabPfnHttpCorrector` (gated on `tabpfn_http`) or a future
242/// subprocess adapter. The trait contract is identical, so the swap is
243/// a one-line change at the call site.
244///
245/// # Examples
246///
247/// ```
248/// use samkhya_core::residual::{CorrectionFeatures, Corrector, TabPfnStub};
249///
250/// let stub = TabPfnStub;
251/// // Stub always returns Ok(None) — the engine falls back to its native estimate.
252/// let f = CorrectionFeatures { baseline_estimate: 999, ..Default::default() };
253/// assert_eq!(stub.correct(&f).unwrap(), None);
254/// assert_eq!(stub.name(), "tabpfn-stub");
255/// ```
256pub struct TabPfnStub;
257
258impl Corrector for TabPfnStub {
259 fn correct(&self, _features: &CorrectionFeatures) -> Result<Option<u64>> {
260 // Deliberately `None`: the integration point is wired but
261 // disabled. The engine falls back to the native estimate.
262 Ok(None)
263 }
264
265 fn name(&self) -> &'static str {
266 "tabpfn-stub"
267 }
268}
269
270#[cfg(feature = "gbt")]
271pub mod gbt {
272 //! Gradient-boosted-tree residual corrector.
273 //!
274 //! Wraps the `gbdt` crate (Baidu / mesalock-linux,
275 //! <https://github.com/mesalock-linux/gbdt-rs>) — pure-Rust, no native
276 //! deps, builds on stable Rust 1.94 / edition 2024. Compiled in only
277 //! when the `gbt` cargo feature is enabled.
278 //!
279 //! Training target is `log(actual_rows / est_rows)` — the
280 //! multiplicative correction ratio in log-space. At prediction time
281 //! we exponentiate and multiply through the baseline, then clamp to
282 //! the configured LpBound ceiling via
283 //! [`crate::lpbound::saturating_clamp`] so the corrector cannot ever
284 //! violate the envelope contract.
285 //!
286 //! Observations with `est_rows == 0` or `actual_rows == 0` are
287 //! silently dropped (log of zero is undefined); we do not invent a
288 //! Laplace-style smoothing constant at the corrector layer.
289
290 use gbdt::config::{Config, Loss};
291 use gbdt::decision_tree::{Data, DataVec};
292 use gbdt::gradient_boost::GBDT;
293
294 use super::{CorrectionFeatures, Corrector};
295 use crate::feedback::Observation;
296 use crate::lpbound::saturating_clamp;
297 use crate::{Error, Result};
298
299 /// Tunables for [`GbtCorrector::train`]. Defaults are an MVP starting
300 /// point: shallow trees, modest depth, square-error loss.
301 #[derive(Debug, Clone)]
302 pub struct GbtOptions {
303 /// Shrinkage / learning rate applied to each tree's contribution.
304 pub learning_rate: f64,
305 /// Max depth of each regression tree. Root is depth 0.
306 pub max_depth: u32,
307 /// Number of boosting iterations (one tree per iteration).
308 pub num_trees: u32,
309 /// Inclusive upper bound applied to every corrected estimate.
310 /// Use `u64::MAX` to disable (the trait signature has no ceiling
311 /// slot, so we store it here at train time).
312 pub ceiling: u64,
313 /// Minimum samples per leaf — guards against overfitting tiny
314 /// feedback histories.
315 pub min_leaf_size: usize,
316 }
317
318 impl Default for GbtOptions {
319 fn default() -> Self {
320 Self {
321 learning_rate: 0.1,
322 max_depth: 4,
323 num_trees: 50,
324 ceiling: u64::MAX,
325 min_leaf_size: 1,
326 }
327 }
328 }
329
330 /// Trained GBT-backed residual corrector.
331 pub struct GbtCorrector {
332 model: GBDT,
333 ceiling: u64,
334 }
335
336 impl GbtCorrector {
337 /// Train a corrector from a slice of [`Observation`]s.
338 ///
339 /// Returns [`Error::Feedback`] if the observation slice is empty,
340 /// or if every observation is unusable (zero est_rows or zero
341 /// actual_rows). Non-positive-ratio observations are silently
342 /// filtered, matching the convention in [`Observation::q_error`].
343 pub fn train(observations: &[Observation], options: GbtOptions) -> Result<Self> {
344 if observations.is_empty() {
345 return Err(Error::Feedback(
346 "cannot train GbtCorrector: observation slice is empty".into(),
347 ));
348 }
349
350 let mut training: DataVec = Vec::with_capacity(observations.len());
351 for obs in observations {
352 if obs.est_rows == 0 || obs.actual_rows == 0 {
353 continue;
354 }
355 // Reconstruct a feature vector from the observation. The
356 // feedback table doesn't yet carry full plan features, so
357 // we synthesize the minimal `baseline_estimate`-only
358 // vector. As `Observation` gains columns the mapping
359 // below should grow in lockstep with `CorrectionFeatures`.
360 let features = CorrectionFeatures {
361 baseline_estimate: obs.est_rows,
362 ..Default::default()
363 };
364 let feature_f32: Vec<f32> =
365 features.to_vec().into_iter().map(|v| v as f32).collect();
366 let ratio_log = (obs.actual_rows as f64 / obs.est_rows as f64).ln() as f32;
367 training.push(Data::new_training_data(feature_f32, 1.0, ratio_log, None));
368 }
369
370 if training.is_empty() {
371 return Err(Error::Feedback(
372 "cannot train GbtCorrector: all observations had zero est or actual rows"
373 .into(),
374 ));
375 }
376
377 let mut cfg = Config::new();
378 cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
379 cfg.set_max_depth(options.max_depth);
380 cfg.set_iterations(options.num_trees as usize);
381 cfg.set_shrinkage(options.learning_rate as f32);
382 cfg.set_min_leaf_size(options.min_leaf_size);
383 cfg.set_loss(&loss_name(Loss::SquaredError));
384
385 let mut model = GBDT::new(&cfg);
386 model.fit(&mut training);
387
388 Ok(Self {
389 model,
390 ceiling: options.ceiling,
391 })
392 }
393
394 /// Predict the log-ratio correction for a single feature vector.
395 /// Exposed for diagnostics / unit tests; the production path is
396 /// [`Corrector::correct`].
397 pub fn predict_log_ratio(&self, features: &CorrectionFeatures) -> f64 {
398 let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
399 let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
400 let preds = self.model.predict(&probe);
401 preds.first().copied().unwrap_or(0.0) as f64
402 }
403
404 /// Configured upper bound. Set at training time; the trait method
405 /// [`Corrector::correct`] enforces it via `saturating_clamp`.
406 pub fn ceiling(&self) -> u64 {
407 self.ceiling
408 }
409 }
410
411 impl Corrector for GbtCorrector {
412 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
413 let log_ratio = self.predict_log_ratio(features);
414 let ratio = log_ratio.exp();
415 let scaled = features.baseline_estimate as f64 * ratio;
416 Ok(Some(saturating_clamp(scaled, self.ceiling)))
417 }
418
419 fn name(&self) -> &'static str {
420 "gbt"
421 }
422 }
423
424 /// `gbdt::config::Config::set_loss` takes a string; this is the
425 /// canonical spelling for square-error in that crate.
426 fn loss_name(loss: Loss) -> String {
427 gbdt::config::loss2string(&loss)
428 }
429}
430
431#[cfg(feature = "additive_gbt")]
432pub mod additive {
433 //! Additive gradient-boosted-tree residual corrector.
434 //!
435 //! Sibling backend to [`super::gbt`]. The multiplicative form trains on
436 //! `log(actual / baseline_estimate)` and applies the correction as
437 //! `baseline * exp(predicted)`. That model is structurally trapped at
438 //! zero whenever the engine hands us `baseline_estimate = 0` — the
439 //! q=∞ regime where the upstream estimator has completely collapsed
440 //! (a common DataFusion 46 symptom on chained joins).
441 //!
442 //! The additive backend sidesteps that trap by training the model to
443 //! predict the **absolute** `actual_rows` from the full
444 //! [`CorrectionFeatures`] vector (all 7 features, not just the
445 //! baseline). The prediction is clamped to a non-negative integer and
446 //! then to the configured LpBound ceiling via
447 //! [`crate::lpbound::saturating_clamp`], so the envelope contract is
448 //! preserved.
449 //!
450 //! Cargo feature: `additive_gbt`. Independent of the `gbt` feature —
451 //! they can be enabled separately or together.
452
453 use gbdt::config::{Config, Loss};
454 use gbdt::decision_tree::{Data, DataVec};
455 use gbdt::gradient_boost::GBDT;
456 use std::sync::Mutex;
457
458 use super::{CorrectionFeatures, Corrector};
459 use crate::feedback::Observation;
460 use crate::lpbound::saturating_clamp;
461 use crate::{Error, Result};
462
463 /// Tunables for [`AdditiveGbtCorrector::train`]. Defaults mirror
464 /// [`super::gbt::GbtOptions`] so the two backends are
465 /// drop-in-comparable when benchmarking.
466 #[derive(Debug, Clone)]
467 pub struct AdditiveGbtOptions {
468 /// Shrinkage / learning rate applied to each tree's contribution.
469 pub learning_rate: f64,
470 /// Max depth of each regression tree. Root is depth 0.
471 pub max_depth: u32,
472 /// Number of boosting iterations (one tree per iteration).
473 pub num_trees: u32,
474 /// Inclusive upper bound applied to every corrected estimate.
475 /// Use `u64::MAX` to disable.
476 pub ceiling: u64,
477 /// Minimum samples per leaf — guards against overfitting tiny
478 /// feedback histories.
479 pub min_leaf_size: usize,
480 }
481
482 impl Default for AdditiveGbtOptions {
483 fn default() -> Self {
484 Self {
485 learning_rate: 0.1,
486 max_depth: 4,
487 num_trees: 50,
488 ceiling: u64::MAX,
489 min_leaf_size: 1,
490 }
491 }
492 }
493
494 /// Trained additive GBT corrector. Predicts absolute row counts.
495 ///
496 /// The model is wrapped in a [`Mutex`] because `gbdt::GBDT::predict`
497 /// takes `&mut self` on some configurations; the lock is held only
498 /// for the prediction call and is uncontended in the common single-
499 /// threaded estimate path.
500 pub struct AdditiveGbtCorrector {
501 model: Mutex<GBDT>,
502 ceiling: u64,
503 }
504
505 impl AdditiveGbtCorrector {
506 /// Train an additive corrector from a slice of [`Observation`]s.
507 ///
508 /// Returns [`Error::Feedback`] if the observation slice is empty.
509 /// Unlike the multiplicative backend, observations with
510 /// `est_rows == 0` are **kept** — they are precisely the q=∞
511 /// regime this backend exists to handle. Observations with
512 /// `actual_rows == 0` are also kept (a true-zero output is a
513 /// valid signal for an additive model).
514 pub fn train(observations: &[Observation], options: AdditiveGbtOptions) -> Result<Self> {
515 if observations.is_empty() {
516 return Err(Error::Feedback(
517 "cannot train AdditiveGbtCorrector: observation slice is empty".into(),
518 ));
519 }
520
521 let mut training: DataVec = Vec::with_capacity(observations.len());
522 for obs in observations {
523 // Reconstruct a feature vector from the observation. The
524 // feedback table doesn't yet carry the full plan-shape
525 // feature set, so we synthesize from `est_rows`. As
526 // `Observation` gains columns, mirror the additions here.
527 let features = CorrectionFeatures {
528 baseline_estimate: obs.est_rows,
529 ..Default::default()
530 };
531 let feature_f32: Vec<f32> =
532 features.to_vec().into_iter().map(|v| v as f32).collect();
533 let target = obs.actual_rows as f32;
534 training.push(Data::new_training_data(feature_f32, 1.0, target, None));
535 }
536
537 // Empty observations are caught above; the synthesized
538 // training set here is always non-empty.
539 debug_assert!(!training.is_empty());
540
541 let mut cfg = Config::new();
542 cfg.set_feature_size(CorrectionFeatures::FEATURE_LEN);
543 cfg.set_max_depth(options.max_depth);
544 cfg.set_iterations(options.num_trees as usize);
545 cfg.set_shrinkage(options.learning_rate as f32);
546 cfg.set_min_leaf_size(options.min_leaf_size);
547 cfg.set_loss(&gbdt::config::loss2string(&Loss::SquaredError));
548
549 let mut model = GBDT::new(&cfg);
550 model.fit(&mut training);
551
552 Ok(Self {
553 model: Mutex::new(model),
554 ceiling: options.ceiling,
555 })
556 }
557
558 /// Predict the absolute row count for a feature vector.
559 /// Exposed for diagnostics; the production path is
560 /// [`Corrector::correct`].
561 pub fn predict_rows(&self, features: &CorrectionFeatures) -> f64 {
562 let feature_f32: Vec<f32> = features.to_vec().into_iter().map(|v| v as f32).collect();
563 let probe: DataVec = vec![Data::new_test_data(feature_f32, None)];
564 let model = self.model.lock().expect("AdditiveGbtCorrector model lock");
565 let preds = model.predict(&probe);
566 preds.first().copied().unwrap_or(0.0) as f64
567 }
568
569 /// Configured upper bound. Set at training time; the trait method
570 /// [`Corrector::correct`] enforces it via `saturating_clamp`.
571 pub fn ceiling(&self) -> u64 {
572 self.ceiling
573 }
574 }
575
576 impl Corrector for AdditiveGbtCorrector {
577 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
578 let raw = self.predict_rows(features).max(0.0);
579 Ok(Some(saturating_clamp(raw, self.ceiling)))
580 }
581
582 fn name(&self) -> &'static str {
583 "additive_gbt"
584 }
585 }
586}
587
588#[cfg(feature = "tabpfn_http")]
589pub mod tabpfn {
590 //! Foundation-model interface — HTTP transport.
591 //!
592 //! Posts a [`super::CorrectionFeatures`] vector as JSON to a
593 //! user-configured endpoint (e.g., a Python TabPFN inference server
594 //! listening on `http://localhost:8765/infer`), parses an
595 //! `{"estimate": <u64>}` reply, and clamps the result to the LpBound
596 //! ceiling via [`crate::lpbound::saturating_clamp`].
597 //!
598 //! Transport: pure-Rust `ureq` (rustls-only, no OpenSSL). Compiled in
599 //! only when the `tabpfn_http` cargo feature is enabled.
600 //!
601 //! # Safety contract
602 //!
603 //! Any failure — DNS, connection refused, HTTP non-2xx, body parse
604 //! error, timeout — returns `Ok(None)`. The engine falls back to the
605 //! native estimate. We never propagate transport errors to the
606 //! optimizer hot path; a remote inference server going down must not
607 //! surface as a query failure.
608 //!
609 //! Note on naming: this is *the foundation-model interface*, not a
610 //! "learned" or "AI" feature. The corrector is a pluggable backend
611 //! behind the same `Corrector` trait as every other backend in this
612 //! module.
613 //!
614 //! # Wire format
615 //!
616 //! Request body (JSON):
617 //!
618 //! ```json
619 //! {
620 //! "features": [<f64>, <f64>, ...],
621 //! "baseline_estimate": <u64>
622 //! }
623 //! ```
624 //!
625 //! Response body (JSON):
626 //!
627 //! ```json
628 //! { "estimate": <u64> }
629 //! ```
630 //!
631 //! Any extra fields in the response are ignored, so server
632 //! implementations are free to add diagnostics without breaking the
633 //! client.
634 //!
635 //! # See also
636 //!
637 //! - [`super::TabPfnStub`] — always-on no-op for the same integration
638 //! slot, no transport dependency.
639
640 use serde::{Deserialize, Serialize};
641 use std::time::Duration;
642
643 use super::{CorrectionFeatures, Corrector};
644 use crate::Result;
645 use crate::lpbound::saturating_clamp;
646
647 /// Configuration for [`TabPfnHttpCorrector`].
648 #[derive(Debug, Clone)]
649 pub struct TabPfnHttpOptions {
650 /// Inference endpoint URL. The corrector POSTs here on every
651 /// `correct()` call. Example: `http://localhost:8765/infer`.
652 pub base_url: String,
653 /// Per-request timeout. Applies independently to the connect and
654 /// read phases. Bounded by the architecture's sub-ms budget for
655 /// the production path, but configurable so users can dial it up
656 /// for diagnostics.
657 pub timeout_ms: u64,
658 /// Inclusive upper bound applied to every corrected estimate via
659 /// [`saturating_clamp`]. The Layer 3 safety guarantee — corrections
660 /// can never exceed this regardless of what the remote backend
661 /// returns. Use `u64::MAX` to disable.
662 pub ceiling: u64,
663 }
664
665 impl Default for TabPfnHttpOptions {
666 fn default() -> Self {
667 Self {
668 base_url: "http://localhost:8765/infer".into(),
669 timeout_ms: 50,
670 ceiling: u64::MAX,
671 }
672 }
673 }
674
675 /// JSON request body sent to the inference endpoint.
676 #[derive(Serialize)]
677 struct InferRequest<'a> {
678 features: &'a [f64],
679 baseline_estimate: u64,
680 }
681
682 /// JSON response body. Extra fields are ignored.
683 #[derive(Deserialize)]
684 struct InferResponse {
685 estimate: u64,
686 }
687
688 /// HTTP-backed foundation-model corrector.
689 ///
690 /// Holds a tiny client config and a base URL. The `ureq` agent is
691 /// constructed per-call: the per-estimate cost is dominated by network
692 /// round-trip, not agent allocation, and per-call agents keep the
693 /// struct cheaply `Send + Sync` without interior mutability.
694 pub struct TabPfnHttpCorrector {
695 options: TabPfnHttpOptions,
696 }
697
698 impl TabPfnHttpCorrector {
699 /// Build a corrector from explicit options.
700 pub fn new(options: TabPfnHttpOptions) -> Self {
701 super::warn_if_remote_plaintext_http(&options.base_url, "tabpfn_http");
702 Self { options }
703 }
704
705 /// Convenience constructor: default options with the supplied URL.
706 pub fn with_url(base_url: impl Into<String>) -> Self {
707 let opts = TabPfnHttpOptions {
708 base_url: base_url.into(),
709 ..TabPfnHttpOptions::default()
710 };
711 super::warn_if_remote_plaintext_http(&opts.base_url, "tabpfn_http");
712 Self { options: opts }
713 }
714
715 /// Configured options (for diagnostics / logging).
716 pub fn options(&self) -> &TabPfnHttpOptions {
717 &self.options
718 }
719
720 /// Attempt one inference call. Returns `None` on any failure
721 /// (network, parse, non-2xx). The `correct()` trait method wraps
722 /// this and applies the LpBound clamp.
723 fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
724 let feature_vec = features.to_vec();
725 let payload = InferRequest {
726 features: &feature_vec,
727 baseline_estimate: features.baseline_estimate,
728 };
729
730 let timeout = Duration::from_millis(self.options.timeout_ms);
731 let agent = ureq::AgentBuilder::new()
732 .timeout_connect(timeout)
733 .timeout_read(timeout)
734 .timeout_write(timeout)
735 .build();
736
737 let response = match agent.post(&self.options.base_url).send_json(&payload) {
738 Ok(r) => r,
739 Err(err) => {
740 // Map every transport error to None and log at debug.
741 // The Error::Feedback diagnostic carries the URL plus
742 // the underlying message so callers tailing logs can
743 // see what failed without us aborting the query.
744 log::debug!(
745 "tabpfn_http: request to {} failed: {}",
746 self.options.base_url,
747 err
748 );
749 return None;
750 }
751 };
752
753 match response.into_json::<InferResponse>() {
754 Ok(body) => Some(body.estimate),
755 Err(err) => {
756 log::debug!(
757 "tabpfn_http: response from {} failed to parse: {}",
758 self.options.base_url,
759 err
760 );
761 None
762 }
763 }
764 }
765 }
766
767 impl Corrector for TabPfnHttpCorrector {
768 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
769 // Safety contract: every failure returns Ok(None), not Err.
770 // The engine then transparently falls back to the native
771 // estimate. We use Result here to honour the trait shape and
772 // to keep a door open for future *non-fallback* error modes
773 // (e.g. a deliberate misconfiguration check), but on the hot
774 // path failures are absorbed.
775 let Some(raw) = self.try_infer(features) else {
776 return Ok(None);
777 };
778 Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
779 }
780
781 fn name(&self) -> &'static str {
782 "tabpfn-http"
783 }
784 }
785}
786
787#[cfg(feature = "llm_http")]
788pub mod llm {
789 //! LLM-pluggable corrector backend — HTTP transport.
790 //!
791 //! Posts a [`super::CorrectionFeatures`] vector as JSON to a
792 //! user-configured endpoint (e.g., a Python LLM inference server
793 //! listening on `http://localhost:8766/infer`) and parses an
794 //! `{"estimate": <u64>}` reply. The server-side LLM provider
795 //! (Anthropic, OpenAI, local Ollama, dummy) is selected by the
796 //! `SAMKHYA_LLM_BACKEND` env var on the server process — the wire
797 //! contract is identical regardless of which provider is configured.
798 //!
799 //! Transport: pure-Rust `ureq` (rustls-only, no OpenSSL). Compiled in
800 //! only when the `llm_http` cargo feature is enabled.
801 //!
802 //! # Naming
803 //!
804 //! This is *the LLM-pluggable corrector backend* — a transport-level
805 //! integration that lets a foundation language model serve as the
806 //! cardinality corrector behind the same `Corrector` trait as every
807 //! other backend in this module. It is **not** an "AI", "adaptive",
808 //! or "learned" feature; the samkhya envelope still dominates the
809 //! safety contract and the LLM is strictly an opt-in pluggable
810 //! backend. The default samkhya build does not pull this in.
811 //!
812 //! # Safety contract
813 //!
814 //! Any failure — DNS, connection refused, HTTP non-2xx, body parse
815 //! error, timeout — returns `Ok(None)`. The engine falls back to the
816 //! native estimate. We never propagate transport errors to the
817 //! optimizer hot path; a remote inference server going down must not
818 //! surface as a query failure. Mirrors the
819 //! [`super::tabpfn::TabPfnHttpCorrector`] contract exactly.
820 //!
821 //! # Wire format
822 //!
823 //! Request body (JSON):
824 //!
825 //! ```json
826 //! {
827 //! "features": [<f64>, <f64>, ...],
828 //! "baseline_estimate": <u64>
829 //! }
830 //! ```
831 //!
832 //! Response body (JSON):
833 //!
834 //! ```json
835 //! { "estimate": <u64> }
836 //! ```
837 //!
838 //! Any extra fields in the response are ignored, so server
839 //! implementations are free to add diagnostics (e.g., the LLM's raw
840 //! text reply, parse-status flags) without breaking the client.
841 //!
842 //! # Latency expectations
843 //!
844 //! LLM round-trips are 2–3 orders of magnitude slower than the TabPFN
845 //! tier (P95 in the 0.3–2 s range vs. ~30 ms for TabPFN). The default
846 //! per-request timeout is therefore 2 000 ms (vs. 50 ms for TabPFN),
847 //! with a 60 s hard cap available for cold-cache diagnostics. The
848 //! `llm_http` backend is intended for *offline / overnight*
849 //! re-validation and schema-introspection use cases, not the online
850 //! query hot path. See `bench-results/19_llm_corrector.md` §6 for
851 //! routing guidance.
852
853 use serde::{Deserialize, Serialize};
854 use std::time::Duration;
855
856 use super::{CorrectionFeatures, Corrector};
857 use crate::Result;
858 use crate::lpbound::saturating_clamp;
859
860 /// Default per-request timeout for the LLM HTTP backend (milliseconds).
861 /// LLMs are 2–3 orders of magnitude slower than TabPFN; the 2 s
862 /// default is the smallest budget that consistently covers warm-cache
863 /// Anthropic Claude / OpenAI GPT-4o-mini calls without spurious
864 /// timeouts in measurement.
865 pub const DEFAULT_TIMEOUT_MS: u64 = 2_000;
866
867 /// Hard per-request ceiling (milliseconds). Constructors that accept
868 /// a `timeout_ms` saturate to this value so a misconfigured caller
869 /// cannot pin the optimizer for longer than 60 s on a single call.
870 pub const MAX_TIMEOUT_MS: u64 = 60_000;
871
872 /// Default inference endpoint. Distinct from the TabPFN default port
873 /// (`8765`) so an operator can run both servers side-by-side without
874 /// collision.
875 pub const DEFAULT_URL: &str = "http://127.0.0.1:8766/infer";
876
877 /// Configuration for [`LlmHttpCorrector`].
878 #[derive(Debug, Clone)]
879 pub struct LlmHttpOptions {
880 /// Inference endpoint URL. The corrector POSTs here on every
881 /// `correct()` call. Example: `http://localhost:8766/infer`.
882 pub base_url: String,
883 /// Per-request timeout. Applies to connect, read, and write
884 /// phases. Capped at [`MAX_TIMEOUT_MS`] so a misconfigured caller
885 /// cannot stall the optimizer indefinitely.
886 pub timeout_ms: u64,
887 /// Inclusive upper bound applied to every corrected estimate via
888 /// [`saturating_clamp`]. The Layer 3 safety guarantee —
889 /// corrections can never exceed this regardless of what the
890 /// remote LLM returns. Use `u64::MAX` to disable.
891 pub ceiling: u64,
892 }
893
894 impl Default for LlmHttpOptions {
895 fn default() -> Self {
896 Self {
897 base_url: DEFAULT_URL.into(),
898 timeout_ms: DEFAULT_TIMEOUT_MS,
899 ceiling: u64::MAX,
900 }
901 }
902 }
903
904 /// JSON request body sent to the inference endpoint.
905 #[derive(Serialize)]
906 struct InferRequest<'a> {
907 features: &'a [f64],
908 baseline_estimate: u64,
909 }
910
911 /// JSON response body. Extra fields are ignored.
912 #[derive(Deserialize)]
913 struct InferResponse {
914 estimate: u64,
915 }
916
917 /// HTTP-backed LLM-pluggable corrector.
918 ///
919 /// Holds a tiny client config and a base URL. The `ureq` agent is
920 /// constructed per-call: the per-estimate cost is dominated by LLM
921 /// inference (hundreds of milliseconds), not agent allocation, and
922 /// per-call agents keep the struct cheaply `Send + Sync` without
923 /// interior mutability.
924 pub struct LlmHttpCorrector {
925 options: LlmHttpOptions,
926 }
927
928 impl LlmHttpCorrector {
929 /// Build a corrector from explicit options. The `timeout_ms`
930 /// value is saturated to [`MAX_TIMEOUT_MS`] so misconfigured
931 /// callers cannot stall the optimizer for longer than that.
932 pub fn new(mut options: LlmHttpOptions) -> Self {
933 if options.timeout_ms > MAX_TIMEOUT_MS {
934 options.timeout_ms = MAX_TIMEOUT_MS;
935 }
936 super::warn_if_remote_plaintext_http(&options.base_url, "llm_http");
937 Self { options }
938 }
939
940 /// Convenience constructor: default options with the supplied
941 /// URL. Useful for ad-hoc bench / smoke clients.
942 pub fn with_url(base_url: impl Into<String>) -> Self {
943 Self::new(LlmHttpOptions {
944 base_url: base_url.into(),
945 ..LlmHttpOptions::default()
946 })
947 }
948
949 /// Configured options (for diagnostics / logging).
950 pub fn options(&self) -> &LlmHttpOptions {
951 &self.options
952 }
953
954 /// Attempt one inference call. Returns `None` on any failure
955 /// (network, parse, non-2xx). The `correct()` trait method wraps
956 /// this and applies the LpBound clamp.
957 fn try_infer(&self, features: &CorrectionFeatures) -> Option<u64> {
958 let feature_vec = features.to_vec();
959 let payload = InferRequest {
960 features: &feature_vec,
961 baseline_estimate: features.baseline_estimate,
962 };
963
964 let timeout = Duration::from_millis(self.options.timeout_ms);
965 let agent = ureq::AgentBuilder::new()
966 .timeout_connect(timeout)
967 .timeout_read(timeout)
968 .timeout_write(timeout)
969 .build();
970
971 let response = match agent.post(&self.options.base_url).send_json(&payload) {
972 Ok(r) => r,
973 Err(err) => {
974 log::debug!(
975 "llm_http: request to {} failed: {}",
976 self.options.base_url,
977 err
978 );
979 return None;
980 }
981 };
982
983 match response.into_json::<InferResponse>() {
984 Ok(body) => Some(body.estimate),
985 Err(err) => {
986 log::debug!(
987 "llm_http: response from {} failed to parse: {}",
988 self.options.base_url,
989 err
990 );
991 None
992 }
993 }
994 }
995 }
996
997 impl Corrector for LlmHttpCorrector {
998 fn correct(&self, features: &CorrectionFeatures) -> Result<Option<u64>> {
999 // Safety contract: every failure returns Ok(None), not Err.
1000 // Mirrors `TabPfnHttpCorrector::correct`. On the optimizer's
1001 // hot path a remote LLM going down (rate limit, network
1002 // partition, mis-config) must never surface as a query
1003 // failure.
1004 let Some(raw) = self.try_infer(features) else {
1005 return Ok(None);
1006 };
1007 Ok(Some(saturating_clamp(raw as f64, self.options.ceiling)))
1008 }
1009
1010 fn name(&self) -> &'static str {
1011 "llm-http"
1012 }
1013 }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019
1020 #[test]
1021 fn identity_returns_baseline() {
1022 let corrector = IdentityCorrector;
1023 let features = CorrectionFeatures {
1024 baseline_estimate: 1234,
1025 ..Default::default()
1026 };
1027 assert_eq!(corrector.correct(&features).unwrap(), Some(1234));
1028 assert_eq!(corrector.name(), "identity");
1029 }
1030
1031 #[test]
1032 fn tabpfn_stub_always_returns_none() {
1033 let corrector = TabPfnStub;
1034 let features = CorrectionFeatures {
1035 baseline_estimate: 9999,
1036 ..Default::default()
1037 };
1038 assert_eq!(
1039 corrector.correct(&features).unwrap(),
1040 None,
1041 "TabPfnStub must always return Ok(None) — it documents the integration point"
1042 );
1043 assert_eq!(corrector.name(), "tabpfn-stub");
1044
1045 // Also exercise an empty feature vector — the stub should still
1046 // return None without inspecting the input.
1047 let empty = CorrectionFeatures::default();
1048 assert_eq!(corrector.correct(&empty).unwrap(), None);
1049 }
1050
1051 #[test]
1052 fn feature_vec_layout_is_stable() {
1053 let f = CorrectionFeatures {
1054 baseline_estimate: 100,
1055 left_input_rows: Some(10),
1056 right_input_rows: None,
1057 left_distinct: Some(7),
1058 right_distinct: None,
1059 predicate_count: 3,
1060 join_depth: 2,
1061 };
1062 let v = f.to_vec();
1063 assert_eq!(v.len(), CorrectionFeatures::FEATURE_LEN);
1064 assert_eq!(v[0], 100.0);
1065 assert_eq!(v[1], 10.0);
1066 assert_eq!(v[2], 0.0); // None → 0
1067 assert_eq!(v[3], 7.0);
1068 assert_eq!(v[4], 0.0);
1069 assert_eq!(v[5], 3.0);
1070 assert_eq!(v[6], 2.0);
1071 }
1072}
1073
1074#[cfg(all(test, feature = "gbt"))]
1075mod gbt_tests {
1076 use super::gbt::{GbtCorrector, GbtOptions};
1077 use super::{CorrectionFeatures, Corrector};
1078 use crate::feedback::Observation;
1079
1080 /// Build N synthetic observations where `actual = est * 2` for a
1081 /// spread of est values. Plenty of signal for the trees to latch on.
1082 fn synthetic_double(n: u64) -> Vec<Observation> {
1083 (1..=n)
1084 .map(|i| Observation {
1085 template_hash: "syn".into(),
1086 plan_fingerprint: "p".into(),
1087 est_rows: i * 10,
1088 actual_rows: i * 10 * 2,
1089 latency_ms: None,
1090 })
1091 .collect()
1092 }
1093
1094 #[test]
1095 fn predicts_roughly_double_when_training_says_double() {
1096 let obs = synthetic_double(200);
1097 let opts = GbtOptions {
1098 learning_rate: 0.3,
1099 max_depth: 4,
1100 num_trees: 50,
1101 ceiling: u64::MAX,
1102 min_leaf_size: 1,
1103 };
1104 let corrector = GbtCorrector::train(&obs, opts).expect("training");
1105
1106 let features = CorrectionFeatures {
1107 baseline_estimate: 500,
1108 ..Default::default()
1109 };
1110 let corrected = corrector
1111 .correct(&features)
1112 .expect("correct")
1113 .expect("Some");
1114 // True target is 1000. Trees won't be exact; require within 25%.
1115 let ratio = corrected as f64 / 1000.0;
1116 assert!(
1117 (0.75..=1.25).contains(&ratio),
1118 "expected ~1000, got {} (ratio {})",
1119 corrected,
1120 ratio
1121 );
1122 assert_eq!(corrector.name(), "gbt");
1123 }
1124
1125 #[test]
1126 fn ceiling_clamps_when_prediction_exceeds_it() {
1127 let obs = synthetic_double(200);
1128 let opts = GbtOptions {
1129 learning_rate: 0.3,
1130 max_depth: 4,
1131 num_trees: 50,
1132 ceiling: 100, // far below 2 × baseline
1133 min_leaf_size: 1,
1134 };
1135 let corrector = GbtCorrector::train(&obs, opts).expect("training");
1136
1137 let features = CorrectionFeatures {
1138 baseline_estimate: 500,
1139 ..Default::default()
1140 };
1141 let corrected = corrector
1142 .correct(&features)
1143 .expect("correct")
1144 .expect("Some");
1145 assert_eq!(corrected, 100, "ceiling must clamp the corrected estimate");
1146 assert_eq!(corrector.ceiling(), 100);
1147 }
1148
1149 #[test]
1150 fn empty_observations_errors() {
1151 match GbtCorrector::train(&[], GbtOptions::default()) {
1152 Ok(_) => panic!("expected error on empty observations"),
1153 Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1154 }
1155 }
1156
1157 #[test]
1158 fn all_zero_observations_errors() {
1159 let obs = vec![
1160 Observation {
1161 template_hash: "z".into(),
1162 plan_fingerprint: "p".into(),
1163 est_rows: 0,
1164 actual_rows: 5,
1165 latency_ms: None,
1166 },
1167 Observation {
1168 template_hash: "z".into(),
1169 plan_fingerprint: "p".into(),
1170 est_rows: 5,
1171 actual_rows: 0,
1172 latency_ms: None,
1173 },
1174 ];
1175 match GbtCorrector::train(&obs, GbtOptions::default()) {
1176 Ok(_) => panic!("expected error when all observations are zero"),
1177 Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1178 }
1179 }
1180}
1181
1182#[cfg(all(test, feature = "additive_gbt"))]
1183mod additive_tests {
1184 use super::additive::{AdditiveGbtCorrector, AdditiveGbtOptions};
1185 use super::{CorrectionFeatures, Corrector};
1186 use crate::feedback::Observation;
1187
1188 /// Build N synthetic observations where every actual row count is
1189 /// the same constant `target`. An additive model trained on this
1190 /// should regress toward `target` regardless of the input features.
1191 fn synthetic_constant(n: u64, target: u64) -> Vec<Observation> {
1192 (1..=n)
1193 .map(|i| Observation {
1194 template_hash: "syn-add".into(),
1195 plan_fingerprint: "p".into(),
1196 est_rows: i * 10,
1197 actual_rows: target,
1198 latency_ms: None,
1199 })
1200 .collect()
1201 }
1202
1203 #[test]
1204 fn predicts_near_constant_when_training_is_constant() {
1205 let obs = synthetic_constant(200, 1000);
1206 let opts = AdditiveGbtOptions {
1207 learning_rate: 0.3,
1208 max_depth: 4,
1209 num_trees: 50,
1210 ceiling: u64::MAX,
1211 min_leaf_size: 1,
1212 };
1213 let corrector =
1214 AdditiveGbtCorrector::train(&obs, opts).expect("training additive corrector");
1215
1216 let features = CorrectionFeatures {
1217 baseline_estimate: 500,
1218 ..Default::default()
1219 };
1220 let corrected = corrector
1221 .correct(&features)
1222 .expect("correct")
1223 .expect("Some");
1224 assert!(
1225 (800..=1200).contains(&corrected),
1226 "expected ~1000, got {corrected}"
1227 );
1228 assert_eq!(corrector.name(), "additive_gbt");
1229 }
1230
1231 #[test]
1232 fn ceiling_clamps_when_prediction_exceeds_it() {
1233 let obs = synthetic_constant(200, 1000);
1234 let opts = AdditiveGbtOptions {
1235 learning_rate: 0.3,
1236 max_depth: 4,
1237 num_trees: 50,
1238 ceiling: 100, // far below the trained constant
1239 min_leaf_size: 1,
1240 };
1241 let corrector = AdditiveGbtCorrector::train(&obs, opts).expect("training");
1242
1243 let features = CorrectionFeatures {
1244 baseline_estimate: 500,
1245 ..Default::default()
1246 };
1247 let corrected = corrector
1248 .correct(&features)
1249 .expect("correct")
1250 .expect("Some");
1251 assert_eq!(corrected, 100, "ceiling must clamp the additive correction");
1252 assert_eq!(corrector.ceiling(), 100);
1253 }
1254
1255 #[test]
1256 fn corrects_nonzero_even_when_baseline_estimate_is_zero() {
1257 // This is the q=∞ fix proof. The multiplicative GbtCorrector
1258 // would return 0 here (baseline * exp(predicted) = 0 * _ = 0).
1259 // The additive backend must escape that trap.
1260 let obs = synthetic_constant(200, 1000);
1261 let corrector =
1262 AdditiveGbtCorrector::train(&obs, AdditiveGbtOptions::default()).expect("training");
1263
1264 let features = CorrectionFeatures {
1265 baseline_estimate: 0,
1266 ..Default::default()
1267 };
1268 let corrected = corrector
1269 .correct(&features)
1270 .expect("correct")
1271 .expect("Some");
1272 assert!(
1273 corrected > 0,
1274 "additive corrector must return non-zero even when baseline_estimate = 0; got {corrected}"
1275 );
1276 }
1277
1278 #[test]
1279 fn empty_observations_errors() {
1280 match AdditiveGbtCorrector::train(&[], AdditiveGbtOptions::default()) {
1281 Ok(_) => panic!("expected error on empty observations"),
1282 Err(e) => assert!(matches!(e, crate::Error::Feedback(_))),
1283 }
1284 }
1285}
1286
1287#[cfg(all(test, feature = "tabpfn_http"))]
1288mod tabpfn_http_tests {
1289 use super::tabpfn::{TabPfnHttpCorrector, TabPfnHttpOptions};
1290 use super::{CorrectionFeatures, Corrector};
1291
1292 /// Pointing at port 1 on the loopback interface is the canonical
1293 /// "guaranteed-to-refuse-connection" target on Linux/macOS. The
1294 /// safety contract says: any transport failure must surface as
1295 /// `Ok(None)`, never `Err`, never a panic. We verify that here
1296 /// without standing up a real inference server.
1297 #[test]
1298 fn http_failure_returns_none_not_error() {
1299 let corrector = TabPfnHttpCorrector::new(TabPfnHttpOptions {
1300 base_url: "http://127.0.0.1:1/infer".into(),
1301 timeout_ms: 50,
1302 ceiling: u64::MAX,
1303 });
1304 let features = CorrectionFeatures {
1305 baseline_estimate: 1234,
1306 ..Default::default()
1307 };
1308 let result = corrector.correct(&features);
1309 assert!(
1310 result.is_ok(),
1311 "tabpfn-http transport failure must not propagate as Err; got {result:?}"
1312 );
1313 assert_eq!(
1314 result.unwrap(),
1315 None,
1316 "tabpfn-http transport failure must yield Ok(None) so the engine falls back cleanly"
1317 );
1318 assert_eq!(corrector.name(), "tabpfn-http");
1319 }
1320
1321 #[test]
1322 fn malformed_url_returns_none() {
1323 // Not even a valid URL — `ureq` rejects this at request-build
1324 // time, which our error path must absorb the same as any other
1325 // transport failure.
1326 let corrector = TabPfnHttpCorrector::with_url("not a url at all");
1327 let features = CorrectionFeatures::default();
1328 let result = corrector.correct(&features).expect("never Err");
1329 assert_eq!(result, None);
1330 }
1331
1332 #[test]
1333 fn options_default_is_localhost() {
1334 let opts = TabPfnHttpOptions::default();
1335 assert!(opts.base_url.starts_with("http://"));
1336 assert!(opts.timeout_ms > 0);
1337 assert_eq!(opts.ceiling, u64::MAX);
1338 }
1339}
1340
1341#[cfg(all(test, feature = "llm_http"))]
1342mod llm_http_tests {
1343 use super::llm::{
1344 DEFAULT_TIMEOUT_MS, DEFAULT_URL, LlmHttpCorrector, LlmHttpOptions, MAX_TIMEOUT_MS,
1345 };
1346 use super::{CorrectionFeatures, Corrector};
1347 use std::io::{Read, Write};
1348 use std::net::TcpListener;
1349 use std::sync::atomic::{AtomicUsize, Ordering};
1350 use std::sync::{Arc, Mutex};
1351 use std::thread;
1352 use std::time::Duration;
1353
1354 /// Tiny hand-rolled mock HTTP server, one-shot per accept. We avoid
1355 /// pulling `mockito` (not currently a dep) and keep the test binary
1356 /// lean. The server reads the full request, then writes a fixed
1357 /// response. The handler closure decides what to send so the same
1358 /// scaffolding serves both success and parse-error cases.
1359 fn spawn_mock(
1360 responder: impl Fn(usize) -> Vec<u8> + Send + Sync + 'static,
1361 max_requests: usize,
1362 ) -> (String, Arc<AtomicUsize>) {
1363 let listener = TcpListener::bind("127.0.0.1:0").expect("bind loopback");
1364 let port = listener.local_addr().unwrap().port();
1365 let url = format!("http://127.0.0.1:{port}/infer");
1366 let counter = Arc::new(AtomicUsize::new(0));
1367 let counter_thread = Arc::clone(&counter);
1368 let responder = Arc::new(Mutex::new(responder));
1369 thread::spawn(move || {
1370 listener
1371 .set_nonblocking(false)
1372 .expect("blocking mode for mock");
1373 for stream in listener.incoming().take(max_requests) {
1374 let Ok(mut stream) = stream else { continue };
1375 let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
1376 let _ = stream.set_write_timeout(Some(Duration::from_secs(2)));
1377 // Drain HTTP request: read headers + body. We pull a
1378 // bounded chunk; the bench client sends tiny payloads
1379 // (sub-200 bytes) so this is sufficient for the tests
1380 // and avoids the parsing complexity of a full HTTP
1381 // server.
1382 let mut buf = [0u8; 4096];
1383 let _ = stream.read(&mut buf);
1384 let idx = counter_thread.fetch_add(1, Ordering::SeqCst);
1385 let body = responder.lock().unwrap()(idx);
1386 let header = format!(
1387 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
1388 body.len()
1389 );
1390 let _ = stream.write_all(header.as_bytes());
1391 let _ = stream.write_all(&body);
1392 let _ = stream.flush();
1393 }
1394 });
1395 (url, counter)
1396 }
1397
1398 /// Pointing at port 1 on the loopback interface is the canonical
1399 /// "guaranteed-to-refuse-connection" target on Linux/macOS. The
1400 /// safety contract says: any transport failure must surface as
1401 /// `Ok(None)`, never `Err`, never a panic.
1402 #[test]
1403 fn http_failure_returns_none_not_error() {
1404 let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1405 base_url: "http://127.0.0.1:1/infer".into(),
1406 timeout_ms: 50,
1407 ceiling: u64::MAX,
1408 });
1409 let features = CorrectionFeatures {
1410 baseline_estimate: 1234,
1411 ..Default::default()
1412 };
1413 let result = corrector.correct(&features);
1414 assert!(
1415 result.is_ok(),
1416 "llm-http transport failure must not propagate as Err; got {result:?}"
1417 );
1418 assert_eq!(
1419 result.unwrap(),
1420 None,
1421 "llm-http transport failure must yield Ok(None) so the engine falls back cleanly"
1422 );
1423 assert_eq!(corrector.name(), "llm-http");
1424 }
1425
1426 #[test]
1427 fn malformed_url_returns_none() {
1428 let corrector = LlmHttpCorrector::with_url("not a url at all");
1429 let features = CorrectionFeatures::default();
1430 let result = corrector.correct(&features).expect("never Err");
1431 assert_eq!(result, None);
1432 }
1433
1434 #[test]
1435 fn options_default_is_localhost_on_llm_port() {
1436 let opts = LlmHttpOptions::default();
1437 assert_eq!(opts.base_url, DEFAULT_URL);
1438 assert!(opts.base_url.contains(":8766"));
1439 assert_eq!(opts.timeout_ms, DEFAULT_TIMEOUT_MS);
1440 assert_eq!(opts.ceiling, u64::MAX);
1441 }
1442
1443 #[test]
1444 fn timeout_is_saturated_to_max() {
1445 let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1446 base_url: "http://127.0.0.1:1/infer".into(),
1447 timeout_ms: MAX_TIMEOUT_MS * 10,
1448 ceiling: u64::MAX,
1449 });
1450 assert_eq!(corrector.options().timeout_ms, MAX_TIMEOUT_MS);
1451 }
1452
1453 #[test]
1454 fn mock_success_returns_clamped_estimate() {
1455 let (url, counter) = spawn_mock(|_| br#"{"estimate": 4242}"#.to_vec(), 2);
1456 let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1457 base_url: url,
1458 timeout_ms: 2_000,
1459 ceiling: 1_000_000,
1460 });
1461 let features = CorrectionFeatures {
1462 baseline_estimate: 1_000,
1463 ..Default::default()
1464 };
1465 let result = corrector.correct(&features).expect("ok");
1466 assert_eq!(result, Some(4242));
1467 assert!(counter.load(Ordering::SeqCst) >= 1);
1468 }
1469
1470 #[test]
1471 fn mock_clamps_to_ceiling() {
1472 let (url, _counter) = spawn_mock(|_| br#"{"estimate": 99999999}"#.to_vec(), 2);
1473 let corrector = LlmHttpCorrector::new(LlmHttpOptions {
1474 base_url: url,
1475 timeout_ms: 2_000,
1476 ceiling: 500,
1477 });
1478 let result = corrector
1479 .correct(&CorrectionFeatures::default())
1480 .expect("ok");
1481 assert_eq!(result, Some(500));
1482 }
1483
1484 #[test]
1485 fn mock_parse_error_returns_none() {
1486 let (url, _counter) = spawn_mock(|_| b"not json at all".to_vec(), 2);
1487 let corrector = LlmHttpCorrector::with_url(url);
1488 let result = corrector
1489 .correct(&CorrectionFeatures::default())
1490 .expect("ok");
1491 assert_eq!(result, None);
1492 }
1493}