solo_api/llm/sampling.rs
1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingLlmClient`] — `LlmClient` impl backed by an MCP client's
4//! `sampling/createMessage` capability.
5//!
6//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
7//! §6 "Sampling-backed LLM client" / MAJOR 1 + MAJOR 3 resolutions):
8//!
9//! * Steward holds an `Arc<dyn LlmClient>`. When `LlmConfig::McpSampling`
10//! is configured, the Steward's `LlmClient` is a `SamplingLlmClient`
11//! constructed at MCP `initialize` time (when the live peer becomes
12//! available — the `TenantHandle::steward_slot` LATE-population path).
13//!
14//! * `SamplingLlmClient::complete()` translates the workspace's
15//! `Message` → `rmcp::SamplingMessage`, calls
16//! `peer.create_message(params).await`, extracts the assistant's
17//! text from the returned `CreateMessageResult`, and emits a
18//! per-call `AuditOperation::LlmSamplingCall` row through the
19//! tenant's `WriteHandle` (lesson #30: sync in writer-actor tx
20//! for ACID).
21//!
22//! * **Privacy invariant**: the audit `details_json` carries metadata
23//! only — model hint, message count, max_tokens, duration_ms,
24//! total prompt character count, output character count. **The raw
25//! prompt content MUST NOT appear in the audit row**. Pinned by
26//! [`tests::audit_row_omits_raw_prompt_text`].
27//!
28//! * Error paths land structured audit rows:
29//! - Client refusal → `result = "forbidden"`,
30//! `details_json.reason = "client_refused"`.
31//! - Timeout → `result = "error"`,
32//! `details_json.reason = "timeout"`.
33//! - Other transport / malformed-response → `result = "error"`,
34//! `details_json.reason = <category>`.
35//!
36//! * Per-call rate-limit / coalescing is **deferred to v0.9.0 P4**
37//! (`SamplingCoordinator`). P2 wires the per-call path only.
38
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44 CreateMessageRequestParams, CreateMessageResult, ModelHint, ModelPreferences, Role as RmcpRole,
45 SamplingMessage, SamplingMessageContent,
46};
47use rmcp::service::{Peer, RoleServer, ServiceError};
48use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
49use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
50
51/// Default per-call timeout. Drives the bounded wait around
52/// `peer.create_message`; if the client refuses or stalls, the caller
53/// sees a structured timeout error instead of an indefinite hang.
54///
55/// 30 seconds matches the consolidate-timer's cadence margins: an
56/// LLM call slower than this would already starve the Steward batch
57/// in P4's coordinator. Configurable per-construct via
58/// [`SamplingLlmClient::with_timeout`].
59pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
60
61/// Default max_tokens for sampling completions. Matches
62/// `solo-steward::StewardConfig::default().abstraction_max_tokens`
63/// so the wire shape is identical to what the Steward would have
64/// requested from any other backend.
65const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
66
67/// Error surface for [`SamplingClient::create_message`]. Combines the
68/// real rmcp `ServiceError` (when wrapping a live `Peer<RoleServer>`)
69/// with [`super::super::test_support::fake_mcp_client::FakeSamplingError`]
70/// (when driving the fixture from tests).
71#[derive(Debug)]
72pub enum SamplingError {
73 /// Forwarded from `rmcp::Peer::create_message`.
74 Service(ServiceError),
75 /// Routed from [`super::super::test_support::fake_mcp_client::
76 /// FakeSamplingError`] in test paths.
77 #[cfg(any(test, feature = "test-support"))]
78 Fake(crate::test_support::FakeSamplingError),
79}
80
81impl std::fmt::Display for SamplingError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Self::Service(e) => write!(f, "{e}"),
85 #[cfg(any(test, feature = "test-support"))]
86 Self::Fake(e) => write!(f, "{e}"),
87 }
88 }
89}
90
91impl std::error::Error for SamplingError {}
92
93impl SamplingError {
94 /// Classifier used by [`SamplingLlmClient::complete`] to map the
95 /// transport-level error to an audit-row category + a Solo
96 /// [`CoreError`] variant.
97 ///
98 /// `(category_for_audit, treat_as_forbidden)` — `forbidden` becomes
99 /// `AuditResult::Forbidden` + `CoreError::Forbidden`; everything
100 /// else is `AuditResult::Error` + `CoreError::Llm`.
101 pub fn classify(&self) -> (&'static str, bool) {
102 match self {
103 Self::Service(_) => ("transport_error", false),
104 #[cfg(any(test, feature = "test-support"))]
105 Self::Fake(e) => match e {
106 crate::test_support::FakeSamplingError::Refused { .. } => ("client_refused", true),
107 crate::test_support::FakeSamplingError::Transport { .. } => {
108 ("transport_error", false)
109 }
110 crate::test_support::FakeSamplingError::MalformedResponse { .. } => {
111 ("malformed_response", false)
112 }
113 },
114 }
115 }
116}
117
118/// Trait abstracting the `sampling/createMessage` RPC. The production
119/// impl wraps `Arc<Peer<RoleServer>>`; the test impl is
120/// [`super::super::test_support::fake_mcp_client::FakeMcpClient`].
121///
122/// Separating the trait from the concrete `Peer<RoleServer>` is the
123/// way around rmcp's `Peer` having private constructors — we can't
124/// build a fake `Peer` for tests, so we inject behind a trait.
125#[async_trait]
126pub trait SamplingClient: Send + Sync {
127 async fn create_message(
128 &self,
129 params: CreateMessageRequestParams,
130 ) -> Result<CreateMessageResult, SamplingError>;
131}
132
133/// Production wrapper around `rmcp::Peer<RoleServer>`. The Peer is
134/// cheap to clone (internally `Arc`-backed) and stays valid for the
135/// lifetime of the MCP session.
136pub struct PeerSamplingClient {
137 peer: Peer<RoleServer>,
138}
139
140impl PeerSamplingClient {
141 pub fn new(peer: Peer<RoleServer>) -> Self {
142 Self { peer }
143 }
144}
145
146#[async_trait]
147impl SamplingClient for PeerSamplingClient {
148 async fn create_message(
149 &self,
150 params: CreateMessageRequestParams,
151 ) -> Result<CreateMessageResult, SamplingError> {
152 self.peer
153 .create_message(params)
154 .await
155 .map_err(SamplingError::Service)
156 }
157}
158
159/// `LlmClient` impl whose `complete()` calls back via the connected
160/// MCP client's sampling capability.
161///
162/// Construct via [`SamplingLlmClient::new`] (production path: wraps a
163/// real `Peer<RoleServer>`) or [`SamplingLlmClient::with_sampling_client`]
164/// (test path: takes the abstracted [`SamplingClient`] trait object
165/// directly so [`super::super::test_support::fake_mcp_client::
166/// FakeMcpClient`] can drive it).
167///
168/// Cheap to clone — every field is `Arc`-shared.
169#[derive(Clone)]
170pub struct SamplingLlmClient {
171 /// The RPC channel back to the MCP client.
172 sampling_client: Arc<dyn SamplingClient>,
173 /// Per-tenant `WriteHandle` for the synchronous audit emit. Routes
174 /// through the writer-actor's mpsc so the
175 /// `AuditOperation::LlmSamplingCall` INSERT lands in a dedicated
176 /// `BEGIN IMMEDIATE` transaction on the writer's connection.
177 write_handle: WriteHandle,
178 /// Cached audit `principal_subject` for the MCP session. Resolved
179 /// at session init time (see `mcp::resolve_mcp_principal`); `None`
180 /// for unauthenticated stdio sessions.
181 audit_principal: Option<String>,
182 /// `max_tokens` value sent on every `sampling/createMessage`.
183 /// Defaults to [`DEFAULT_SAMPLING_MAX_TOKENS`]; configurable via
184 /// [`Self::with_max_tokens`].
185 max_tokens: u32,
186 /// Bounded wait on `create_message`. See
187 /// [`DEFAULT_SAMPLING_TIMEOUT`].
188 timeout: Duration,
189}
190
191impl SamplingLlmClient {
192 /// Build a client wrapping a real `Peer<RoleServer>`. Production
193 /// path — called from
194 /// [`crate::mcp::SoloMcpServer::populate_sampling_steward`] when an MCP
195 /// session reaches `initialize` with a sampling-capable peer.
196 pub fn new(
197 peer: Peer<RoleServer>,
198 write_handle: WriteHandle,
199 audit_principal: Option<String>,
200 ) -> Self {
201 Self::with_sampling_client(
202 Arc::new(PeerSamplingClient::new(peer)),
203 write_handle,
204 audit_principal,
205 )
206 }
207
208 /// Test-friendly constructor accepting any [`SamplingClient`]
209 /// implementation. Pair with
210 /// [`super::super::test_support::fake_mcp_client::FakeMcpClient`]
211 /// in tests.
212 pub fn with_sampling_client(
213 sampling_client: Arc<dyn SamplingClient>,
214 write_handle: WriteHandle,
215 audit_principal: Option<String>,
216 ) -> Self {
217 Self {
218 sampling_client,
219 write_handle,
220 audit_principal,
221 max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
222 timeout: DEFAULT_SAMPLING_TIMEOUT,
223 }
224 }
225
226 /// Override the per-call `max_tokens` cap.
227 pub fn with_max_tokens(mut self, n: u32) -> Self {
228 self.max_tokens = n.max(1);
229 self
230 }
231
232 /// Override the per-call timeout.
233 pub fn with_timeout(mut self, t: Duration) -> Self {
234 self.timeout = t;
235 self
236 }
237
238 /// Build the `CreateMessageRequestParams` from Solo's `Message`
239 /// vec. Splits out `Role::System` into the `system_prompt` field
240 /// (rmcp's `SamplingMessage::role` is only User / Assistant) and
241 /// hints the user's MCP client toward a Claude-class model.
242 fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
243 // Split system messages out of the conversation history; the
244 // sampling protocol carries the system prompt as a top-level
245 // field rather than inline.
246 let mut system_parts: Vec<String> = Vec::new();
247 let mut samp_messages: Vec<SamplingMessage> = Vec::new();
248 for m in messages {
249 match m.role {
250 Role::System => system_parts.push(m.content.clone()),
251 Role::User => {
252 samp_messages.push(SamplingMessage::user_text(&m.content));
253 }
254 Role::Assistant => {
255 samp_messages.push(SamplingMessage::assistant_text(&m.content));
256 }
257 }
258 }
259 // rmcp 1.7's struct literals are non-exhaustive across crate
260 // boundaries; build via the typed constructors + builders.
261 let preferences = ModelPreferences::new()
262 .with_hints(vec![ModelHint::new("claude")])
263 .with_intelligence_priority(0.7)
264 .with_speed_priority(0.3)
265 .with_cost_priority(0.4);
266 let mut params = CreateMessageRequestParams::new(samp_messages, self.max_tokens)
267 .with_model_preferences(preferences);
268 if !system_parts.is_empty() {
269 params = params.with_system_prompt(system_parts.join("\n\n"));
270 }
271 params
272 }
273
274 /// Build the audit `AuditEvent` carrying ONLY metadata. No raw
275 /// prompt content lands in `details_json`.
276 ///
277 /// Pinned by [`tests::audit_row_omits_raw_prompt_text`].
278 ///
279 /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the raw character count
280 /// of the prompt is itself a side-channel — a 6-char prompt
281 /// uniquely identifies very-short refusal paths (e.g. a leaked
282 /// password length). `prompt_chars` and `input_tokens_est` are
283 /// rounded up to the next power of two before persistence. This
284 /// preserves operator capacity-planning (the bucket is within ~2x
285 /// of the real size for any sufficiently large prompt) while
286 /// removing the per-character precision.
287 ///
288 /// Buckets: `0, 1, 2, 4, 8, 16, 32, 64, ..., 1024, 2048, ...`
289 /// (next power of two `>= n`). 0 stays 0.
290 fn audit_event(
291 &self,
292 params: &CreateMessageRequestParams,
293 outcome: SamplingOutcome,
294 ) -> AuditEvent {
295 let raw_prompt_chars: usize = params
296 .messages
297 .iter()
298 .flat_map(|m| m.content.iter())
299 .filter_map(|c| c.as_text().map(|t| t.text.len()))
300 .sum::<usize>()
301 + params.system_prompt.as_ref().map(|s| s.len()).unwrap_or(0);
302 // v0.9.1 P1 Fix 4: bucket the raw count to the next power of
303 // two. Pinned by `tests::audit_row_bucket_prompt_chars_to_pow2`.
304 let prompt_chars = next_pow2_bucket(raw_prompt_chars);
305 // ~4 chars per token for the rough English-text estimate used
306 // by `solo doctor --check-llm` and Anthropic's docs. Recorded
307 // for operator capacity-planning. Bucketed for the same
308 // privacy reason — and to stay consistent with `prompt_chars`.
309 let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
310 let model_hint = params
311 .model_preferences
312 .as_ref()
313 .and_then(|p| p.hints.as_ref())
314 .and_then(|h| h.first())
315 .and_then(|h| h.name.clone())
316 .unwrap_or_else(|| "(none)".to_string());
317
318 let mut details = serde_json::Map::new();
319 details.insert(
320 "model_hint".to_string(),
321 serde_json::Value::String(model_hint),
322 );
323 details.insert(
324 "messages_count".to_string(),
325 serde_json::Value::Number(params.messages.len().into()),
326 );
327 details.insert(
328 "max_tokens".to_string(),
329 serde_json::Value::Number(params.max_tokens.into()),
330 );
331 details.insert(
332 "prompt_chars".to_string(),
333 serde_json::Value::Number(prompt_chars.into()),
334 );
335 details.insert(
336 "input_tokens_est".to_string(),
337 serde_json::Value::Number(input_tokens_est.into()),
338 );
339
340 let result = match &outcome {
341 SamplingOutcome::Ok {
342 duration_ms,
343 model,
344 output_chars,
345 } => {
346 // v0.9.1 P1 Fix 4: same power-of-2 bucketing as
347 // `prompt_chars` for the output side. A model that
348 // always replies with a one-token refusal (e.g. an
349 // assistant trained to say "no.") would otherwise leak
350 // the response-length distribution; bucketing
351 // collapses 1-2-3-4 chars all into bucket 4.
352 let bucketed_output_chars = next_pow2_bucket(*output_chars);
353 let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
354 details.insert(
355 "duration_ms".to_string(),
356 serde_json::Value::Number((*duration_ms).into()),
357 );
358 details.insert(
359 "model".to_string(),
360 serde_json::Value::String(model.clone()),
361 );
362 details.insert(
363 "output_chars".to_string(),
364 serde_json::Value::Number(bucketed_output_chars.into()),
365 );
366 details.insert(
367 "output_tokens_est".to_string(),
368 serde_json::Value::Number(output_tokens_est.into()),
369 );
370 AuditResult::Ok
371 }
372 SamplingOutcome::Forbidden {
373 reason,
374 duration_ms,
375 } => {
376 details.insert(
377 "duration_ms".to_string(),
378 serde_json::Value::Number((*duration_ms).into()),
379 );
380 details.insert(
381 "reason".to_string(),
382 serde_json::Value::String(reason.to_string()),
383 );
384 AuditResult::Forbidden
385 }
386 SamplingOutcome::Error {
387 reason,
388 duration_ms,
389 } => {
390 details.insert(
391 "duration_ms".to_string(),
392 serde_json::Value::Number((*duration_ms).into()),
393 );
394 details.insert(
395 "reason".to_string(),
396 serde_json::Value::String(reason.to_string()),
397 );
398 AuditResult::Error
399 }
400 };
401
402 AuditEvent {
403 ts_ms: chrono::Utc::now().timestamp_millis(),
404 principal_subject: self.audit_principal.clone(),
405 operation: AuditOperation::LlmSamplingCall,
406 target_id: None,
407 result,
408 details: Some(serde_json::Value::Object(details)),
409 }
410 }
411}
412
413/// Internal outcome category for the audit-row builder.
414enum SamplingOutcome {
415 Ok {
416 duration_ms: u64,
417 model: String,
418 output_chars: usize,
419 },
420 Forbidden {
421 reason: &'static str,
422 duration_ms: u64,
423 },
424 Error {
425 reason: &'static str,
426 duration_ms: u64,
427 },
428}
429
430#[async_trait]
431impl LlmClient for SamplingLlmClient {
432 fn name(&self) -> &str {
433 "mcp-sampling"
434 }
435
436 async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
437 let params = self.build_request(messages);
438 let start = Instant::now();
439
440 // Bounded wait on `peer.create_message`. The fold of (rmcp
441 // ServiceError | FakeError | tokio timeout) into the
442 // `Outcome` enum keeps the audit path single-sourced.
443 let rpc = tokio::time::timeout(
444 self.timeout,
445 self.sampling_client.create_message(params.clone()),
446 )
447 .await;
448 let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX)) as u64;
449
450 let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) = match rpc {
451 Ok(Ok(result)) => match extract_text(&result) {
452 Ok(text) => {
453 let output_chars = text.len();
454 let outcome = SamplingOutcome::Ok {
455 duration_ms,
456 model: result.model.clone(),
457 output_chars,
458 };
459 (Ok(Message::assistant(text)), outcome)
460 }
461 Err(reason) => (
462 Err(CoreError::llm(format!(
463 "mcp sampling: malformed response: {reason}"
464 ))),
465 SamplingOutcome::Error {
466 reason: "malformed_response",
467 duration_ms,
468 },
469 ),
470 },
471 Ok(Err(e)) => {
472 let (category, is_forbidden) = e.classify();
473 let outcome = if is_forbidden {
474 SamplingOutcome::Forbidden {
475 reason: category,
476 duration_ms,
477 }
478 } else {
479 SamplingOutcome::Error {
480 reason: category,
481 duration_ms,
482 }
483 };
484 let err = if is_forbidden {
485 CoreError::forbidden(format!("mcp sampling: {e}"))
486 } else {
487 CoreError::llm(format!("mcp sampling: {e}"))
488 };
489 (Err(err), outcome)
490 }
491 Err(_elapsed) => (
492 Err(CoreError::llm(format!(
493 "mcp sampling: timeout after {}ms",
494 duration_ms
495 ))),
496 SamplingOutcome::Error {
497 reason: "timeout",
498 duration_ms,
499 },
500 ),
501 };
502
503 // Synchronous audit emit through the writer-actor (lesson #30).
504 // Failure to land the audit row is operator-visible: the
505 // sampling call's caller sees the storage error and can decide
506 // whether to abort (we DO abort here — without the audit row
507 // we have no record of the call).
508 //
509 // v0.9.1 P1 Fix 3 (F4 Result-shadowing): when BOTH `core_result`
510 // is `Err(..)` AND the audit emit also fails, return the
511 // ORIGINAL LLM-side error (more actionable for callers — they
512 // can retry the LLM call, or decide whether the upstream
513 // refusal/timeout is recoverable). Surface the audit failure
514 // via `tracing::error!` for operator visibility — operators
515 // alarming on storage errors see it; callers see the actionable
516 // error.
517 //
518 // Policy summary:
519 // * RPC Ok + audit Ok → return Ok(text)
520 // * RPC Ok + audit Err → return Err(storage) [audit failure
521 // wins — no undocumented sampling
522 // calls per lesson #30]
523 // * RPC Err + audit Ok → return Err(llm/forbidden) [unchanged]
524 // * RPC Err + audit Err → return Err(llm/forbidden) AND log
525 // audit failure at error level
526 // [v0.9.1 P1 Fix 3]
527 let event = self.audit_event(¶ms, outcome);
528 match (
529 core_result,
530 self.write_handle.emit_llm_sampling_audit(event).await,
531 ) {
532 (Ok(text), Ok(())) => Ok(text),
533 (Ok(_text), Err(audit_err)) => {
534 // RPC succeeded but the audit row didn't land. Drop
535 // the success — without a durable audit row we can't
536 // honor the "every sampling call leaves a trace"
537 // invariant.
538 Err(CoreError::storage(format!(
539 "mcp sampling: audit emit failed: {audit_err}"
540 )))
541 }
542 (Err(core_err), Ok(())) => Err(core_err),
543 (Err(core_err), Err(audit_err)) => {
544 // Both failed. Return the LLM-side error (the caller's
545 // most actionable signal); log the audit failure so an
546 // operator who alarms on storage errors still sees it.
547 tracing::error!(
548 audit_error = %audit_err,
549 core_error = %core_err,
550 "mcp sampling: audit emit failed alongside core \
551 error; surfacing core error to caller"
552 );
553 Err(core_err)
554 }
555 }
556 }
557}
558
559/// Round `n` up to the next power of two. Used to bucket
560/// `prompt_chars` / `output_chars` / `*_tokens_est` in the
561/// `LlmSamplingCall` audit row's `details_json` (v0.9.1 P1 Fix 4
562/// "F6" — `prompt_chars` was a privacy side-channel for short
563/// prompts).
564///
565/// Buckets: `0 → 0`, `1 → 1`, `2 → 2`, `3 → 4`, `4 → 4`, `5..=8 → 8`,
566/// `9..=16 → 16`, `17..=32 → 32`, ... — within a bucket all values
567/// collapse to the same persisted number. The worst-case fidelity
568/// loss is just under 2x (e.g. 9 chars persists as 16) which is well
569/// within the precision capacity-planning needs.
570///
571/// Pinned by [`tests::next_pow2_bucket_*`] and
572/// [`tests::audit_row_bucket_prompt_chars_to_pow2`].
573fn next_pow2_bucket(n: usize) -> usize {
574 if n == 0 {
575 return 0;
576 }
577 // `next_power_of_two` saturates at `usize::MAX` if `n` is past the
578 // last representable power. For our use (char counts on a Solo
579 // prompt) the absolute upper bound is the LLM model's context
580 // window — well below `usize::MAX` on every Solo-supported target.
581 n.next_power_of_two()
582}
583
584/// Pull the assistant's text out of the rmcp result. Walks every text
585/// content block in the message (the spec allows either a single
586/// `SamplingContent::Single` or a `SamplingContent::Multiple`) and
587/// concatenates them with newlines. Returns `Err(reason)` if no text
588/// blocks were present — the malformed-response path.
589fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
590 if result.message.role != RmcpRole::Assistant {
591 return Err("response role was not Assistant");
592 }
593 let mut out = String::new();
594 for content in result.message.content.iter() {
595 if let SamplingMessageContent::Text(text) = content {
596 if !out.is_empty() {
597 out.push('\n');
598 }
599 out.push_str(&text.text);
600 }
601 }
602 if out.is_empty() {
603 Err("no text content blocks")
604 } else {
605 Ok(out)
606 }
607}
608
609/// v0.9.0 P2: build a sampling-backed `Arc<Steward>` for a tenant that
610/// has resolved `LlmConfig::McpSampling` and just attached an MCP
611/// session.
612///
613/// Called from [`crate::mcp::SoloMcpServer::populate_sampling_steward`] at
614/// MCP `initialize` time once the peer's sampling capability is
615/// confirmed. The returned `Arc<Steward>` is written into
616/// `tenant.steward_slot()` so the writer-actor + consolidate timer
617/// can read a populated slot on their next tick.
618///
619/// v0.9.0 P5 (M3 wiring): the live `PeerSamplingClient` is now wrapped
620/// in a [`super::SamplingCoordinator`] before being handed to
621/// `SamplingLlmClient`. Concurrent `complete()` calls within the
622/// coalesce window collapse into one `peer.create_message` RPC and the
623/// response is demultiplexed back per-task — matching the
624/// `[sampling] coalesce_window_ms` / `coalesce_max_requests` config the
625/// operator wrote in `solo.config.toml`. Per-call audit emit semantics
626/// are unchanged: every logical request still lands one
627/// `AuditOperation::LlmSamplingCall` row, no raw prompt content escapes
628/// to the audit row.
629///
630/// Edge case (clamping): the `[sampling]` block accepts values that
631/// effectively disable batching — `coalesce_max_requests = 1` and / or
632/// `coalesce_window_ms = 0` reduce the coordinator to pass-through (one
633/// inner call per submission). The coordinator's
634/// [`super::SamplingCoordinator::with_settings`] clamps `max_batch` to
635/// `max(1)` so a zero value still produces a single-element flush
636/// immediately rather than panicking or deadlocking.
637pub fn build_sampling_steward(
638 peer: Peer<RoleServer>,
639 write_handle: WriteHandle,
640 audit_principal: Option<String>,
641 steward_config: solo_steward::StewardConfig,
642 sampling_config: solo_storage::SamplingConfig,
643) -> Arc<solo_steward::Steward> {
644 let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
645 let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
646 inner,
647 std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
648 sampling_config.coalesce_max_requests as usize,
649 );
650 let client =
651 SamplingLlmClient::with_sampling_client(coordinator, write_handle, audit_principal)
652 .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
653 Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659 use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
660 use rmcp::model::CreateMessageResult;
661 use solo_core::TenantId;
662 use solo_storage::{
663 EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder, TenantHandle,
664 TenantRegistry, TenantRegistryParams, init, open_sqlcipher,
665 };
666 use std::path::PathBuf;
667 use std::sync::Arc;
668 use tempfile::TempDir;
669 use zeroize::Zeroizing;
670
671 const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
672
673 /// Bootstrap a per-tenant `TenantHandle` whose writer-actor accepts
674 /// the new `WriteCommand::EmitLlmSamplingAudit` variant.
675 ///
676 /// Mirrors the v0.8.x test discipline (see
677 /// `crates/solo-storage/src/tenants/handle_registry_tests.rs`'s
678 /// `fresh_init_dir`): build a real tenant DB on disk via the same
679 /// `init()` helper users invoke, wrap in a `TenantRegistry`, and
680 /// surface the `WriteHandle` for direct `SamplingLlmClient`
681 /// wiring.
682 struct Harness {
683 _tmp: TempDir,
684 _registry: Arc<TenantRegistry>,
685 _tenant: Arc<TenantHandle>,
686 write_handle: solo_storage::WriteHandle,
687 db_path: PathBuf,
688 key: KeyMaterial,
689 }
690
691 async fn harness() -> Harness {
692 let tmp = TempDir::new().expect("tempdir");
693 let data_dir = tmp.path().to_path_buf();
694 let _ = init(InitParams {
695 data_dir: data_dir.clone(),
696 passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
697 force: false,
698 embedder: EmbedderConfig {
699 name: "stub".into(),
700 version: "v1".into(),
701 dim: 32,
702 dtype: "f32".into(),
703 },
704 })
705 .expect("init");
706
707 let cfg =
708 solo_storage::SoloConfig::read(&data_dir.join("solo.config.toml")).expect("read cfg");
709 let key = KeyMaterial::derive(TEST_PASSPHRASE, &cfg.salt_bytes().expect("salt"))
710 .expect("derive key");
711
712 let embedder: Arc<dyn solo_core::Embedder> = Arc::new(StubEmbedder::new("stub", "v1", 32));
713 let registry = Arc::new(
714 TenantRegistry::open(TenantRegistryParams {
715 data_dir: data_dir.clone(),
716 key: key.clone(),
717 embedder: embedder.clone(),
718 hnsw_params: HnswParams::default(),
719 steward: None,
720 runtime_handle: Some(tokio::runtime::Handle::current()),
721 steward_factory: None,
722 triples_batch_signal: None,
723 })
724 .expect("open registry"),
725 );
726
727 let tenant_id = TenantId::default_tenant();
728 let tenant = registry
729 .get_or_open(&tenant_id)
730 .await
731 .expect("get_or_open default tenant");
732 let write_handle = tenant.write().clone();
733 let db_path = tenant.db_path().to_path_buf();
734
735 Harness {
736 _tmp: tmp,
737 _registry: registry,
738 _tenant: tenant,
739 write_handle,
740 db_path,
741 key,
742 }
743 }
744
745 /// Helper: count the `audit_events` rows whose `operation` is the
746 /// given string. Opens a fresh connection to the tenant DB so we
747 /// avoid contention with the writer-actor's own connection.
748 fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
749 let conn = open_sqlcipher(db_path, key).expect("open db");
750 conn.query_row(
751 "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
752 rusqlite::params![op],
753 |r| r.get(0),
754 )
755 .expect("count")
756 }
757
758 /// Helper: load the most-recent `llm.sampling_call` audit row and
759 /// return `(result, principal_subject, details_json)`.
760 fn latest_sampling_audit_details(
761 db_path: &std::path::Path,
762 key: &KeyMaterial,
763 ) -> (String, Option<String>, serde_json::Value) {
764 let conn = open_sqlcipher(db_path, key).expect("open db");
765 let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
766 .query_row(
767 "SELECT result, principal_subject, details_json
768 FROM audit_events
769 WHERE operation = 'llm.sampling_call'
770 ORDER BY ts_ms DESC, rowid DESC
771 LIMIT 1",
772 [],
773 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
774 )
775 .expect("query");
776 let details: serde_json::Value =
777 serde_json::from_str(&details_str.expect("details_json present"))
778 .expect("parse details");
779 (result, principal, details)
780 }
781
782 /// Happy path: a successful `create_message` round-trip returns
783 /// the assistant text wrapped in a `Message::assistant`, and lands
784 /// exactly one `llm.sampling_call` audit row with `result = 'ok'`.
785 #[tokio::test]
786 async fn sampling_complete_happy_path_returns_text() {
787 let h = harness().await;
788 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
789 let client = SamplingLlmClient::with_sampling_client(
790 fake.clone(),
791 h.write_handle.clone(),
792 Some("alice".into()),
793 );
794 let messages = vec![Message::user("summarise these episodes")];
795 let result = client.complete(&messages).await.expect("ok");
796 assert_eq!(result.role, Role::Assistant);
797 assert_eq!(result.content, "derived theme");
798
799 // Exactly one audit row landed.
800 assert_eq!(count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"), 1);
801 let (result_str, principal, details) = latest_sampling_audit_details(&h.db_path, &h.key);
802 assert_eq!(result_str, "ok");
803 assert_eq!(principal.as_deref(), Some("alice"));
804 assert_eq!(details["model_hint"], "claude");
805 assert_eq!(details["model"], "fake-claude");
806 assert_eq!(details["messages_count"], 1);
807 assert_eq!(details["max_tokens"], 512);
808 }
809
810 /// Privacy invariant: the audit row's `details_json` MUST NOT
811 /// contain the raw prompt content. Pinned by string inspection of
812 /// the persisted JSON.
813 #[tokio::test]
814 async fn audit_row_omits_raw_prompt_text() {
815 let h = harness().await;
816 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
817 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
818 let secret = "THE-USER-ID-IS-bobby-1234";
819 let messages = vec![
820 Message::system("you are a friendly assistant"),
821 Message::user(secret),
822 ];
823 client.complete(&messages).await.expect("ok");
824
825 let (_, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
826 let serialised = serde_json::to_string(&details).expect("serialise details");
827 assert!(
828 !serialised.contains(secret),
829 "audit details must not carry raw prompt content; was: {serialised}"
830 );
831 assert!(
832 !serialised.contains("you are a friendly assistant"),
833 "audit details must not carry system prompt; was: {serialised}"
834 );
835 // Metadata IS present, even though the prompt is not.
836 assert_eq!(details["messages_count"], 1);
837 assert!(details["prompt_chars"].as_u64().unwrap() > 0);
838 }
839
840 /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the audit row's
841 /// `prompt_chars` MUST be the power-of-2 bucket, never the raw
842 /// character count. Pins the bucketing behavior end-to-end (raw
843 /// `audit_event` → SQLite → re-read).
844 ///
845 /// Test recipe: drive a prompt with a known raw length (6 chars
846 /// total, `"hello "` system + `"x"` user → 6+1 = 7) and assert the
847 /// audit row carries `8` (next pow2 ≥ 7), not 7.
848 #[tokio::test]
849 async fn audit_row_bucket_prompt_chars_to_pow2() {
850 let h = harness().await;
851 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
852 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
853 // System: 6 chars + user: 1 char = 7 chars raw → bucket 8.
854 client
855 .complete(&[Message::system("hello "), Message::user("x")])
856 .await
857 .expect("ok");
858 let (_, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
859 assert_eq!(
860 details["prompt_chars"].as_u64().unwrap(),
861 8,
862 "prompt_chars must be bucketed to next pow2 (7 → 8). \
863 raw count is a privacy side-channel; see Fix 4 F6 in \
864 v0.9.1 P1 dev log. got details={details}"
865 );
866 }
867
868 /// Stability invariant: two prompts that fall in the SAME bucket
869 /// must persist identical `prompt_chars`. Distinguishes "the
870 /// implementation buckets" from "the implementation hashes/leaks
871 /// raw values".
872 ///
873 /// 5 chars and 7 chars both round to 8 → must persist identically.
874 /// (Mirrors the brief's "test that bucketed values are stable
875 /// across exact-character variations within the same bucket".)
876 #[tokio::test]
877 async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
878 let h = harness().await;
879 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
880 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
881 // 5 chars raw → bucket 8.
882 client
883 .complete(&[Message::user("hello")])
884 .await
885 .expect("ok");
886 let (_, _, details_5) = latest_sampling_audit_details(&h.db_path, &h.key);
887 // 7 chars raw → bucket 8.
888 client
889 .complete(&[Message::user("hellooo")])
890 .await
891 .expect("ok");
892 let (_, _, details_7) = latest_sampling_audit_details(&h.db_path, &h.key);
893 assert_eq!(
894 details_5["prompt_chars"], details_7["prompt_chars"],
895 "5 chars and 7 chars must hash to the same bucket (8) — \
896 otherwise the bucketing is leaking raw fidelity. \
897 5-char details: {details_5}, 7-char details: {details_7}"
898 );
899 assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
900 }
901
902 /// Unit-level pins for the bucketing helper. Catches a regression
903 /// where someone "simplifies" `next_pow2_bucket` into a no-op.
904 #[test]
905 fn next_pow2_bucket_table() {
906 assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
907 assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
908 assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
909 assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
910 assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
911 assert_eq!(next_pow2_bucket(5), 8);
912 assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
913 assert_eq!(next_pow2_bucket(7), 8);
914 assert_eq!(next_pow2_bucket(8), 8);
915 assert_eq!(next_pow2_bucket(9), 16);
916 assert_eq!(next_pow2_bucket(1023), 1024);
917 assert_eq!(next_pow2_bucket(1024), 1024);
918 assert_eq!(next_pow2_bucket(1025), 2048);
919 }
920
921 /// Client refusal: maps to `CoreError::Forbidden` + audit row
922 /// `result = 'forbidden'` + `details_json.reason = 'client_refused'`.
923 #[tokio::test]
924 async fn client_refusal_returns_forbidden_and_audits() {
925 let h = harness().await;
926 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
927 fake.reject_with("user dismissed approval");
928 let client = SamplingLlmClient::with_sampling_client(
929 fake,
930 h.write_handle.clone(),
931 Some("alice".into()),
932 );
933 let err = client
934 .complete(&[Message::user("anything")])
935 .await
936 .unwrap_err();
937 match err {
938 CoreError::Forbidden(_) => {}
939 other => panic!("expected Forbidden, got {other:?}"),
940 }
941 let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
942 assert_eq!(result_str, "forbidden");
943 assert_eq!(details["reason"], "client_refused");
944 }
945
946 /// Timeout: tokio::time::timeout fires before the fake's `Slow`
947 /// response resolves; client returns `CoreError::Llm` + audit row
948 /// `result = 'error'` + `details_json.reason = 'timeout'`.
949 ///
950 /// Real wall-clock: 80ms slow response vs 30ms client timeout.
951 /// Margin is loose enough for slow CI without making the test
952 /// drag.
953 #[tokio::test]
954 async fn timeout_returns_error_with_timeout_reason() {
955 let h = harness().await;
956 let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
957 "late",
958 Duration::from_millis(800),
959 )));
960 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None)
961 .with_timeout(Duration::from_millis(30));
962 let err = client
963 .complete(&[Message::user("hello")])
964 .await
965 .unwrap_err();
966 match err {
967 CoreError::Llm(msg) => assert!(msg.contains("timeout")),
968 other => panic!("expected Llm, got {other:?}"),
969 }
970 let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
971 assert_eq!(result_str, "error");
972 assert_eq!(details["reason"], "timeout");
973 }
974
975 /// Malformed response: the fake returns a result with zero text
976 /// content blocks; client surfaces `CoreError::Llm` + audit row
977 /// `result = 'error'` + `details_json.reason = 'malformed_response'`.
978 #[tokio::test]
979 async fn malformed_response_returns_error_with_reason() {
980 let h = harness().await;
981 let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
982 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
983 let err = client.complete(&[Message::user("hi")]).await.unwrap_err();
984 assert!(matches!(err, CoreError::Llm(_)));
985 let (result_str, _, details) = latest_sampling_audit_details(&h.db_path, &h.key);
986 assert_eq!(result_str, "error");
987 assert_eq!(details["reason"], "malformed_response");
988 }
989
990 /// `principal_subject = None` works — audit row still emits with
991 /// NULL.
992 #[tokio::test]
993 async fn no_principal_emits_audit_with_null_principal() {
994 let h = harness().await;
995 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
996 let client = SamplingLlmClient::with_sampling_client(fake, h.write_handle.clone(), None);
997 client.complete(&[Message::user("hi")]).await.expect("ok");
998 let (_, principal, _) = latest_sampling_audit_details(&h.db_path, &h.key);
999 assert_eq!(principal, None);
1000 }
1001
1002 /// Concurrency: 8 parallel `complete()` calls land 8 audit rows.
1003 /// Audit IDs (autoincrement rowid) must be distinct — verifies the
1004 /// writer-actor serialises the per-call audit emit (no
1005 /// interleaving / dropped rows).
1006 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1007 async fn parallel_completes_serialise_audit_rows() {
1008 let h = harness().await;
1009 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1010 let client = SamplingLlmClient::with_sampling_client(
1011 fake.clone(),
1012 h.write_handle.clone(),
1013 Some("alice".into()),
1014 );
1015 let mut futs = Vec::new();
1016 for _ in 0..8 {
1017 let c = client.clone();
1018 futs.push(tokio::spawn(async move {
1019 c.complete(&[Message::user("hi")]).await
1020 }));
1021 }
1022 for f in futs {
1023 f.await.expect("join").expect("ok");
1024 }
1025 assert_eq!(
1026 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1027 8,
1028 "8 parallel calls must land 8 audit rows"
1029 );
1030
1031 // Each was a separate request to the fake.
1032 assert_eq!(fake.record_requests().len(), 8);
1033 }
1034
1035 /// `complete` translates the workspace's `Message::system` into the
1036 /// `system_prompt` top-level field; user/assistant roles map to
1037 /// rmcp's `SamplingMessage::user_text` / `assistant_text`.
1038 #[tokio::test]
1039 async fn build_request_splits_system_from_messages() {
1040 let h = harness().await;
1041 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1042 let client =
1043 SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None);
1044 client
1045 .complete(&[
1046 Message::system("be terse"),
1047 Message::user("question"),
1048 Message::assistant("answer"),
1049 ])
1050 .await
1051 .expect("ok");
1052 let recorded = fake.record_requests();
1053 assert_eq!(recorded.len(), 1);
1054 let req = &recorded[0];
1055 assert_eq!(
1056 req.system_prompt.as_deref(),
1057 Some("be terse"),
1058 "Role::System must map to system_prompt"
1059 );
1060 assert_eq!(req.messages.len(), 2);
1061 // The remaining two messages are the user + assistant turns.
1062 assert_eq!(req.messages[0].role, RmcpRole::User);
1063 assert_eq!(req.messages[1].role, RmcpRole::Assistant);
1064 }
1065
1066 /// `model_preferences` carries the `claude` hint per plan §6.
1067 /// Pins the wire shape so a future change is a conscious decision.
1068 #[tokio::test]
1069 async fn build_request_includes_claude_model_hint() {
1070 let h = harness().await;
1071 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1072 let client =
1073 SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None);
1074 client.complete(&[Message::user("hi")]).await.expect("ok");
1075 let recorded = fake.record_requests();
1076 let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
1077 let hint = prefs
1078 .hints
1079 .as_ref()
1080 .and_then(|h| h.first())
1081 .and_then(|h| h.name.clone())
1082 .expect("hint name");
1083 assert_eq!(hint, "claude");
1084 }
1085
1086 /// `with_max_tokens(n)` propagates to the request's
1087 /// `max_tokens` field.
1088 #[tokio::test]
1089 async fn with_max_tokens_overrides_default() {
1090 let h = harness().await;
1091 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1092 let client =
1093 SamplingLlmClient::with_sampling_client(fake.clone(), h.write_handle.clone(), None)
1094 .with_max_tokens(2048);
1095 client.complete(&[Message::user("hi")]).await.expect("ok");
1096 let recorded = fake.record_requests();
1097 assert_eq!(recorded[0].max_tokens, 2048);
1098 }
1099
1100 /// Reconfiguring the fake mid-test produces distinct audit rows
1101 /// for each call (positive then negative).
1102 #[tokio::test]
1103 async fn reconfigurable_fake_distinguishes_audit_rows() {
1104 let h = harness().await;
1105 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1106 let client = SamplingLlmClient::with_sampling_client(
1107 fake.clone(),
1108 h.write_handle.clone(),
1109 Some("alice".into()),
1110 );
1111
1112 client.complete(&[Message::user("a")]).await.expect("ok");
1113 fake.reject_with("user said no");
1114 let _ = client.complete(&[Message::user("b")]).await;
1115
1116 let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1117 let mut stmt = conn
1118 .prepare(
1119 "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1120 )
1121 .expect("prepare");
1122 let rows: Vec<String> = stmt
1123 .query_map([], |r| r.get::<_, String>(0))
1124 .expect("query")
1125 .map(|r| r.expect("row"))
1126 .collect();
1127 assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1128 }
1129
1130 /// `extract_text` walks single-block content.
1131 #[test]
1132 fn extract_text_pulls_text_from_single_block() {
1133 let result =
1134 CreateMessageResult::new(SamplingMessage::assistant_text("hello"), "fake".into());
1135 assert_eq!(extract_text(&result).unwrap(), "hello");
1136 }
1137
1138 /// `extract_text` rejects an empty-content response.
1139 #[test]
1140 fn extract_text_rejects_empty_content() {
1141 let result = CreateMessageResult::new(
1142 SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1143 "fake".into(),
1144 );
1145 assert!(extract_text(&result).is_err());
1146 }
1147
1148 /// `extract_text` rejects a User-role response (impossible per
1149 /// spec but pinning the defensive check).
1150 #[test]
1151 fn extract_text_rejects_non_assistant_role() {
1152 let result = CreateMessageResult::new(SamplingMessage::user_text("hello"), "fake".into());
1153 assert!(extract_text(&result).is_err());
1154 }
1155
1156 // ---- v0.10.1 F5 audit-minor closure: pin multi-block concat
1157 // semantics (deferred from v0.9.0 P2). ----
1158 //
1159 // `extract_text` walks `result.message.content` and pushes a `\n`
1160 // between successive `SamplingMessageContent::Text` blocks. The
1161 // exact semantics are wire-shape decisions that downstream Steward
1162 // parsers depend on; we pin them here so a future refactor can't
1163 // drift silently. The audit minor (v0.9.0 P2 §F5) flagged that the
1164 // existing tests only exercised single-block + empty + non-
1165 // assistant cases — not the multi-block join behavior. Closure
1166 // checklist:
1167 //
1168 // 1. join inserts ONE `\n` between non-trailing-newline blocks
1169 // 2. trailing `\n` in a block is preserved (so a `"abc\n"` block
1170 // + a `"def"` block produces `"abc\n\ndef"`)
1171 // 3. single block returns verbatim (no leading/trailing newline
1172 // added)
1173 //
1174 // These tests are intentionally identical-output to what the code
1175 // produces today. They're a regression net, not a behavior change.
1176 //
1177 // Grep terms: F5, extract_text_joins_multi_block, extract_text_preserves_trailing_newlines.
1178
1179 /// F5 pin: a two-block message joins with exactly ONE `\n` between
1180 /// blocks when neither block ends in a newline.
1181 #[test]
1182 fn extract_text_joins_multi_block_with_newline_separator() {
1183 let blocks = vec![
1184 SamplingMessageContent::text("abc"),
1185 SamplingMessageContent::text("def"),
1186 ];
1187 let result = CreateMessageResult::new(
1188 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1189 "fake".into(),
1190 );
1191 // Exact value pinned: "abc" + "\n" + "def".
1192 assert_eq!(
1193 extract_text(&result).unwrap(),
1194 "abc\ndef",
1195 "two non-newline-terminated blocks must join with a single newline"
1196 );
1197 }
1198
1199 /// F5 pin: a trailing newline in a content block is preserved
1200 /// verbatim AND a join newline is still inserted, so the result
1201 /// contains `\n\n` between such a block and the next. This is the
1202 /// "honest current behavior" pin — a future refactor that strips
1203 /// trailing newlines must explicitly update this test.
1204 #[test]
1205 fn extract_text_preserves_trailing_newlines_in_blocks() {
1206 let blocks = vec![
1207 SamplingMessageContent::text("abc\n"),
1208 SamplingMessageContent::text("def"),
1209 ];
1210 let result = CreateMessageResult::new(
1211 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1212 "fake".into(),
1213 );
1214 assert_eq!(
1215 extract_text(&result).unwrap(),
1216 "abc\n\ndef",
1217 "trailing newline in block 1 + join newline => '\\n\\n' between blocks"
1218 );
1219 }
1220
1221 /// F5 pin: a single block returns verbatim — no leading or
1222 /// trailing newline added by the helper. Already covered by
1223 /// `extract_text_pulls_text_from_single_block` for the "hello"
1224 /// case; this variant pins the negative space for inputs with
1225 /// internal newlines (the helper does NOT mutate inner whitespace).
1226 #[test]
1227 fn extract_text_single_block_returns_verbatim_including_inner_newlines() {
1228 let blocks = vec![SamplingMessageContent::text("line1\nline2")];
1229 let result = CreateMessageResult::new(
1230 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1231 "fake".into(),
1232 );
1233 assert_eq!(
1234 extract_text(&result).unwrap(),
1235 "line1\nline2",
1236 "single block must return verbatim, no extra newlines added"
1237 );
1238 }
1239
1240 /// F5 pin: three blocks, second is empty-string. The first block
1241 /// emits its text + `\n`; the empty middle pushes nothing but
1242 /// `out.is_empty()` is false so the NEXT iteration's pre-newline
1243 /// fires again. Result: `"a\n\nb"`. This pins the "empty middle
1244 /// block adds a blank line" semantic — surprising at first read
1245 /// but consistent with the join-between-non-empty rule applied
1246 /// uniformly to every iteration.
1247 #[test]
1248 fn extract_text_empty_middle_block_inserts_blank_line() {
1249 let blocks = vec![
1250 SamplingMessageContent::text("a"),
1251 SamplingMessageContent::text(""),
1252 SamplingMessageContent::text("b"),
1253 ];
1254 let result = CreateMessageResult::new(
1255 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1256 "fake".into(),
1257 );
1258 // The implementation's exact behavior: after "a" is pushed,
1259 // out = "a". For block 2 (empty): out.is_empty() is false, so
1260 // push '\n' → out = "a\n"; then push_str("") → out = "a\n".
1261 // For block 3 ("b"): out.is_empty() is false, so push '\n' →
1262 // out = "a\n\n"; then push_str("b") → out = "a\n\nb".
1263 assert_eq!(
1264 extract_text(&result).unwrap(),
1265 "a\n\nb",
1266 "empty middle block leaves a blank line between non-empty blocks"
1267 );
1268 }
1269
1270 /// `SamplingError::classify` maps each fake variant to the right
1271 /// audit category.
1272 #[test]
1273 fn sampling_error_classify_maps_fake_variants() {
1274 let refused = SamplingError::Fake(FakeSamplingError::Refused { reason: "x".into() });
1275 let (cat, forb) = refused.classify();
1276 assert_eq!(cat, "client_refused");
1277 assert!(forb);
1278
1279 let transport = SamplingError::Fake(FakeSamplingError::Transport {
1280 message: "x".into(),
1281 });
1282 let (cat, forb) = transport.classify();
1283 assert_eq!(cat, "transport_error");
1284 assert!(!forb);
1285
1286 let malformed = SamplingError::Fake(FakeSamplingError::MalformedResponse {
1287 message: "x".into(),
1288 });
1289 let (cat, forb) = malformed.classify();
1290 assert_eq!(cat, "malformed_response");
1291 assert!(!forb);
1292 }
1293
1294 // -------- v0.9.0 P5a (M3 wiring) — SamplingCoordinator integration --------
1295 //
1296 // These tests pin the contract that `build_sampling_steward` wraps the
1297 // live peer in a `SamplingCoordinator` before handing it to
1298 // `SamplingLlmClient`. They cannot call `build_sampling_steward`
1299 // directly (it takes a real `Peer<RoleServer>` whose constructors are
1300 // private inside rmcp), but they exercise the **exact same wiring
1301 // shape** by substituting `FakeMcpClient` for `PeerSamplingClient`.
1302 // The production code path is:
1303 //
1304 // PeerSamplingClient -> SamplingCoordinator -> SamplingLlmClient
1305 //
1306 // The tested shape is:
1307 //
1308 // FakeMcpClient -> SamplingCoordinator -> SamplingLlmClient
1309 //
1310 // Only the leaf `SamplingClient` impl differs; the
1311 // `SamplingClient` trait is the same Arc-of-dyn in both paths.
1312
1313 /// SamplingCoordinator wrapping a `FakeMcpClient` and feeding
1314 /// `SamplingLlmClient::with_sampling_client` is the same Arc-of-dyn
1315 /// shape `build_sampling_steward` constructs at MCP-initialize
1316 /// time. Single-element flushes pass through unwrapped, so a lone
1317 /// `complete()` call still emits one audit row and produces the
1318 /// expected text.
1319 #[tokio::test]
1320 async fn sampling_llm_client_uses_coordinator_in_production_path() {
1321 let h = harness().await;
1322 let fake: Arc<dyn SamplingClient> = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1323 let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1324 fake.clone(),
1325 Duration::from_millis(50),
1326 10,
1327 );
1328 let client = SamplingLlmClient::with_sampling_client(
1329 coord,
1330 h.write_handle.clone(),
1331 Some("alice".into()),
1332 );
1333 let result = client.complete(&[Message::user("test")]).await.expect("ok");
1334 assert_eq!(result.role, Role::Assistant);
1335 assert_eq!(result.content, "ok");
1336 // Single audit row landed — per-call audit semantics
1337 // unchanged by the coordinator wrap.
1338 assert_eq!(
1339 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1340 1,
1341 "one logical call → one audit row, even through coordinator"
1342 );
1343 }
1344
1345 /// End-to-end batching pin: N concurrent `complete()` calls within
1346 /// the coalesce window resolve as ONE inner `create_message` RPC
1347 /// on the underlying `FakeMcpClient`, but N audit rows still land
1348 /// (one per logical call — the privacy + audit invariants from P2
1349 /// hold).
1350 ///
1351 /// This is the v0.9.0 release notes' "⌈N/M⌉ peer.create_message
1352 /// calls per coalesce window" claim, exercised through the same
1353 /// trait-object chain that `build_sampling_steward` constructs in
1354 /// production.
1355 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1356 async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1357 // Coalesced JSON response for 5 tasks — matches the
1358 // `[{task_index, response}]` shape `flush_batch` demuxes
1359 // multi-element batches into.
1360 let response = serde_json::to_string(
1361 &(0..5)
1362 .map(|i| {
1363 serde_json::json!({
1364 "task_index": i,
1365 "response": format!("response-{i}"),
1366 })
1367 })
1368 .collect::<Vec<_>>(),
1369 )
1370 .unwrap();
1371
1372 let h = harness().await;
1373 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1374 let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1375 fake.clone(),
1376 // Wide window so all 5 submissions land in one batch.
1377 Duration::from_secs(5),
1378 10,
1379 );
1380 let client = SamplingLlmClient::with_sampling_client(
1381 coord,
1382 h.write_handle.clone(),
1383 Some("alice".into()),
1384 );
1385
1386 // Fire 5 concurrent `complete()` calls; the coordinator should
1387 // coalesce them into ONE `FakeMcpClient::create_message` call.
1388 let mut futs = Vec::new();
1389 for i in 0..5 {
1390 let c = client.clone();
1391 futs.push(tokio::spawn(async move {
1392 c.complete(&[Message::user(format!("task-{i}"))]).await
1393 }));
1394 }
1395 for f in futs {
1396 f.await.expect("join").expect("ok");
1397 }
1398
1399 // EXACTLY one inner RPC.
1400 assert_eq!(
1401 fake.record_requests().len(),
1402 1,
1403 "5 logical calls within window must coalesce to 1 inner RPC"
1404 );
1405 // BUT 5 audit rows — per-logical-call audit invariant preserved.
1406 assert_eq!(
1407 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1408 5,
1409 "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1410 );
1411 }
1412
1413 /// Edge case: `coalesce_max_requests = 1` reduces the coordinator
1414 /// to pass-through (each submit flushes a 1-element batch
1415 /// immediately). With max_batch=1 and a wide window, 3 concurrent
1416 /// calls land 3 inner RPCs — coordinator is operating as if no
1417 /// batching were configured.
1418 ///
1419 /// Pins the brief's documented edge-case: zero / one-valued config
1420 /// reduces to pass-through, never panics or deadlocks. Mirrors
1421 /// `SamplingCoordinator::with_settings`'s `max_batch.max(1)`
1422 /// clamping for the `coalesce_max_requests = 0` case.
1423 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1424 async fn coordinator_max_batch_one_acts_as_passthrough() {
1425 let h = harness().await;
1426 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1427 let coord: Arc<dyn SamplingClient> = super::super::SamplingCoordinator::with_settings(
1428 fake.clone(),
1429 Duration::from_secs(5),
1430 // max_batch=1 → every submission flushes immediately as
1431 // a 1-element batch; pass-through behaviour.
1432 1,
1433 );
1434 let client = SamplingLlmClient::with_sampling_client(coord, h.write_handle.clone(), None);
1435 let mut futs = Vec::new();
1436 for _ in 0..3 {
1437 let c = client.clone();
1438 futs.push(tokio::spawn(async move {
1439 c.complete(&[Message::user("hi")]).await
1440 }));
1441 }
1442 for f in futs {
1443 f.await.expect("join").expect("ok");
1444 }
1445 // 3 logical calls → 3 inner RPCs (no coalescing).
1446 assert_eq!(
1447 fake.record_requests().len(),
1448 3,
1449 "max_batch=1 must pass through every submission as its own RPC"
1450 );
1451 assert_eq!(count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"), 3);
1452 }
1453}