Skip to main content

solo_api/llm/
sampling.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingLlmClient`] — `LlmClient` impl backed by an MCP client's
4//! `sampling/createMessage` capability.
5//!
6//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
7//! §6 "Sampling-backed LLM client" / MAJOR 1 + MAJOR 3 resolutions):
8//!
9//!   * Steward holds an `Arc<dyn LlmClient>`. When `LlmConfig::McpSampling`
10//!     is configured, the Steward's `LlmClient` is a `SamplingLlmClient`
11//!     constructed at MCP `initialize` time (when the live peer becomes
12//!     available — the `TenantHandle::steward_slot` LATE-population path).
13//!
14//!   * `SamplingLlmClient::complete()` translates the workspace's
15//!     `Message` → `rmcp::SamplingMessage`, calls
16//!     `peer.create_message(params).await`, extracts the assistant's
17//!     text from the returned `CreateMessageResult`, and emits a
18//!     per-call `AuditOperation::LlmSamplingCall` row through the
19//!     tenant's `WriteHandle` (lesson #30: sync in writer-actor tx
20//!     for ACID).
21//!
22//!   * **Privacy invariant**: the audit `details_json` carries metadata
23//!     only — model hint, message count, max_tokens, duration_ms,
24//!     total prompt character count, output character count. **The raw
25//!     prompt content MUST NOT appear in the audit row**. Pinned by
26//!     [`tests::audit_row_omits_raw_prompt_text`].
27//!
28//!   * Error paths land structured audit rows:
29//!     - Client refusal → `result = "forbidden"`,
30//!       `details_json.reason = "client_refused"`.
31//!     - Timeout → `result = "error"`,
32//!       `details_json.reason = "timeout"`.
33//!     - Other transport / malformed-response → `result = "error"`,
34//!       `details_json.reason = <category>`.
35//!
36//!   * Per-call rate-limit / coalescing is **deferred to v0.9.0 P4**
37//!     (`SamplingCoordinator`). P2 wires the per-call path only.
38
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44    CreateMessageRequestParams, CreateMessageResult, ModelHint,
45    ModelPreferences, Role as RmcpRole, SamplingMessage,
46    SamplingMessageContent,
47};
48use rmcp::service::{Peer, RoleServer, ServiceError};
49use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
50use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
51
52/// Default per-call timeout. Drives the bounded wait around
53/// `peer.create_message`; if the client refuses or stalls, the caller
54/// sees a structured timeout error instead of an indefinite hang.
55///
56/// 30 seconds matches the consolidate-timer's cadence margins: an
57/// LLM call slower than this would already starve the Steward batch
58/// in P4's coordinator. Configurable per-construct via
59/// [`SamplingLlmClient::with_timeout`].
60pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
61
62/// Default max_tokens for sampling completions. Matches
63/// `solo-steward::StewardConfig::default().abstraction_max_tokens`
64/// so the wire shape is identical to what the Steward would have
65/// requested from any other backend.
66const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
67
68/// Error surface for [`SamplingClient::create_message`]. Combines the
69/// real rmcp `ServiceError` (when wrapping a live `Peer<RoleServer>`)
70/// with [`super::super::test_support::fake_mcp_client::FakeSamplingError`]
71/// (when driving the fixture from tests).
72#[derive(Debug)]
73pub enum SamplingError {
74    /// Forwarded from `rmcp::Peer::create_message`.
75    Service(ServiceError),
76    /// Routed from [`super::super::test_support::fake_mcp_client::
77    /// FakeSamplingError`] in test paths.
78    #[cfg(any(test, feature = "test-support"))]
79    Fake(crate::test_support::FakeSamplingError),
80}
81
82impl std::fmt::Display for SamplingError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            Self::Service(e) => write!(f, "{e}"),
86            #[cfg(any(test, feature = "test-support"))]
87            Self::Fake(e) => write!(f, "{e}"),
88        }
89    }
90}
91
92impl std::error::Error for SamplingError {}
93
94impl SamplingError {
95    /// Classifier used by [`SamplingLlmClient::complete`] to map the
96    /// transport-level error to an audit-row category + a Solo
97    /// [`CoreError`] variant.
98    ///
99    /// `(category_for_audit, treat_as_forbidden)` — `forbidden` becomes
100    /// `AuditResult::Forbidden` + `CoreError::Forbidden`; everything
101    /// else is `AuditResult::Error` + `CoreError::Llm`.
102    pub fn classify(&self) -> (&'static str, bool) {
103        match self {
104            Self::Service(_) => ("transport_error", false),
105            #[cfg(any(test, feature = "test-support"))]
106            Self::Fake(e) => match e {
107                crate::test_support::FakeSamplingError::Refused { .. } => {
108                    ("client_refused", true)
109                }
110                crate::test_support::FakeSamplingError::Transport { .. } => {
111                    ("transport_error", false)
112                }
113                crate::test_support::FakeSamplingError::MalformedResponse {
114                    ..
115                } => ("malformed_response", false),
116            },
117        }
118    }
119}
120
121/// Trait abstracting the `sampling/createMessage` RPC. The production
122/// impl wraps `Arc<Peer<RoleServer>>`; the test impl is
123/// [`super::super::test_support::fake_mcp_client::FakeMcpClient`].
124///
125/// Separating the trait from the concrete `Peer<RoleServer>` is the
126/// way around rmcp's `Peer` having private constructors — we can't
127/// build a fake `Peer` for tests, so we inject behind a trait.
128#[async_trait]
129pub trait SamplingClient: Send + Sync {
130    async fn create_message(
131        &self,
132        params: CreateMessageRequestParams,
133    ) -> Result<CreateMessageResult, SamplingError>;
134}
135
136/// Production wrapper around `rmcp::Peer<RoleServer>`. The Peer is
137/// cheap to clone (internally `Arc`-backed) and stays valid for the
138/// lifetime of the MCP session.
139pub struct PeerSamplingClient {
140    peer: Peer<RoleServer>,
141}
142
143impl PeerSamplingClient {
144    pub fn new(peer: Peer<RoleServer>) -> Self {
145        Self { peer }
146    }
147}
148
149#[async_trait]
150impl SamplingClient for PeerSamplingClient {
151    async fn create_message(
152        &self,
153        params: CreateMessageRequestParams,
154    ) -> Result<CreateMessageResult, SamplingError> {
155        self.peer
156            .create_message(params)
157            .await
158            .map_err(SamplingError::Service)
159    }
160}
161
162/// `LlmClient` impl whose `complete()` calls back via the connected
163/// MCP client's sampling capability.
164///
165/// Construct via [`SamplingLlmClient::new`] (production path: wraps a
166/// real `Peer<RoleServer>`) or [`SamplingLlmClient::with_sampling_client`]
167/// (test path: takes the abstracted [`SamplingClient`] trait object
168/// directly so [`super::super::test_support::fake_mcp_client::
169/// FakeMcpClient`] can drive it).
170///
171/// Cheap to clone — every field is `Arc`-shared.
172#[derive(Clone)]
173pub struct SamplingLlmClient {
174    /// The RPC channel back to the MCP client.
175    sampling_client: Arc<dyn SamplingClient>,
176    /// Per-tenant `WriteHandle` for the synchronous audit emit. Routes
177    /// through the writer-actor's mpsc so the
178    /// `AuditOperation::LlmSamplingCall` INSERT lands in a dedicated
179    /// `BEGIN IMMEDIATE` transaction on the writer's connection.
180    write_handle: WriteHandle,
181    /// Cached audit `principal_subject` for the MCP session. Resolved
182    /// at session init time (see `mcp::resolve_mcp_principal`); `None`
183    /// for unauthenticated stdio sessions.
184    audit_principal: Option<String>,
185    /// `max_tokens` value sent on every `sampling/createMessage`.
186    /// Defaults to [`DEFAULT_SAMPLING_MAX_TOKENS`]; configurable via
187    /// [`Self::with_max_tokens`].
188    max_tokens: u32,
189    /// Bounded wait on `create_message`. See
190    /// [`DEFAULT_SAMPLING_TIMEOUT`].
191    timeout: Duration,
192}
193
194impl SamplingLlmClient {
195    /// Build a client wrapping a real `Peer<RoleServer>`. Production
196    /// path — called from
197    /// [`crate::mcp::SoloMcpServer::populate_sampling_steward`] when an MCP
198    /// session reaches `initialize` with a sampling-capable peer.
199    pub fn new(
200        peer: Peer<RoleServer>,
201        write_handle: WriteHandle,
202        audit_principal: Option<String>,
203    ) -> Self {
204        Self::with_sampling_client(
205            Arc::new(PeerSamplingClient::new(peer)),
206            write_handle,
207            audit_principal,
208        )
209    }
210
211    /// Test-friendly constructor accepting any [`SamplingClient`]
212    /// implementation. Pair with
213    /// [`super::super::test_support::fake_mcp_client::FakeMcpClient`]
214    /// in tests.
215    pub fn with_sampling_client(
216        sampling_client: Arc<dyn SamplingClient>,
217        write_handle: WriteHandle,
218        audit_principal: Option<String>,
219    ) -> Self {
220        Self {
221            sampling_client,
222            write_handle,
223            audit_principal,
224            max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
225            timeout: DEFAULT_SAMPLING_TIMEOUT,
226        }
227    }
228
229    /// Override the per-call `max_tokens` cap.
230    pub fn with_max_tokens(mut self, n: u32) -> Self {
231        self.max_tokens = n.max(1);
232        self
233    }
234
235    /// Override the per-call timeout.
236    pub fn with_timeout(mut self, t: Duration) -> Self {
237        self.timeout = t;
238        self
239    }
240
241    /// Build the `CreateMessageRequestParams` from Solo's `Message`
242    /// vec. Splits out `Role::System` into the `system_prompt` field
243    /// (rmcp's `SamplingMessage::role` is only User / Assistant) and
244    /// hints the user's MCP client toward a Claude-class model.
245    fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
246        // Split system messages out of the conversation history; the
247        // sampling protocol carries the system prompt as a top-level
248        // field rather than inline.
249        let mut system_parts: Vec<String> = Vec::new();
250        let mut samp_messages: Vec<SamplingMessage> = Vec::new();
251        for m in messages {
252            match m.role {
253                Role::System => system_parts.push(m.content.clone()),
254                Role::User => {
255                    samp_messages.push(SamplingMessage::user_text(&m.content));
256                }
257                Role::Assistant => {
258                    samp_messages
259                        .push(SamplingMessage::assistant_text(&m.content));
260                }
261            }
262        }
263        // rmcp 1.7's struct literals are non-exhaustive across crate
264        // boundaries; build via the typed constructors + builders.
265        let preferences = ModelPreferences::new()
266            .with_hints(vec![ModelHint::new("claude")])
267            .with_intelligence_priority(0.7)
268            .with_speed_priority(0.3)
269            .with_cost_priority(0.4);
270        let mut params =
271            CreateMessageRequestParams::new(samp_messages, self.max_tokens)
272                .with_model_preferences(preferences);
273        if !system_parts.is_empty() {
274            params = params.with_system_prompt(system_parts.join("\n\n"));
275        }
276        params
277    }
278
279    /// Build the audit `AuditEvent` carrying ONLY metadata. No raw
280    /// prompt content lands in `details_json`.
281    ///
282    /// Pinned by [`tests::audit_row_omits_raw_prompt_text`].
283    ///
284    /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the raw character count
285    /// of the prompt is itself a side-channel — a 6-char prompt
286    /// uniquely identifies very-short refusal paths (e.g. a leaked
287    /// password length). `prompt_chars` and `input_tokens_est` are
288    /// rounded up to the next power of two before persistence. This
289    /// preserves operator capacity-planning (the bucket is within ~2x
290    /// of the real size for any sufficiently large prompt) while
291    /// removing the per-character precision.
292    ///
293    /// Buckets: `0, 1, 2, 4, 8, 16, 32, 64, ..., 1024, 2048, ...`
294    /// (next power of two `>= n`). 0 stays 0.
295    fn audit_event(
296        &self,
297        params: &CreateMessageRequestParams,
298        outcome: SamplingOutcome,
299    ) -> AuditEvent {
300        let raw_prompt_chars: usize = params
301            .messages
302            .iter()
303            .flat_map(|m| m.content.iter())
304            .filter_map(|c| c.as_text().map(|t| t.text.len()))
305            .sum::<usize>()
306            + params
307                .system_prompt
308                .as_ref()
309                .map(|s| s.len())
310                .unwrap_or(0);
311        // v0.9.1 P1 Fix 4: bucket the raw count to the next power of
312        // two. Pinned by `tests::audit_row_bucket_prompt_chars_to_pow2`.
313        let prompt_chars = next_pow2_bucket(raw_prompt_chars);
314        // ~4 chars per token for the rough English-text estimate used
315        // by `solo doctor --check-llm` and Anthropic's docs. Recorded
316        // for operator capacity-planning. Bucketed for the same
317        // privacy reason — and to stay consistent with `prompt_chars`.
318        let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
319        let model_hint = params
320            .model_preferences
321            .as_ref()
322            .and_then(|p| p.hints.as_ref())
323            .and_then(|h| h.first())
324            .and_then(|h| h.name.clone())
325            .unwrap_or_else(|| "(none)".to_string());
326
327        let mut details = serde_json::Map::new();
328        details.insert(
329            "model_hint".to_string(),
330            serde_json::Value::String(model_hint),
331        );
332        details.insert(
333            "messages_count".to_string(),
334            serde_json::Value::Number(params.messages.len().into()),
335        );
336        details.insert(
337            "max_tokens".to_string(),
338            serde_json::Value::Number(params.max_tokens.into()),
339        );
340        details.insert(
341            "prompt_chars".to_string(),
342            serde_json::Value::Number(prompt_chars.into()),
343        );
344        details.insert(
345            "input_tokens_est".to_string(),
346            serde_json::Value::Number(input_tokens_est.into()),
347        );
348
349        let result = match &outcome {
350            SamplingOutcome::Ok {
351                duration_ms,
352                model,
353                output_chars,
354            } => {
355                // v0.9.1 P1 Fix 4: same power-of-2 bucketing as
356                // `prompt_chars` for the output side. A model that
357                // always replies with a one-token refusal (e.g. an
358                // assistant trained to say "no.") would otherwise leak
359                // the response-length distribution; bucketing
360                // collapses 1-2-3-4 chars all into bucket 4.
361                let bucketed_output_chars = next_pow2_bucket(*output_chars);
362                let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
363                details.insert(
364                    "duration_ms".to_string(),
365                    serde_json::Value::Number((*duration_ms).into()),
366                );
367                details.insert(
368                    "model".to_string(),
369                    serde_json::Value::String(model.clone()),
370                );
371                details.insert(
372                    "output_chars".to_string(),
373                    serde_json::Value::Number(bucketed_output_chars.into()),
374                );
375                details.insert(
376                    "output_tokens_est".to_string(),
377                    serde_json::Value::Number(output_tokens_est.into()),
378                );
379                AuditResult::Ok
380            }
381            SamplingOutcome::Forbidden {
382                reason,
383                duration_ms,
384            } => {
385                details.insert(
386                    "duration_ms".to_string(),
387                    serde_json::Value::Number((*duration_ms).into()),
388                );
389                details.insert(
390                    "reason".to_string(),
391                    serde_json::Value::String(reason.to_string()),
392                );
393                AuditResult::Forbidden
394            }
395            SamplingOutcome::Error {
396                reason,
397                duration_ms,
398            } => {
399                details.insert(
400                    "duration_ms".to_string(),
401                    serde_json::Value::Number((*duration_ms).into()),
402                );
403                details.insert(
404                    "reason".to_string(),
405                    serde_json::Value::String(reason.to_string()),
406                );
407                AuditResult::Error
408            }
409        };
410
411        AuditEvent {
412            ts_ms: chrono::Utc::now().timestamp_millis(),
413            principal_subject: self.audit_principal.clone(),
414            operation: AuditOperation::LlmSamplingCall,
415            target_id: None,
416            result,
417            details: Some(serde_json::Value::Object(details)),
418        }
419    }
420}
421
422/// Internal outcome category for the audit-row builder.
423enum SamplingOutcome {
424    Ok {
425        duration_ms: u64,
426        model: String,
427        output_chars: usize,
428    },
429    Forbidden {
430        reason: &'static str,
431        duration_ms: u64,
432    },
433    Error {
434        reason: &'static str,
435        duration_ms: u64,
436    },
437}
438
439#[async_trait]
440impl LlmClient for SamplingLlmClient {
441    fn name(&self) -> &str {
442        "mcp-sampling"
443    }
444
445    async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
446        let params = self.build_request(messages);
447        let start = Instant::now();
448
449        // Bounded wait on `peer.create_message`. The fold of (rmcp
450        // ServiceError | FakeError | tokio timeout) into the
451        // `Outcome` enum keeps the audit path single-sourced.
452        let rpc = tokio::time::timeout(
453            self.timeout,
454            self.sampling_client.create_message(params.clone()),
455        )
456        .await;
457        let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
458            as u64;
459
460        let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
461            match rpc {
462                Ok(Ok(result)) => {
463                    match extract_text(&result) {
464                        Ok(text) => {
465                            let output_chars = text.len();
466                            let outcome = SamplingOutcome::Ok {
467                                duration_ms,
468                                model: result.model.clone(),
469                                output_chars,
470                            };
471                            (Ok(Message::assistant(text)), outcome)
472                        }
473                        Err(reason) => (
474                            Err(CoreError::llm(format!(
475                                "mcp sampling: malformed response: {reason}"
476                            ))),
477                            SamplingOutcome::Error {
478                                reason: "malformed_response",
479                                duration_ms,
480                            },
481                        ),
482                    }
483                }
484                Ok(Err(e)) => {
485                    let (category, is_forbidden) = e.classify();
486                    let outcome = if is_forbidden {
487                        SamplingOutcome::Forbidden {
488                            reason: category,
489                            duration_ms,
490                        }
491                    } else {
492                        SamplingOutcome::Error {
493                            reason: category,
494                            duration_ms,
495                        }
496                    };
497                    let err = if is_forbidden {
498                        CoreError::forbidden(format!("mcp sampling: {e}"))
499                    } else {
500                        CoreError::llm(format!("mcp sampling: {e}"))
501                    };
502                    (Err(err), outcome)
503                }
504                Err(_elapsed) => (
505                    Err(CoreError::llm(format!(
506                        "mcp sampling: timeout after {}ms",
507                        duration_ms
508                    ))),
509                    SamplingOutcome::Error {
510                        reason: "timeout",
511                        duration_ms,
512                    },
513                ),
514            };
515
516        // Synchronous audit emit through the writer-actor (lesson #30).
517        // Failure to land the audit row is operator-visible: the
518        // sampling call's caller sees the storage error and can decide
519        // whether to abort (we DO abort here — without the audit row
520        // we have no record of the call).
521        //
522        // v0.9.1 P1 Fix 3 (F4 Result-shadowing): when BOTH `core_result`
523        // is `Err(..)` AND the audit emit also fails, return the
524        // ORIGINAL LLM-side error (more actionable for callers — they
525        // can retry the LLM call, or decide whether the upstream
526        // refusal/timeout is recoverable). Surface the audit failure
527        // via `tracing::error!` for operator visibility — operators
528        // alarming on storage errors see it; callers see the actionable
529        // error.
530        //
531        // Policy summary:
532        //   * RPC Ok  + audit Ok  → return Ok(text)
533        //   * RPC Ok  + audit Err → return Err(storage) [audit failure
534        //                            wins — no undocumented sampling
535        //                            calls per lesson #30]
536        //   * RPC Err + audit Ok  → return Err(llm/forbidden) [unchanged]
537        //   * RPC Err + audit Err → return Err(llm/forbidden) AND log
538        //                            audit failure at error level
539        //                            [v0.9.1 P1 Fix 3]
540        let event = self.audit_event(&params, outcome);
541        match (
542            core_result,
543            self.write_handle.emit_llm_sampling_audit(event).await,
544        ) {
545            (Ok(text), Ok(())) => Ok(text),
546            (Ok(_text), Err(audit_err)) => {
547                // RPC succeeded but the audit row didn't land. Drop
548                // the success — without a durable audit row we can't
549                // honor the "every sampling call leaves a trace"
550                // invariant.
551                Err(CoreError::storage(format!(
552                    "mcp sampling: audit emit failed: {audit_err}"
553                )))
554            }
555            (Err(core_err), Ok(())) => Err(core_err),
556            (Err(core_err), Err(audit_err)) => {
557                // Both failed. Return the LLM-side error (the caller's
558                // most actionable signal); log the audit failure so an
559                // operator who alarms on storage errors still sees it.
560                tracing::error!(
561                    audit_error = %audit_err,
562                    core_error = %core_err,
563                    "mcp sampling: audit emit failed alongside core \
564                     error; surfacing core error to caller"
565                );
566                Err(core_err)
567            }
568        }
569    }
570}
571
572/// Round `n` up to the next power of two. Used to bucket
573/// `prompt_chars` / `output_chars` / `*_tokens_est` in the
574/// `LlmSamplingCall` audit row's `details_json` (v0.9.1 P1 Fix 4
575/// "F6" — `prompt_chars` was a privacy side-channel for short
576/// prompts).
577///
578/// Buckets: `0 → 0`, `1 → 1`, `2 → 2`, `3 → 4`, `4 → 4`, `5..=8 → 8`,
579/// `9..=16 → 16`, `17..=32 → 32`, ... — within a bucket all values
580/// collapse to the same persisted number. The worst-case fidelity
581/// loss is just under 2x (e.g. 9 chars persists as 16) which is well
582/// within the precision capacity-planning needs.
583///
584/// Pinned by [`tests::next_pow2_bucket_*`] and
585/// [`tests::audit_row_bucket_prompt_chars_to_pow2`].
586fn next_pow2_bucket(n: usize) -> usize {
587    if n == 0 {
588        return 0;
589    }
590    // `next_power_of_two` saturates at `usize::MAX` if `n` is past the
591    // last representable power. For our use (char counts on a Solo
592    // prompt) the absolute upper bound is the LLM model's context
593    // window — well below `usize::MAX` on every Solo-supported target.
594    n.next_power_of_two()
595}
596
597/// Pull the assistant's text out of the rmcp result. Walks every text
598/// content block in the message (the spec allows either a single
599/// `SamplingContent::Single` or a `SamplingContent::Multiple`) and
600/// concatenates them with newlines. Returns `Err(reason)` if no text
601/// blocks were present — the malformed-response path.
602fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
603    if result.message.role != RmcpRole::Assistant {
604        return Err("response role was not Assistant");
605    }
606    let mut out = String::new();
607    for content in result.message.content.iter() {
608        if let SamplingMessageContent::Text(text) = content {
609            if !out.is_empty() {
610                out.push('\n');
611            }
612            out.push_str(&text.text);
613        }
614    }
615    if out.is_empty() {
616        Err("no text content blocks")
617    } else {
618        Ok(out)
619    }
620}
621
622/// v0.9.0 P2: build a sampling-backed `Arc<Steward>` for a tenant that
623/// has resolved `LlmConfig::McpSampling` and just attached an MCP
624/// session.
625///
626/// Called from [`crate::mcp::SoloMcpServer::populate_sampling_steward`] at
627/// MCP `initialize` time once the peer's sampling capability is
628/// confirmed. The returned `Arc<Steward>` is written into
629/// `tenant.steward_slot()` so the writer-actor + consolidate timer
630/// can read a populated slot on their next tick.
631///
632/// v0.9.0 P5 (M3 wiring): the live `PeerSamplingClient` is now wrapped
633/// in a [`super::SamplingCoordinator`] before being handed to
634/// `SamplingLlmClient`. Concurrent `complete()` calls within the
635/// coalesce window collapse into one `peer.create_message` RPC and the
636/// response is demultiplexed back per-task — matching the
637/// `[sampling] coalesce_window_ms` / `coalesce_max_requests` config the
638/// operator wrote in `solo.config.toml`. Per-call audit emit semantics
639/// are unchanged: every logical request still lands one
640/// `AuditOperation::LlmSamplingCall` row, no raw prompt content escapes
641/// to the audit row.
642///
643/// Edge case (clamping): the `[sampling]` block accepts values that
644/// effectively disable batching — `coalesce_max_requests = 1` and / or
645/// `coalesce_window_ms = 0` reduce the coordinator to pass-through (one
646/// inner call per submission). The coordinator's
647/// [`super::SamplingCoordinator::with_settings`] clamps `max_batch` to
648/// `max(1)` so a zero value still produces a single-element flush
649/// immediately rather than panicking or deadlocking.
650pub fn build_sampling_steward(
651    peer: Peer<RoleServer>,
652    write_handle: WriteHandle,
653    audit_principal: Option<String>,
654    steward_config: solo_steward::StewardConfig,
655    sampling_config: solo_storage::SamplingConfig,
656) -> Arc<solo_steward::Steward> {
657    let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
658    let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
659        inner,
660        std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
661        sampling_config.coalesce_max_requests as usize,
662    );
663    let client = SamplingLlmClient::with_sampling_client(
664        coordinator,
665        write_handle,
666        audit_principal,
667    )
668    .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
669    Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675    use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
676    use rmcp::model::CreateMessageResult;
677    use solo_core::TenantId;
678    use solo_storage::{
679        EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
680        TenantHandle, TenantRegistry, TenantRegistryParams, init,
681        open_sqlcipher,
682    };
683    use std::path::PathBuf;
684    use std::sync::Arc;
685    use tempfile::TempDir;
686    use zeroize::Zeroizing;
687
688    const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
689
690    /// Bootstrap a per-tenant `TenantHandle` whose writer-actor accepts
691    /// the new `WriteCommand::EmitLlmSamplingAudit` variant.
692    ///
693    /// Mirrors the v0.8.x test discipline (see
694    /// `crates/solo-storage/src/tenants/handle_registry_tests.rs`'s
695    /// `fresh_init_dir`): build a real tenant DB on disk via the same
696    /// `init()` helper users invoke, wrap in a `TenantRegistry`, and
697    /// surface the `WriteHandle` for direct `SamplingLlmClient`
698    /// wiring.
699    struct Harness {
700        _tmp: TempDir,
701        _registry: Arc<TenantRegistry>,
702        _tenant: Arc<TenantHandle>,
703        write_handle: solo_storage::WriteHandle,
704        db_path: PathBuf,
705        key: KeyMaterial,
706    }
707
708    async fn harness() -> Harness {
709        let tmp = TempDir::new().expect("tempdir");
710        let data_dir = tmp.path().to_path_buf();
711        let _ = init(InitParams {
712            data_dir: data_dir.clone(),
713            passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
714            force: false,
715            embedder: EmbedderConfig {
716                name: "stub".into(),
717                version: "v1".into(),
718                dim: 32,
719                dtype: "f32".into(),
720            },
721        })
722        .expect("init");
723
724        let cfg = solo_storage::SoloConfig::read(
725            &data_dir.join("solo.config.toml"),
726        )
727        .expect("read cfg");
728        let key = KeyMaterial::derive(
729            TEST_PASSPHRASE,
730            &cfg.salt_bytes().expect("salt"),
731        )
732        .expect("derive key");
733
734        let embedder: Arc<dyn solo_core::Embedder> =
735            Arc::new(StubEmbedder::new("stub", "v1", 32));
736        let registry = Arc::new(
737            TenantRegistry::open(TenantRegistryParams {
738                data_dir: data_dir.clone(),
739                key: key.clone(),
740                embedder: embedder.clone(),
741                hnsw_params: HnswParams::default(),
742                steward: None,
743                runtime_handle: Some(tokio::runtime::Handle::current()),
744                steward_factory: None,
745                triples_batch_signal: None,
746            })
747            .expect("open registry"),
748        );
749
750        let tenant_id = TenantId::default_tenant();
751        let tenant = registry
752            .get_or_open(&tenant_id)
753            .await
754            .expect("get_or_open default tenant");
755        let write_handle = tenant.write().clone();
756        let db_path = tenant.db_path().to_path_buf();
757
758        Harness {
759            _tmp: tmp,
760            _registry: registry,
761            _tenant: tenant,
762            write_handle,
763            db_path,
764            key,
765        }
766    }
767
768    /// Helper: count the `audit_events` rows whose `operation` is the
769    /// given string. Opens a fresh connection to the tenant DB so we
770    /// avoid contention with the writer-actor's own connection.
771    fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
772        let conn = open_sqlcipher(db_path, key).expect("open db");
773        conn.query_row(
774            "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
775            rusqlite::params![op],
776            |r| r.get(0),
777        )
778        .expect("count")
779    }
780
781    /// Helper: load the most-recent `llm.sampling_call` audit row and
782    /// return `(result, principal_subject, details_json)`.
783    fn latest_sampling_audit_details(
784        db_path: &std::path::Path,
785        key: &KeyMaterial,
786    ) -> (String, Option<String>, serde_json::Value) {
787        let conn = open_sqlcipher(db_path, key).expect("open db");
788        let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
789            .query_row(
790                "SELECT result, principal_subject, details_json
791                 FROM audit_events
792                 WHERE operation = 'llm.sampling_call'
793                 ORDER BY ts_ms DESC, rowid DESC
794                 LIMIT 1",
795                [],
796                |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
797            )
798            .expect("query");
799        let details: serde_json::Value =
800            serde_json::from_str(&details_str.expect("details_json present"))
801                .expect("parse details");
802        (result, principal, details)
803    }
804
805    /// Happy path: a successful `create_message` round-trip returns
806    /// the assistant text wrapped in a `Message::assistant`, and lands
807    /// exactly one `llm.sampling_call` audit row with `result = 'ok'`.
808    #[tokio::test]
809    async fn sampling_complete_happy_path_returns_text() {
810        let h = harness().await;
811        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
812        let client = SamplingLlmClient::with_sampling_client(
813            fake.clone(),
814            h.write_handle.clone(),
815            Some("alice".into()),
816        );
817        let messages = vec![Message::user("summarise these episodes")];
818        let result = client.complete(&messages).await.expect("ok");
819        assert_eq!(result.role, Role::Assistant);
820        assert_eq!(result.content, "derived theme");
821
822        // Exactly one audit row landed.
823        assert_eq!(
824            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
825            1
826        );
827        let (result_str, principal, details) =
828            latest_sampling_audit_details(&h.db_path, &h.key);
829        assert_eq!(result_str, "ok");
830        assert_eq!(principal.as_deref(), Some("alice"));
831        assert_eq!(details["model_hint"], "claude");
832        assert_eq!(details["model"], "fake-claude");
833        assert_eq!(details["messages_count"], 1);
834        assert_eq!(details["max_tokens"], 512);
835    }
836
837    /// Privacy invariant: the audit row's `details_json` MUST NOT
838    /// contain the raw prompt content. Pinned by string inspection of
839    /// the persisted JSON.
840    #[tokio::test]
841    async fn audit_row_omits_raw_prompt_text() {
842        let h = harness().await;
843        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
844        let client = SamplingLlmClient::with_sampling_client(
845            fake,
846            h.write_handle.clone(),
847            None,
848        );
849        let secret = "THE-USER-ID-IS-bobby-1234";
850        let messages = vec![
851            Message::system("you are a friendly assistant"),
852            Message::user(secret),
853        ];
854        client.complete(&messages).await.expect("ok");
855
856        let (_, _, details) =
857            latest_sampling_audit_details(&h.db_path, &h.key);
858        let serialised =
859            serde_json::to_string(&details).expect("serialise details");
860        assert!(
861            !serialised.contains(secret),
862            "audit details must not carry raw prompt content; was: {serialised}"
863        );
864        assert!(
865            !serialised.contains("you are a friendly assistant"),
866            "audit details must not carry system prompt; was: {serialised}"
867        );
868        // Metadata IS present, even though the prompt is not.
869        assert_eq!(details["messages_count"], 1);
870        assert!(details["prompt_chars"].as_u64().unwrap() > 0);
871    }
872
873    /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the audit row's
874    /// `prompt_chars` MUST be the power-of-2 bucket, never the raw
875    /// character count. Pins the bucketing behavior end-to-end (raw
876    /// `audit_event` → SQLite → re-read).
877    ///
878    /// Test recipe: drive a prompt with a known raw length (6 chars
879    /// total, `"hello "` system + `"x"` user → 6+1 = 7) and assert the
880    /// audit row carries `8` (next pow2 ≥ 7), not 7.
881    #[tokio::test]
882    async fn audit_row_bucket_prompt_chars_to_pow2() {
883        let h = harness().await;
884        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
885        let client = SamplingLlmClient::with_sampling_client(
886            fake,
887            h.write_handle.clone(),
888            None,
889        );
890        // System: 6 chars + user: 1 char = 7 chars raw → bucket 8.
891        client
892            .complete(&[Message::system("hello "), Message::user("x")])
893            .await
894            .expect("ok");
895        let (_, _, details) =
896            latest_sampling_audit_details(&h.db_path, &h.key);
897        assert_eq!(
898            details["prompt_chars"].as_u64().unwrap(),
899            8,
900            "prompt_chars must be bucketed to next pow2 (7 → 8). \
901             raw count is a privacy side-channel; see Fix 4 F6 in \
902             v0.9.1 P1 dev log. got details={details}"
903        );
904    }
905
906    /// Stability invariant: two prompts that fall in the SAME bucket
907    /// must persist identical `prompt_chars`. Distinguishes "the
908    /// implementation buckets" from "the implementation hashes/leaks
909    /// raw values".
910    ///
911    /// 5 chars and 7 chars both round to 8 → must persist identically.
912    /// (Mirrors the brief's "test that bucketed values are stable
913    /// across exact-character variations within the same bucket".)
914    #[tokio::test]
915    async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
916        let h = harness().await;
917        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
918        let client = SamplingLlmClient::with_sampling_client(
919            fake,
920            h.write_handle.clone(),
921            None,
922        );
923        // 5 chars raw → bucket 8.
924        client
925            .complete(&[Message::user("hello")])
926            .await
927            .expect("ok");
928        let (_, _, details_5) =
929            latest_sampling_audit_details(&h.db_path, &h.key);
930        // 7 chars raw → bucket 8.
931        client
932            .complete(&[Message::user("hellooo")])
933            .await
934            .expect("ok");
935        let (_, _, details_7) =
936            latest_sampling_audit_details(&h.db_path, &h.key);
937        assert_eq!(
938            details_5["prompt_chars"], details_7["prompt_chars"],
939            "5 chars and 7 chars must hash to the same bucket (8) — \
940             otherwise the bucketing is leaking raw fidelity. \
941             5-char details: {details_5}, 7-char details: {details_7}"
942        );
943        assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
944    }
945
946    /// Unit-level pins for the bucketing helper. Catches a regression
947    /// where someone "simplifies" `next_pow2_bucket` into a no-op.
948    #[test]
949    fn next_pow2_bucket_table() {
950        assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
951        assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
952        assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
953        assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
954        assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
955        assert_eq!(next_pow2_bucket(5), 8);
956        assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
957        assert_eq!(next_pow2_bucket(7), 8);
958        assert_eq!(next_pow2_bucket(8), 8);
959        assert_eq!(next_pow2_bucket(9), 16);
960        assert_eq!(next_pow2_bucket(1023), 1024);
961        assert_eq!(next_pow2_bucket(1024), 1024);
962        assert_eq!(next_pow2_bucket(1025), 2048);
963    }
964
965    /// Client refusal: maps to `CoreError::Forbidden` + audit row
966    /// `result = 'forbidden'` + `details_json.reason = 'client_refused'`.
967    #[tokio::test]
968    async fn client_refusal_returns_forbidden_and_audits() {
969        let h = harness().await;
970        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
971        fake.reject_with("user dismissed approval");
972        let client = SamplingLlmClient::with_sampling_client(
973            fake,
974            h.write_handle.clone(),
975            Some("alice".into()),
976        );
977        let err = client
978            .complete(&[Message::user("anything")])
979            .await
980            .unwrap_err();
981        match err {
982            CoreError::Forbidden(_) => {}
983            other => panic!("expected Forbidden, got {other:?}"),
984        }
985        let (result_str, _, details) =
986            latest_sampling_audit_details(&h.db_path, &h.key);
987        assert_eq!(result_str, "forbidden");
988        assert_eq!(details["reason"], "client_refused");
989    }
990
991    /// Timeout: tokio::time::timeout fires before the fake's `Slow`
992    /// response resolves; client returns `CoreError::Llm` + audit row
993    /// `result = 'error'` + `details_json.reason = 'timeout'`.
994    ///
995    /// Real wall-clock: 80ms slow response vs 30ms client timeout.
996    /// Margin is loose enough for slow CI without making the test
997    /// drag.
998    #[tokio::test]
999    async fn timeout_returns_error_with_timeout_reason() {
1000        let h = harness().await;
1001        let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
1002            "late",
1003            Duration::from_millis(800),
1004        )));
1005        let client = SamplingLlmClient::with_sampling_client(
1006            fake,
1007            h.write_handle.clone(),
1008            None,
1009        )
1010        .with_timeout(Duration::from_millis(30));
1011        let err = client
1012            .complete(&[Message::user("hello")])
1013            .await
1014            .unwrap_err();
1015        match err {
1016            CoreError::Llm(msg) => assert!(msg.contains("timeout")),
1017            other => panic!("expected Llm, got {other:?}"),
1018        }
1019        let (result_str, _, details) =
1020            latest_sampling_audit_details(&h.db_path, &h.key);
1021        assert_eq!(result_str, "error");
1022        assert_eq!(details["reason"], "timeout");
1023    }
1024
1025    /// Malformed response: the fake returns a result with zero text
1026    /// content blocks; client surfaces `CoreError::Llm` + audit row
1027    /// `result = 'error'` + `details_json.reason = 'malformed_response'`.
1028    #[tokio::test]
1029    async fn malformed_response_returns_error_with_reason() {
1030        let h = harness().await;
1031        let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
1032        let client = SamplingLlmClient::with_sampling_client(
1033            fake,
1034            h.write_handle.clone(),
1035            None,
1036        );
1037        let err = client
1038            .complete(&[Message::user("hi")])
1039            .await
1040            .unwrap_err();
1041        assert!(matches!(err, CoreError::Llm(_)));
1042        let (result_str, _, details) =
1043            latest_sampling_audit_details(&h.db_path, &h.key);
1044        assert_eq!(result_str, "error");
1045        assert_eq!(details["reason"], "malformed_response");
1046    }
1047
1048    /// `principal_subject = None` works — audit row still emits with
1049    /// NULL.
1050    #[tokio::test]
1051    async fn no_principal_emits_audit_with_null_principal() {
1052        let h = harness().await;
1053        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1054        let client = SamplingLlmClient::with_sampling_client(
1055            fake,
1056            h.write_handle.clone(),
1057            None,
1058        );
1059        client.complete(&[Message::user("hi")]).await.expect("ok");
1060        let (_, principal, _) =
1061            latest_sampling_audit_details(&h.db_path, &h.key);
1062        assert_eq!(principal, None);
1063    }
1064
1065    /// Concurrency: 8 parallel `complete()` calls land 8 audit rows.
1066    /// Audit IDs (autoincrement rowid) must be distinct — verifies the
1067    /// writer-actor serialises the per-call audit emit (no
1068    /// interleaving / dropped rows).
1069    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1070    async fn parallel_completes_serialise_audit_rows() {
1071        let h = harness().await;
1072        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1073        let client = SamplingLlmClient::with_sampling_client(
1074            fake.clone(),
1075            h.write_handle.clone(),
1076            Some("alice".into()),
1077        );
1078        let mut futs = Vec::new();
1079        for _ in 0..8 {
1080            let c = client.clone();
1081            futs.push(tokio::spawn(async move {
1082                c.complete(&[Message::user("hi")]).await
1083            }));
1084        }
1085        for f in futs {
1086            f.await.expect("join").expect("ok");
1087        }
1088        assert_eq!(
1089            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1090            8,
1091            "8 parallel calls must land 8 audit rows"
1092        );
1093
1094        // Each was a separate request to the fake.
1095        assert_eq!(fake.record_requests().len(), 8);
1096    }
1097
1098    /// `complete` translates the workspace's `Message::system` into the
1099    /// `system_prompt` top-level field; user/assistant roles map to
1100    /// rmcp's `SamplingMessage::user_text` / `assistant_text`.
1101    #[tokio::test]
1102    async fn build_request_splits_system_from_messages() {
1103        let h = harness().await;
1104        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1105        let client = SamplingLlmClient::with_sampling_client(
1106            fake.clone(),
1107            h.write_handle.clone(),
1108            None,
1109        );
1110        client
1111            .complete(&[
1112                Message::system("be terse"),
1113                Message::user("question"),
1114                Message::assistant("answer"),
1115            ])
1116            .await
1117            .expect("ok");
1118        let recorded = fake.record_requests();
1119        assert_eq!(recorded.len(), 1);
1120        let req = &recorded[0];
1121        assert_eq!(
1122            req.system_prompt.as_deref(),
1123            Some("be terse"),
1124            "Role::System must map to system_prompt"
1125        );
1126        assert_eq!(req.messages.len(), 2);
1127        // The remaining two messages are the user + assistant turns.
1128        assert_eq!(req.messages[0].role, RmcpRole::User);
1129        assert_eq!(req.messages[1].role, RmcpRole::Assistant);
1130    }
1131
1132    /// `model_preferences` carries the `claude` hint per plan §6.
1133    /// Pins the wire shape so a future change is a conscious decision.
1134    #[tokio::test]
1135    async fn build_request_includes_claude_model_hint() {
1136        let h = harness().await;
1137        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1138        let client = SamplingLlmClient::with_sampling_client(
1139            fake.clone(),
1140            h.write_handle.clone(),
1141            None,
1142        );
1143        client
1144            .complete(&[Message::user("hi")])
1145            .await
1146            .expect("ok");
1147        let recorded = fake.record_requests();
1148        let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
1149        let hint = prefs
1150            .hints
1151            .as_ref()
1152            .and_then(|h| h.first())
1153            .and_then(|h| h.name.clone())
1154            .expect("hint name");
1155        assert_eq!(hint, "claude");
1156    }
1157
1158    /// `with_max_tokens(n)` propagates to the request's
1159    /// `max_tokens` field.
1160    #[tokio::test]
1161    async fn with_max_tokens_overrides_default() {
1162        let h = harness().await;
1163        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1164        let client = SamplingLlmClient::with_sampling_client(
1165            fake.clone(),
1166            h.write_handle.clone(),
1167            None,
1168        )
1169        .with_max_tokens(2048);
1170        client
1171            .complete(&[Message::user("hi")])
1172            .await
1173            .expect("ok");
1174        let recorded = fake.record_requests();
1175        assert_eq!(recorded[0].max_tokens, 2048);
1176    }
1177
1178    /// Reconfiguring the fake mid-test produces distinct audit rows
1179    /// for each call (positive then negative).
1180    #[tokio::test]
1181    async fn reconfigurable_fake_distinguishes_audit_rows() {
1182        let h = harness().await;
1183        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1184        let client = SamplingLlmClient::with_sampling_client(
1185            fake.clone(),
1186            h.write_handle.clone(),
1187            Some("alice".into()),
1188        );
1189
1190        client.complete(&[Message::user("a")]).await.expect("ok");
1191        fake.reject_with("user said no");
1192        let _ = client.complete(&[Message::user("b")]).await;
1193
1194        let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1195        let mut stmt = conn
1196            .prepare(
1197                "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1198            )
1199            .expect("prepare");
1200        let rows: Vec<String> = stmt
1201            .query_map([], |r| r.get::<_, String>(0))
1202            .expect("query")
1203            .map(|r| r.expect("row"))
1204            .collect();
1205        assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1206    }
1207
1208    /// `extract_text` walks single-block content.
1209    #[test]
1210    fn extract_text_pulls_text_from_single_block() {
1211        let result = CreateMessageResult::new(
1212            SamplingMessage::assistant_text("hello"),
1213            "fake".into(),
1214        );
1215        assert_eq!(extract_text(&result).unwrap(), "hello");
1216    }
1217
1218    /// `extract_text` rejects an empty-content response.
1219    #[test]
1220    fn extract_text_rejects_empty_content() {
1221        let result = CreateMessageResult::new(
1222            SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1223            "fake".into(),
1224        );
1225        assert!(extract_text(&result).is_err());
1226    }
1227
1228    /// `extract_text` rejects a User-role response (impossible per
1229    /// spec but pinning the defensive check).
1230    #[test]
1231    fn extract_text_rejects_non_assistant_role() {
1232        let result = CreateMessageResult::new(
1233            SamplingMessage::user_text("hello"),
1234            "fake".into(),
1235        );
1236        assert!(extract_text(&result).is_err());
1237    }
1238
1239    /// `SamplingError::classify` maps each fake variant to the right
1240    /// audit category.
1241    #[test]
1242    fn sampling_error_classify_maps_fake_variants() {
1243        let refused = SamplingError::Fake(FakeSamplingError::Refused {
1244            reason: "x".into(),
1245        });
1246        let (cat, forb) = refused.classify();
1247        assert_eq!(cat, "client_refused");
1248        assert!(forb);
1249
1250        let transport = SamplingError::Fake(FakeSamplingError::Transport {
1251            message: "x".into(),
1252        });
1253        let (cat, forb) = transport.classify();
1254        assert_eq!(cat, "transport_error");
1255        assert!(!forb);
1256
1257        let malformed =
1258            SamplingError::Fake(FakeSamplingError::MalformedResponse {
1259                message: "x".into(),
1260            });
1261        let (cat, forb) = malformed.classify();
1262        assert_eq!(cat, "malformed_response");
1263        assert!(!forb);
1264    }
1265
1266    // -------- v0.9.0 P5a (M3 wiring) — SamplingCoordinator integration --------
1267    //
1268    // These tests pin the contract that `build_sampling_steward` wraps the
1269    // live peer in a `SamplingCoordinator` before handing it to
1270    // `SamplingLlmClient`. They cannot call `build_sampling_steward`
1271    // directly (it takes a real `Peer<RoleServer>` whose constructors are
1272    // private inside rmcp), but they exercise the **exact same wiring
1273    // shape** by substituting `FakeMcpClient` for `PeerSamplingClient`.
1274    // The production code path is:
1275    //
1276    //     PeerSamplingClient -> SamplingCoordinator -> SamplingLlmClient
1277    //
1278    // The tested shape is:
1279    //
1280    //     FakeMcpClient      -> SamplingCoordinator -> SamplingLlmClient
1281    //
1282    // Only the leaf `SamplingClient` impl differs; the
1283    // `SamplingClient` trait is the same Arc-of-dyn in both paths.
1284
1285    /// SamplingCoordinator wrapping a `FakeMcpClient` and feeding
1286    /// `SamplingLlmClient::with_sampling_client` is the same Arc-of-dyn
1287    /// shape `build_sampling_steward` constructs at MCP-initialize
1288    /// time. Single-element flushes pass through unwrapped, so a lone
1289    /// `complete()` call still emits one audit row and produces the
1290    /// expected text.
1291    #[tokio::test]
1292    async fn sampling_llm_client_uses_coordinator_in_production_path() {
1293        let h = harness().await;
1294        let fake: Arc<dyn SamplingClient> =
1295            Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1296        let coord: Arc<dyn SamplingClient> =
1297            super::super::SamplingCoordinator::with_settings(
1298                fake.clone(),
1299                Duration::from_millis(50),
1300                10,
1301            );
1302        let client = SamplingLlmClient::with_sampling_client(
1303            coord,
1304            h.write_handle.clone(),
1305            Some("alice".into()),
1306        );
1307        let result = client
1308            .complete(&[Message::user("test")])
1309            .await
1310            .expect("ok");
1311        assert_eq!(result.role, Role::Assistant);
1312        assert_eq!(result.content, "ok");
1313        // Single audit row landed — per-call audit semantics
1314        // unchanged by the coordinator wrap.
1315        assert_eq!(
1316            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1317            1,
1318            "one logical call → one audit row, even through coordinator"
1319        );
1320    }
1321
1322    /// End-to-end batching pin: N concurrent `complete()` calls within
1323    /// the coalesce window resolve as ONE inner `create_message` RPC
1324    /// on the underlying `FakeMcpClient`, but N audit rows still land
1325    /// (one per logical call — the privacy + audit invariants from P2
1326    /// hold).
1327    ///
1328    /// This is the v0.9.0 release notes' "⌈N/M⌉ peer.create_message
1329    /// calls per coalesce window" claim, exercised through the same
1330    /// trait-object chain that `build_sampling_steward` constructs in
1331    /// production.
1332    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1333    async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1334        // Coalesced JSON response for 5 tasks — matches the
1335        // `[{task_index, response}]` shape `flush_batch` demuxes
1336        // multi-element batches into.
1337        let response = serde_json::to_string(&(0..5)
1338            .map(|i| serde_json::json!({
1339                "task_index": i,
1340                "response": format!("response-{i}"),
1341            }))
1342            .collect::<Vec<_>>())
1343            .unwrap();
1344
1345        let h = harness().await;
1346        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1347        let coord: Arc<dyn SamplingClient> =
1348            super::super::SamplingCoordinator::with_settings(
1349                fake.clone(),
1350                // Wide window so all 5 submissions land in one batch.
1351                Duration::from_secs(5),
1352                10,
1353            );
1354        let client = SamplingLlmClient::with_sampling_client(
1355            coord,
1356            h.write_handle.clone(),
1357            Some("alice".into()),
1358        );
1359
1360        // Fire 5 concurrent `complete()` calls; the coordinator should
1361        // coalesce them into ONE `FakeMcpClient::create_message` call.
1362        let mut futs = Vec::new();
1363        for i in 0..5 {
1364            let c = client.clone();
1365            futs.push(tokio::spawn(async move {
1366                c.complete(&[Message::user(format!("task-{i}"))]).await
1367            }));
1368        }
1369        for f in futs {
1370            f.await.expect("join").expect("ok");
1371        }
1372
1373        // EXACTLY one inner RPC.
1374        assert_eq!(
1375            fake.record_requests().len(),
1376            1,
1377            "5 logical calls within window must coalesce to 1 inner RPC"
1378        );
1379        // BUT 5 audit rows — per-logical-call audit invariant preserved.
1380        assert_eq!(
1381            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1382            5,
1383            "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1384        );
1385    }
1386
1387    /// Edge case: `coalesce_max_requests = 1` reduces the coordinator
1388    /// to pass-through (each submit flushes a 1-element batch
1389    /// immediately). With max_batch=1 and a wide window, 3 concurrent
1390    /// calls land 3 inner RPCs — coordinator is operating as if no
1391    /// batching were configured.
1392    ///
1393    /// Pins the brief's documented edge-case: zero / one-valued config
1394    /// reduces to pass-through, never panics or deadlocks. Mirrors
1395    /// `SamplingCoordinator::with_settings`'s `max_batch.max(1)`
1396    /// clamping for the `coalesce_max_requests = 0` case.
1397    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1398    async fn coordinator_max_batch_one_acts_as_passthrough() {
1399        let h = harness().await;
1400        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1401        let coord: Arc<dyn SamplingClient> =
1402            super::super::SamplingCoordinator::with_settings(
1403                fake.clone(),
1404                Duration::from_secs(5),
1405                // max_batch=1 → every submission flushes immediately as
1406                // a 1-element batch; pass-through behaviour.
1407                1,
1408            );
1409        let client = SamplingLlmClient::with_sampling_client(
1410            coord,
1411            h.write_handle.clone(),
1412            None,
1413        );
1414        let mut futs = Vec::new();
1415        for _ in 0..3 {
1416            let c = client.clone();
1417            futs.push(tokio::spawn(async move {
1418                c.complete(&[Message::user("hi")]).await
1419            }));
1420        }
1421        for f in futs {
1422            f.await.expect("join").expect("ok");
1423        }
1424        // 3 logical calls → 3 inner RPCs (no coalescing).
1425        assert_eq!(
1426            fake.record_requests().len(),
1427            3,
1428            "max_batch=1 must pass through every submission as its own RPC"
1429        );
1430        assert_eq!(
1431            count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1432            3
1433        );
1434    }
1435}