solo_api/llm/sampling.rs
1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingLlmClient`] — `LlmClient` impl backed by an MCP client's
4//! `sampling/createMessage` capability.
5//!
6//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
7//! §6 "Sampling-backed LLM client" / MAJOR 1 + MAJOR 3 resolutions):
8//!
9//! * Steward holds an `Arc<dyn LlmClient>`. When `LlmConfig::McpSampling`
10//! is configured, the Steward's `LlmClient` is a `SamplingLlmClient`
11//! constructed at MCP `initialize` time (when the live peer becomes
12//! available — the `TenantHandle::steward_slot` LATE-population path).
13//!
14//! * `SamplingLlmClient::complete()` translates the workspace's
15//! `Message` → `rmcp::SamplingMessage`, calls
16//! `peer.create_message(params).await`, extracts the assistant's
17//! text from the returned `CreateMessageResult`, and emits a
18//! per-call `AuditOperation::LlmSamplingCall` row through the
19//! tenant's `WriteHandle` (lesson #30: sync in writer-actor tx
20//! for ACID).
21//!
22//! * **Privacy invariant**: the audit `details_json` carries metadata
23//! only — model hint, message count, max_tokens, duration_ms,
24//! total prompt character count, output character count. **The raw
25//! prompt content MUST NOT appear in the audit row**. Pinned by
26//! [`tests::audit_row_omits_raw_prompt_text`].
27//!
28//! * Error paths land structured audit rows:
29//! - Client refusal → `result = "forbidden"`,
30//! `details_json.reason = "client_refused"`.
31//! - Timeout → `result = "error"`,
32//! `details_json.reason = "timeout"`.
33//! - Other transport / malformed-response → `result = "error"`,
34//! `details_json.reason = <category>`.
35//!
36//! * Per-call rate-limit / coalescing is **deferred to v0.9.0 P4**
37//! (`SamplingCoordinator`). P2 wires the per-call path only.
38
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44 CreateMessageRequestParams, CreateMessageResult, ModelHint,
45 ModelPreferences, Role as RmcpRole, SamplingMessage,
46 SamplingMessageContent,
47};
48use rmcp::service::{Peer, RoleServer, ServiceError};
49use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
50use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
51
52/// Default per-call timeout. Drives the bounded wait around
53/// `peer.create_message`; if the client refuses or stalls, the caller
54/// sees a structured timeout error instead of an indefinite hang.
55///
56/// 30 seconds matches the consolidate-timer's cadence margins: an
57/// LLM call slower than this would already starve the Steward batch
58/// in P4's coordinator. Configurable per-construct via
59/// [`SamplingLlmClient::with_timeout`].
60pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
61
62/// Default max_tokens for sampling completions. Matches
63/// `solo-steward::StewardConfig::default().abstraction_max_tokens`
64/// so the wire shape is identical to what the Steward would have
65/// requested from any other backend.
66const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
67
68/// Error surface for [`SamplingClient::create_message`]. Combines the
69/// real rmcp `ServiceError` (when wrapping a live `Peer<RoleServer>`)
70/// with [`super::super::test_support::fake_mcp_client::FakeSamplingError`]
71/// (when driving the fixture from tests).
72#[derive(Debug)]
73pub enum SamplingError {
74 /// Forwarded from `rmcp::Peer::create_message`.
75 Service(ServiceError),
76 /// Routed from [`super::super::test_support::fake_mcp_client::
77 /// FakeSamplingError`] in test paths.
78 #[cfg(any(test, feature = "test-support"))]
79 Fake(crate::test_support::FakeSamplingError),
80}
81
82impl std::fmt::Display for SamplingError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 Self::Service(e) => write!(f, "{e}"),
86 #[cfg(any(test, feature = "test-support"))]
87 Self::Fake(e) => write!(f, "{e}"),
88 }
89 }
90}
91
92impl std::error::Error for SamplingError {}
93
94impl SamplingError {
95 /// Classifier used by [`SamplingLlmClient::complete`] to map the
96 /// transport-level error to an audit-row category + a Solo
97 /// [`CoreError`] variant.
98 ///
99 /// `(category_for_audit, treat_as_forbidden)` — `forbidden` becomes
100 /// `AuditResult::Forbidden` + `CoreError::Forbidden`; everything
101 /// else is `AuditResult::Error` + `CoreError::Llm`.
102 pub fn classify(&self) -> (&'static str, bool) {
103 match self {
104 Self::Service(_) => ("transport_error", false),
105 #[cfg(any(test, feature = "test-support"))]
106 Self::Fake(e) => match e {
107 crate::test_support::FakeSamplingError::Refused { .. } => {
108 ("client_refused", true)
109 }
110 crate::test_support::FakeSamplingError::Transport { .. } => {
111 ("transport_error", false)
112 }
113 crate::test_support::FakeSamplingError::MalformedResponse {
114 ..
115 } => ("malformed_response", false),
116 },
117 }
118 }
119}
120
121/// Trait abstracting the `sampling/createMessage` RPC. The production
122/// impl wraps `Arc<Peer<RoleServer>>`; the test impl is
123/// [`super::super::test_support::fake_mcp_client::FakeMcpClient`].
124///
125/// Separating the trait from the concrete `Peer<RoleServer>` is the
126/// way around rmcp's `Peer` having private constructors — we can't
127/// build a fake `Peer` for tests, so we inject behind a trait.
128#[async_trait]
129pub trait SamplingClient: Send + Sync {
130 async fn create_message(
131 &self,
132 params: CreateMessageRequestParams,
133 ) -> Result<CreateMessageResult, SamplingError>;
134}
135
136/// Production wrapper around `rmcp::Peer<RoleServer>`. The Peer is
137/// cheap to clone (internally `Arc`-backed) and stays valid for the
138/// lifetime of the MCP session.
139pub struct PeerSamplingClient {
140 peer: Peer<RoleServer>,
141}
142
143impl PeerSamplingClient {
144 pub fn new(peer: Peer<RoleServer>) -> Self {
145 Self { peer }
146 }
147}
148
149#[async_trait]
150impl SamplingClient for PeerSamplingClient {
151 async fn create_message(
152 &self,
153 params: CreateMessageRequestParams,
154 ) -> Result<CreateMessageResult, SamplingError> {
155 self.peer
156 .create_message(params)
157 .await
158 .map_err(SamplingError::Service)
159 }
160}
161
162/// `LlmClient` impl whose `complete()` calls back via the connected
163/// MCP client's sampling capability.
164///
165/// Construct via [`SamplingLlmClient::new`] (production path: wraps a
166/// real `Peer<RoleServer>`) or [`SamplingLlmClient::with_sampling_client`]
167/// (test path: takes the abstracted [`SamplingClient`] trait object
168/// directly so [`super::super::test_support::fake_mcp_client::
169/// FakeMcpClient`] can drive it).
170///
171/// Cheap to clone — every field is `Arc`-shared.
172#[derive(Clone)]
173pub struct SamplingLlmClient {
174 /// The RPC channel back to the MCP client.
175 sampling_client: Arc<dyn SamplingClient>,
176 /// Per-tenant `WriteHandle` for the synchronous audit emit. Routes
177 /// through the writer-actor's mpsc so the
178 /// `AuditOperation::LlmSamplingCall` INSERT lands in a dedicated
179 /// `BEGIN IMMEDIATE` transaction on the writer's connection.
180 write_handle: WriteHandle,
181 /// Cached audit `principal_subject` for the MCP session. Resolved
182 /// at session init time (see `mcp::resolve_mcp_principal`); `None`
183 /// for unauthenticated stdio sessions.
184 audit_principal: Option<String>,
185 /// `max_tokens` value sent on every `sampling/createMessage`.
186 /// Defaults to [`DEFAULT_SAMPLING_MAX_TOKENS`]; configurable via
187 /// [`Self::with_max_tokens`].
188 max_tokens: u32,
189 /// Bounded wait on `create_message`. See
190 /// [`DEFAULT_SAMPLING_TIMEOUT`].
191 timeout: Duration,
192}
193
194impl SamplingLlmClient {
195 /// Build a client wrapping a real `Peer<RoleServer>`. Production
196 /// path — called from
197 /// [`crate::mcp::SoloMcpServer::populate_sampling_steward`] when an MCP
198 /// session reaches `initialize` with a sampling-capable peer.
199 pub fn new(
200 peer: Peer<RoleServer>,
201 write_handle: WriteHandle,
202 audit_principal: Option<String>,
203 ) -> Self {
204 Self::with_sampling_client(
205 Arc::new(PeerSamplingClient::new(peer)),
206 write_handle,
207 audit_principal,
208 )
209 }
210
211 /// Test-friendly constructor accepting any [`SamplingClient`]
212 /// implementation. Pair with
213 /// [`super::super::test_support::fake_mcp_client::FakeMcpClient`]
214 /// in tests.
215 pub fn with_sampling_client(
216 sampling_client: Arc<dyn SamplingClient>,
217 write_handle: WriteHandle,
218 audit_principal: Option<String>,
219 ) -> Self {
220 Self {
221 sampling_client,
222 write_handle,
223 audit_principal,
224 max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
225 timeout: DEFAULT_SAMPLING_TIMEOUT,
226 }
227 }
228
229 /// Override the per-call `max_tokens` cap.
230 pub fn with_max_tokens(mut self, n: u32) -> Self {
231 self.max_tokens = n.max(1);
232 self
233 }
234
235 /// Override the per-call timeout.
236 pub fn with_timeout(mut self, t: Duration) -> Self {
237 self.timeout = t;
238 self
239 }
240
241 /// Build the `CreateMessageRequestParams` from Solo's `Message`
242 /// vec. Splits out `Role::System` into the `system_prompt` field
243 /// (rmcp's `SamplingMessage::role` is only User / Assistant) and
244 /// hints the user's MCP client toward a Claude-class model.
245 fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
246 // Split system messages out of the conversation history; the
247 // sampling protocol carries the system prompt as a top-level
248 // field rather than inline.
249 let mut system_parts: Vec<String> = Vec::new();
250 let mut samp_messages: Vec<SamplingMessage> = Vec::new();
251 for m in messages {
252 match m.role {
253 Role::System => system_parts.push(m.content.clone()),
254 Role::User => {
255 samp_messages.push(SamplingMessage::user_text(&m.content));
256 }
257 Role::Assistant => {
258 samp_messages
259 .push(SamplingMessage::assistant_text(&m.content));
260 }
261 }
262 }
263 // rmcp 1.7's struct literals are non-exhaustive across crate
264 // boundaries; build via the typed constructors + builders.
265 let preferences = ModelPreferences::new()
266 .with_hints(vec![ModelHint::new("claude")])
267 .with_intelligence_priority(0.7)
268 .with_speed_priority(0.3)
269 .with_cost_priority(0.4);
270 let mut params =
271 CreateMessageRequestParams::new(samp_messages, self.max_tokens)
272 .with_model_preferences(preferences);
273 if !system_parts.is_empty() {
274 params = params.with_system_prompt(system_parts.join("\n\n"));
275 }
276 params
277 }
278
279 /// Build the audit `AuditEvent` carrying ONLY metadata. No raw
280 /// prompt content lands in `details_json`.
281 ///
282 /// Pinned by [`tests::audit_row_omits_raw_prompt_text`].
283 ///
284 /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the raw character count
285 /// of the prompt is itself a side-channel — a 6-char prompt
286 /// uniquely identifies very-short refusal paths (e.g. a leaked
287 /// password length). `prompt_chars` and `input_tokens_est` are
288 /// rounded up to the next power of two before persistence. This
289 /// preserves operator capacity-planning (the bucket is within ~2x
290 /// of the real size for any sufficiently large prompt) while
291 /// removing the per-character precision.
292 ///
293 /// Buckets: `0, 1, 2, 4, 8, 16, 32, 64, ..., 1024, 2048, ...`
294 /// (next power of two `>= n`). 0 stays 0.
295 fn audit_event(
296 &self,
297 params: &CreateMessageRequestParams,
298 outcome: SamplingOutcome,
299 ) -> AuditEvent {
300 let raw_prompt_chars: usize = params
301 .messages
302 .iter()
303 .flat_map(|m| m.content.iter())
304 .filter_map(|c| c.as_text().map(|t| t.text.len()))
305 .sum::<usize>()
306 + params
307 .system_prompt
308 .as_ref()
309 .map(|s| s.len())
310 .unwrap_or(0);
311 // v0.9.1 P1 Fix 4: bucket the raw count to the next power of
312 // two. Pinned by `tests::audit_row_bucket_prompt_chars_to_pow2`.
313 let prompt_chars = next_pow2_bucket(raw_prompt_chars);
314 // ~4 chars per token for the rough English-text estimate used
315 // by `solo doctor --check-llm` and Anthropic's docs. Recorded
316 // for operator capacity-planning. Bucketed for the same
317 // privacy reason — and to stay consistent with `prompt_chars`.
318 let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
319 let model_hint = params
320 .model_preferences
321 .as_ref()
322 .and_then(|p| p.hints.as_ref())
323 .and_then(|h| h.first())
324 .and_then(|h| h.name.clone())
325 .unwrap_or_else(|| "(none)".to_string());
326
327 let mut details = serde_json::Map::new();
328 details.insert(
329 "model_hint".to_string(),
330 serde_json::Value::String(model_hint),
331 );
332 details.insert(
333 "messages_count".to_string(),
334 serde_json::Value::Number(params.messages.len().into()),
335 );
336 details.insert(
337 "max_tokens".to_string(),
338 serde_json::Value::Number(params.max_tokens.into()),
339 );
340 details.insert(
341 "prompt_chars".to_string(),
342 serde_json::Value::Number(prompt_chars.into()),
343 );
344 details.insert(
345 "input_tokens_est".to_string(),
346 serde_json::Value::Number(input_tokens_est.into()),
347 );
348
349 let result = match &outcome {
350 SamplingOutcome::Ok {
351 duration_ms,
352 model,
353 output_chars,
354 } => {
355 // v0.9.1 P1 Fix 4: same power-of-2 bucketing as
356 // `prompt_chars` for the output side. A model that
357 // always replies with a one-token refusal (e.g. an
358 // assistant trained to say "no.") would otherwise leak
359 // the response-length distribution; bucketing
360 // collapses 1-2-3-4 chars all into bucket 4.
361 let bucketed_output_chars = next_pow2_bucket(*output_chars);
362 let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
363 details.insert(
364 "duration_ms".to_string(),
365 serde_json::Value::Number((*duration_ms).into()),
366 );
367 details.insert(
368 "model".to_string(),
369 serde_json::Value::String(model.clone()),
370 );
371 details.insert(
372 "output_chars".to_string(),
373 serde_json::Value::Number(bucketed_output_chars.into()),
374 );
375 details.insert(
376 "output_tokens_est".to_string(),
377 serde_json::Value::Number(output_tokens_est.into()),
378 );
379 AuditResult::Ok
380 }
381 SamplingOutcome::Forbidden {
382 reason,
383 duration_ms,
384 } => {
385 details.insert(
386 "duration_ms".to_string(),
387 serde_json::Value::Number((*duration_ms).into()),
388 );
389 details.insert(
390 "reason".to_string(),
391 serde_json::Value::String(reason.to_string()),
392 );
393 AuditResult::Forbidden
394 }
395 SamplingOutcome::Error {
396 reason,
397 duration_ms,
398 } => {
399 details.insert(
400 "duration_ms".to_string(),
401 serde_json::Value::Number((*duration_ms).into()),
402 );
403 details.insert(
404 "reason".to_string(),
405 serde_json::Value::String(reason.to_string()),
406 );
407 AuditResult::Error
408 }
409 };
410
411 AuditEvent {
412 ts_ms: chrono::Utc::now().timestamp_millis(),
413 principal_subject: self.audit_principal.clone(),
414 operation: AuditOperation::LlmSamplingCall,
415 target_id: None,
416 result,
417 details: Some(serde_json::Value::Object(details)),
418 }
419 }
420}
421
422/// Internal outcome category for the audit-row builder.
423enum SamplingOutcome {
424 Ok {
425 duration_ms: u64,
426 model: String,
427 output_chars: usize,
428 },
429 Forbidden {
430 reason: &'static str,
431 duration_ms: u64,
432 },
433 Error {
434 reason: &'static str,
435 duration_ms: u64,
436 },
437}
438
439#[async_trait]
440impl LlmClient for SamplingLlmClient {
441 fn name(&self) -> &str {
442 "mcp-sampling"
443 }
444
445 async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
446 let params = self.build_request(messages);
447 let start = Instant::now();
448
449 // Bounded wait on `peer.create_message`. The fold of (rmcp
450 // ServiceError | FakeError | tokio timeout) into the
451 // `Outcome` enum keeps the audit path single-sourced.
452 let rpc = tokio::time::timeout(
453 self.timeout,
454 self.sampling_client.create_message(params.clone()),
455 )
456 .await;
457 let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
458 as u64;
459
460 let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
461 match rpc {
462 Ok(Ok(result)) => {
463 match extract_text(&result) {
464 Ok(text) => {
465 let output_chars = text.len();
466 let outcome = SamplingOutcome::Ok {
467 duration_ms,
468 model: result.model.clone(),
469 output_chars,
470 };
471 (Ok(Message::assistant(text)), outcome)
472 }
473 Err(reason) => (
474 Err(CoreError::llm(format!(
475 "mcp sampling: malformed response: {reason}"
476 ))),
477 SamplingOutcome::Error {
478 reason: "malformed_response",
479 duration_ms,
480 },
481 ),
482 }
483 }
484 Ok(Err(e)) => {
485 let (category, is_forbidden) = e.classify();
486 let outcome = if is_forbidden {
487 SamplingOutcome::Forbidden {
488 reason: category,
489 duration_ms,
490 }
491 } else {
492 SamplingOutcome::Error {
493 reason: category,
494 duration_ms,
495 }
496 };
497 let err = if is_forbidden {
498 CoreError::forbidden(format!("mcp sampling: {e}"))
499 } else {
500 CoreError::llm(format!("mcp sampling: {e}"))
501 };
502 (Err(err), outcome)
503 }
504 Err(_elapsed) => (
505 Err(CoreError::llm(format!(
506 "mcp sampling: timeout after {}ms",
507 duration_ms
508 ))),
509 SamplingOutcome::Error {
510 reason: "timeout",
511 duration_ms,
512 },
513 ),
514 };
515
516 // Synchronous audit emit through the writer-actor (lesson #30).
517 // Failure to land the audit row is operator-visible: the
518 // sampling call's caller sees the storage error and can decide
519 // whether to abort (we DO abort here — without the audit row
520 // we have no record of the call).
521 //
522 // v0.9.1 P1 Fix 3 (F4 Result-shadowing): when BOTH `core_result`
523 // is `Err(..)` AND the audit emit also fails, return the
524 // ORIGINAL LLM-side error (more actionable for callers — they
525 // can retry the LLM call, or decide whether the upstream
526 // refusal/timeout is recoverable). Surface the audit failure
527 // via `tracing::error!` for operator visibility — operators
528 // alarming on storage errors see it; callers see the actionable
529 // error.
530 //
531 // Policy summary:
532 // * RPC Ok + audit Ok → return Ok(text)
533 // * RPC Ok + audit Err → return Err(storage) [audit failure
534 // wins — no undocumented sampling
535 // calls per lesson #30]
536 // * RPC Err + audit Ok → return Err(llm/forbidden) [unchanged]
537 // * RPC Err + audit Err → return Err(llm/forbidden) AND log
538 // audit failure at error level
539 // [v0.9.1 P1 Fix 3]
540 let event = self.audit_event(¶ms, outcome);
541 match (
542 core_result,
543 self.write_handle.emit_llm_sampling_audit(event).await,
544 ) {
545 (Ok(text), Ok(())) => Ok(text),
546 (Ok(_text), Err(audit_err)) => {
547 // RPC succeeded but the audit row didn't land. Drop
548 // the success — without a durable audit row we can't
549 // honor the "every sampling call leaves a trace"
550 // invariant.
551 Err(CoreError::storage(format!(
552 "mcp sampling: audit emit failed: {audit_err}"
553 )))
554 }
555 (Err(core_err), Ok(())) => Err(core_err),
556 (Err(core_err), Err(audit_err)) => {
557 // Both failed. Return the LLM-side error (the caller's
558 // most actionable signal); log the audit failure so an
559 // operator who alarms on storage errors still sees it.
560 tracing::error!(
561 audit_error = %audit_err,
562 core_error = %core_err,
563 "mcp sampling: audit emit failed alongside core \
564 error; surfacing core error to caller"
565 );
566 Err(core_err)
567 }
568 }
569 }
570}
571
572/// Round `n` up to the next power of two. Used to bucket
573/// `prompt_chars` / `output_chars` / `*_tokens_est` in the
574/// `LlmSamplingCall` audit row's `details_json` (v0.9.1 P1 Fix 4
575/// "F6" — `prompt_chars` was a privacy side-channel for short
576/// prompts).
577///
578/// Buckets: `0 → 0`, `1 → 1`, `2 → 2`, `3 → 4`, `4 → 4`, `5..=8 → 8`,
579/// `9..=16 → 16`, `17..=32 → 32`, ... — within a bucket all values
580/// collapse to the same persisted number. The worst-case fidelity
581/// loss is just under 2x (e.g. 9 chars persists as 16) which is well
582/// within the precision capacity-planning needs.
583///
584/// Pinned by [`tests::next_pow2_bucket_*`] and
585/// [`tests::audit_row_bucket_prompt_chars_to_pow2`].
586fn next_pow2_bucket(n: usize) -> usize {
587 if n == 0 {
588 return 0;
589 }
590 // `next_power_of_two` saturates at `usize::MAX` if `n` is past the
591 // last representable power. For our use (char counts on a Solo
592 // prompt) the absolute upper bound is the LLM model's context
593 // window — well below `usize::MAX` on every Solo-supported target.
594 n.next_power_of_two()
595}
596
597/// Pull the assistant's text out of the rmcp result. Walks every text
598/// content block in the message (the spec allows either a single
599/// `SamplingContent::Single` or a `SamplingContent::Multiple`) and
600/// concatenates them with newlines. Returns `Err(reason)` if no text
601/// blocks were present — the malformed-response path.
602fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
603 if result.message.role != RmcpRole::Assistant {
604 return Err("response role was not Assistant");
605 }
606 let mut out = String::new();
607 for content in result.message.content.iter() {
608 if let SamplingMessageContent::Text(text) = content {
609 if !out.is_empty() {
610 out.push('\n');
611 }
612 out.push_str(&text.text);
613 }
614 }
615 if out.is_empty() {
616 Err("no text content blocks")
617 } else {
618 Ok(out)
619 }
620}
621
622/// v0.9.0 P2: build a sampling-backed `Arc<Steward>` for a tenant that
623/// has resolved `LlmConfig::McpSampling` and just attached an MCP
624/// session.
625///
626/// Called from [`crate::mcp::SoloMcpServer::populate_sampling_steward`] at
627/// MCP `initialize` time once the peer's sampling capability is
628/// confirmed. The returned `Arc<Steward>` is written into
629/// `tenant.steward_slot()` so the writer-actor + consolidate timer
630/// can read a populated slot on their next tick.
631///
632/// v0.9.0 P5 (M3 wiring): the live `PeerSamplingClient` is now wrapped
633/// in a [`super::SamplingCoordinator`] before being handed to
634/// `SamplingLlmClient`. Concurrent `complete()` calls within the
635/// coalesce window collapse into one `peer.create_message` RPC and the
636/// response is demultiplexed back per-task — matching the
637/// `[sampling] coalesce_window_ms` / `coalesce_max_requests` config the
638/// operator wrote in `solo.config.toml`. Per-call audit emit semantics
639/// are unchanged: every logical request still lands one
640/// `AuditOperation::LlmSamplingCall` row, no raw prompt content escapes
641/// to the audit row.
642///
643/// Edge case (clamping): the `[sampling]` block accepts values that
644/// effectively disable batching — `coalesce_max_requests = 1` and / or
645/// `coalesce_window_ms = 0` reduce the coordinator to pass-through (one
646/// inner call per submission). The coordinator's
647/// [`super::SamplingCoordinator::with_settings`] clamps `max_batch` to
648/// `max(1)` so a zero value still produces a single-element flush
649/// immediately rather than panicking or deadlocking.
650pub fn build_sampling_steward(
651 peer: Peer<RoleServer>,
652 write_handle: WriteHandle,
653 audit_principal: Option<String>,
654 steward_config: solo_steward::StewardConfig,
655 sampling_config: solo_storage::SamplingConfig,
656) -> Arc<solo_steward::Steward> {
657 let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
658 let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
659 inner,
660 std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
661 sampling_config.coalesce_max_requests as usize,
662 );
663 let client = SamplingLlmClient::with_sampling_client(
664 coordinator,
665 write_handle,
666 audit_principal,
667 )
668 .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
669 Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
676 use rmcp::model::CreateMessageResult;
677 use solo_core::TenantId;
678 use solo_storage::{
679 EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
680 TenantHandle, TenantRegistry, TenantRegistryParams, init,
681 open_sqlcipher,
682 };
683 use std::path::PathBuf;
684 use std::sync::Arc;
685 use tempfile::TempDir;
686 use zeroize::Zeroizing;
687
688 const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
689
690 /// Bootstrap a per-tenant `TenantHandle` whose writer-actor accepts
691 /// the new `WriteCommand::EmitLlmSamplingAudit` variant.
692 ///
693 /// Mirrors the v0.8.x test discipline (see
694 /// `crates/solo-storage/src/tenants/handle_registry_tests.rs`'s
695 /// `fresh_init_dir`): build a real tenant DB on disk via the same
696 /// `init()` helper users invoke, wrap in a `TenantRegistry`, and
697 /// surface the `WriteHandle` for direct `SamplingLlmClient`
698 /// wiring.
699 struct Harness {
700 _tmp: TempDir,
701 _registry: Arc<TenantRegistry>,
702 _tenant: Arc<TenantHandle>,
703 write_handle: solo_storage::WriteHandle,
704 db_path: PathBuf,
705 key: KeyMaterial,
706 }
707
708 async fn harness() -> Harness {
709 let tmp = TempDir::new().expect("tempdir");
710 let data_dir = tmp.path().to_path_buf();
711 let _ = init(InitParams {
712 data_dir: data_dir.clone(),
713 passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
714 force: false,
715 embedder: EmbedderConfig {
716 name: "stub".into(),
717 version: "v1".into(),
718 dim: 32,
719 dtype: "f32".into(),
720 },
721 })
722 .expect("init");
723
724 let cfg = solo_storage::SoloConfig::read(
725 &data_dir.join("solo.config.toml"),
726 )
727 .expect("read cfg");
728 let key = KeyMaterial::derive(
729 TEST_PASSPHRASE,
730 &cfg.salt_bytes().expect("salt"),
731 )
732 .expect("derive key");
733
734 let embedder: Arc<dyn solo_core::Embedder> =
735 Arc::new(StubEmbedder::new("stub", "v1", 32));
736 let registry = Arc::new(
737 TenantRegistry::open(TenantRegistryParams {
738 data_dir: data_dir.clone(),
739 key: key.clone(),
740 embedder: embedder.clone(),
741 hnsw_params: HnswParams::default(),
742 steward: None,
743 runtime_handle: Some(tokio::runtime::Handle::current()),
744 steward_factory: None,
745 triples_batch_signal: None,
746 })
747 .expect("open registry"),
748 );
749
750 let tenant_id = TenantId::default_tenant();
751 let tenant = registry
752 .get_or_open(&tenant_id)
753 .await
754 .expect("get_or_open default tenant");
755 let write_handle = tenant.write().clone();
756 let db_path = tenant.db_path().to_path_buf();
757
758 Harness {
759 _tmp: tmp,
760 _registry: registry,
761 _tenant: tenant,
762 write_handle,
763 db_path,
764 key,
765 }
766 }
767
768 /// Helper: count the `audit_events` rows whose `operation` is the
769 /// given string. Opens a fresh connection to the tenant DB so we
770 /// avoid contention with the writer-actor's own connection.
771 fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
772 let conn = open_sqlcipher(db_path, key).expect("open db");
773 conn.query_row(
774 "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
775 rusqlite::params![op],
776 |r| r.get(0),
777 )
778 .expect("count")
779 }
780
781 /// Helper: load the most-recent `llm.sampling_call` audit row and
782 /// return `(result, principal_subject, details_json)`.
783 fn latest_sampling_audit_details(
784 db_path: &std::path::Path,
785 key: &KeyMaterial,
786 ) -> (String, Option<String>, serde_json::Value) {
787 let conn = open_sqlcipher(db_path, key).expect("open db");
788 let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
789 .query_row(
790 "SELECT result, principal_subject, details_json
791 FROM audit_events
792 WHERE operation = 'llm.sampling_call'
793 ORDER BY ts_ms DESC, rowid DESC
794 LIMIT 1",
795 [],
796 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
797 )
798 .expect("query");
799 let details: serde_json::Value =
800 serde_json::from_str(&details_str.expect("details_json present"))
801 .expect("parse details");
802 (result, principal, details)
803 }
804
805 /// Happy path: a successful `create_message` round-trip returns
806 /// the assistant text wrapped in a `Message::assistant`, and lands
807 /// exactly one `llm.sampling_call` audit row with `result = 'ok'`.
808 #[tokio::test]
809 async fn sampling_complete_happy_path_returns_text() {
810 let h = harness().await;
811 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
812 let client = SamplingLlmClient::with_sampling_client(
813 fake.clone(),
814 h.write_handle.clone(),
815 Some("alice".into()),
816 );
817 let messages = vec![Message::user("summarise these episodes")];
818 let result = client.complete(&messages).await.expect("ok");
819 assert_eq!(result.role, Role::Assistant);
820 assert_eq!(result.content, "derived theme");
821
822 // Exactly one audit row landed.
823 assert_eq!(
824 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
825 1
826 );
827 let (result_str, principal, details) =
828 latest_sampling_audit_details(&h.db_path, &h.key);
829 assert_eq!(result_str, "ok");
830 assert_eq!(principal.as_deref(), Some("alice"));
831 assert_eq!(details["model_hint"], "claude");
832 assert_eq!(details["model"], "fake-claude");
833 assert_eq!(details["messages_count"], 1);
834 assert_eq!(details["max_tokens"], 512);
835 }
836
837 /// Privacy invariant: the audit row's `details_json` MUST NOT
838 /// contain the raw prompt content. Pinned by string inspection of
839 /// the persisted JSON.
840 #[tokio::test]
841 async fn audit_row_omits_raw_prompt_text() {
842 let h = harness().await;
843 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
844 let client = SamplingLlmClient::with_sampling_client(
845 fake,
846 h.write_handle.clone(),
847 None,
848 );
849 let secret = "THE-USER-ID-IS-bobby-1234";
850 let messages = vec![
851 Message::system("you are a friendly assistant"),
852 Message::user(secret),
853 ];
854 client.complete(&messages).await.expect("ok");
855
856 let (_, _, details) =
857 latest_sampling_audit_details(&h.db_path, &h.key);
858 let serialised =
859 serde_json::to_string(&details).expect("serialise details");
860 assert!(
861 !serialised.contains(secret),
862 "audit details must not carry raw prompt content; was: {serialised}"
863 );
864 assert!(
865 !serialised.contains("you are a friendly assistant"),
866 "audit details must not carry system prompt; was: {serialised}"
867 );
868 // Metadata IS present, even though the prompt is not.
869 assert_eq!(details["messages_count"], 1);
870 assert!(details["prompt_chars"].as_u64().unwrap() > 0);
871 }
872
873 /// v0.9.1 P1 Fix 4 (F6 privacy bucketing): the audit row's
874 /// `prompt_chars` MUST be the power-of-2 bucket, never the raw
875 /// character count. Pins the bucketing behavior end-to-end (raw
876 /// `audit_event` → SQLite → re-read).
877 ///
878 /// Test recipe: drive a prompt with a known raw length (6 chars
879 /// total, `"hello "` system + `"x"` user → 6+1 = 7) and assert the
880 /// audit row carries `8` (next pow2 ≥ 7), not 7.
881 #[tokio::test]
882 async fn audit_row_bucket_prompt_chars_to_pow2() {
883 let h = harness().await;
884 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
885 let client = SamplingLlmClient::with_sampling_client(
886 fake,
887 h.write_handle.clone(),
888 None,
889 );
890 // System: 6 chars + user: 1 char = 7 chars raw → bucket 8.
891 client
892 .complete(&[Message::system("hello "), Message::user("x")])
893 .await
894 .expect("ok");
895 let (_, _, details) =
896 latest_sampling_audit_details(&h.db_path, &h.key);
897 assert_eq!(
898 details["prompt_chars"].as_u64().unwrap(),
899 8,
900 "prompt_chars must be bucketed to next pow2 (7 → 8). \
901 raw count is a privacy side-channel; see Fix 4 F6 in \
902 v0.9.1 P1 dev log. got details={details}"
903 );
904 }
905
906 /// Stability invariant: two prompts that fall in the SAME bucket
907 /// must persist identical `prompt_chars`. Distinguishes "the
908 /// implementation buckets" from "the implementation hashes/leaks
909 /// raw values".
910 ///
911 /// 5 chars and 7 chars both round to 8 → must persist identically.
912 /// (Mirrors the brief's "test that bucketed values are stable
913 /// across exact-character variations within the same bucket".)
914 #[tokio::test]
915 async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
916 let h = harness().await;
917 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
918 let client = SamplingLlmClient::with_sampling_client(
919 fake,
920 h.write_handle.clone(),
921 None,
922 );
923 // 5 chars raw → bucket 8.
924 client
925 .complete(&[Message::user("hello")])
926 .await
927 .expect("ok");
928 let (_, _, details_5) =
929 latest_sampling_audit_details(&h.db_path, &h.key);
930 // 7 chars raw → bucket 8.
931 client
932 .complete(&[Message::user("hellooo")])
933 .await
934 .expect("ok");
935 let (_, _, details_7) =
936 latest_sampling_audit_details(&h.db_path, &h.key);
937 assert_eq!(
938 details_5["prompt_chars"], details_7["prompt_chars"],
939 "5 chars and 7 chars must hash to the same bucket (8) — \
940 otherwise the bucketing is leaking raw fidelity. \
941 5-char details: {details_5}, 7-char details: {details_7}"
942 );
943 assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
944 }
945
946 /// Unit-level pins for the bucketing helper. Catches a regression
947 /// where someone "simplifies" `next_pow2_bucket` into a no-op.
948 #[test]
949 fn next_pow2_bucket_table() {
950 assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
951 assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
952 assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
953 assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
954 assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
955 assert_eq!(next_pow2_bucket(5), 8);
956 assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
957 assert_eq!(next_pow2_bucket(7), 8);
958 assert_eq!(next_pow2_bucket(8), 8);
959 assert_eq!(next_pow2_bucket(9), 16);
960 assert_eq!(next_pow2_bucket(1023), 1024);
961 assert_eq!(next_pow2_bucket(1024), 1024);
962 assert_eq!(next_pow2_bucket(1025), 2048);
963 }
964
965 /// Client refusal: maps to `CoreError::Forbidden` + audit row
966 /// `result = 'forbidden'` + `details_json.reason = 'client_refused'`.
967 #[tokio::test]
968 async fn client_refusal_returns_forbidden_and_audits() {
969 let h = harness().await;
970 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
971 fake.reject_with("user dismissed approval");
972 let client = SamplingLlmClient::with_sampling_client(
973 fake,
974 h.write_handle.clone(),
975 Some("alice".into()),
976 );
977 let err = client
978 .complete(&[Message::user("anything")])
979 .await
980 .unwrap_err();
981 match err {
982 CoreError::Forbidden(_) => {}
983 other => panic!("expected Forbidden, got {other:?}"),
984 }
985 let (result_str, _, details) =
986 latest_sampling_audit_details(&h.db_path, &h.key);
987 assert_eq!(result_str, "forbidden");
988 assert_eq!(details["reason"], "client_refused");
989 }
990
991 /// Timeout: tokio::time::timeout fires before the fake's `Slow`
992 /// response resolves; client returns `CoreError::Llm` + audit row
993 /// `result = 'error'` + `details_json.reason = 'timeout'`.
994 ///
995 /// Real wall-clock: 80ms slow response vs 30ms client timeout.
996 /// Margin is loose enough for slow CI without making the test
997 /// drag.
998 #[tokio::test]
999 async fn timeout_returns_error_with_timeout_reason() {
1000 let h = harness().await;
1001 let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
1002 "late",
1003 Duration::from_millis(800),
1004 )));
1005 let client = SamplingLlmClient::with_sampling_client(
1006 fake,
1007 h.write_handle.clone(),
1008 None,
1009 )
1010 .with_timeout(Duration::from_millis(30));
1011 let err = client
1012 .complete(&[Message::user("hello")])
1013 .await
1014 .unwrap_err();
1015 match err {
1016 CoreError::Llm(msg) => assert!(msg.contains("timeout")),
1017 other => panic!("expected Llm, got {other:?}"),
1018 }
1019 let (result_str, _, details) =
1020 latest_sampling_audit_details(&h.db_path, &h.key);
1021 assert_eq!(result_str, "error");
1022 assert_eq!(details["reason"], "timeout");
1023 }
1024
1025 /// Malformed response: the fake returns a result with zero text
1026 /// content blocks; client surfaces `CoreError::Llm` + audit row
1027 /// `result = 'error'` + `details_json.reason = 'malformed_response'`.
1028 #[tokio::test]
1029 async fn malformed_response_returns_error_with_reason() {
1030 let h = harness().await;
1031 let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
1032 let client = SamplingLlmClient::with_sampling_client(
1033 fake,
1034 h.write_handle.clone(),
1035 None,
1036 );
1037 let err = client
1038 .complete(&[Message::user("hi")])
1039 .await
1040 .unwrap_err();
1041 assert!(matches!(err, CoreError::Llm(_)));
1042 let (result_str, _, details) =
1043 latest_sampling_audit_details(&h.db_path, &h.key);
1044 assert_eq!(result_str, "error");
1045 assert_eq!(details["reason"], "malformed_response");
1046 }
1047
1048 /// `principal_subject = None` works — audit row still emits with
1049 /// NULL.
1050 #[tokio::test]
1051 async fn no_principal_emits_audit_with_null_principal() {
1052 let h = harness().await;
1053 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1054 let client = SamplingLlmClient::with_sampling_client(
1055 fake,
1056 h.write_handle.clone(),
1057 None,
1058 );
1059 client.complete(&[Message::user("hi")]).await.expect("ok");
1060 let (_, principal, _) =
1061 latest_sampling_audit_details(&h.db_path, &h.key);
1062 assert_eq!(principal, None);
1063 }
1064
1065 /// Concurrency: 8 parallel `complete()` calls land 8 audit rows.
1066 /// Audit IDs (autoincrement rowid) must be distinct — verifies the
1067 /// writer-actor serialises the per-call audit emit (no
1068 /// interleaving / dropped rows).
1069 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1070 async fn parallel_completes_serialise_audit_rows() {
1071 let h = harness().await;
1072 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1073 let client = SamplingLlmClient::with_sampling_client(
1074 fake.clone(),
1075 h.write_handle.clone(),
1076 Some("alice".into()),
1077 );
1078 let mut futs = Vec::new();
1079 for _ in 0..8 {
1080 let c = client.clone();
1081 futs.push(tokio::spawn(async move {
1082 c.complete(&[Message::user("hi")]).await
1083 }));
1084 }
1085 for f in futs {
1086 f.await.expect("join").expect("ok");
1087 }
1088 assert_eq!(
1089 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1090 8,
1091 "8 parallel calls must land 8 audit rows"
1092 );
1093
1094 // Each was a separate request to the fake.
1095 assert_eq!(fake.record_requests().len(), 8);
1096 }
1097
1098 /// `complete` translates the workspace's `Message::system` into the
1099 /// `system_prompt` top-level field; user/assistant roles map to
1100 /// rmcp's `SamplingMessage::user_text` / `assistant_text`.
1101 #[tokio::test]
1102 async fn build_request_splits_system_from_messages() {
1103 let h = harness().await;
1104 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1105 let client = SamplingLlmClient::with_sampling_client(
1106 fake.clone(),
1107 h.write_handle.clone(),
1108 None,
1109 );
1110 client
1111 .complete(&[
1112 Message::system("be terse"),
1113 Message::user("question"),
1114 Message::assistant("answer"),
1115 ])
1116 .await
1117 .expect("ok");
1118 let recorded = fake.record_requests();
1119 assert_eq!(recorded.len(), 1);
1120 let req = &recorded[0];
1121 assert_eq!(
1122 req.system_prompt.as_deref(),
1123 Some("be terse"),
1124 "Role::System must map to system_prompt"
1125 );
1126 assert_eq!(req.messages.len(), 2);
1127 // The remaining two messages are the user + assistant turns.
1128 assert_eq!(req.messages[0].role, RmcpRole::User);
1129 assert_eq!(req.messages[1].role, RmcpRole::Assistant);
1130 }
1131
1132 /// `model_preferences` carries the `claude` hint per plan §6.
1133 /// Pins the wire shape so a future change is a conscious decision.
1134 #[tokio::test]
1135 async fn build_request_includes_claude_model_hint() {
1136 let h = harness().await;
1137 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1138 let client = SamplingLlmClient::with_sampling_client(
1139 fake.clone(),
1140 h.write_handle.clone(),
1141 None,
1142 );
1143 client
1144 .complete(&[Message::user("hi")])
1145 .await
1146 .expect("ok");
1147 let recorded = fake.record_requests();
1148 let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
1149 let hint = prefs
1150 .hints
1151 .as_ref()
1152 .and_then(|h| h.first())
1153 .and_then(|h| h.name.clone())
1154 .expect("hint name");
1155 assert_eq!(hint, "claude");
1156 }
1157
1158 /// `with_max_tokens(n)` propagates to the request's
1159 /// `max_tokens` field.
1160 #[tokio::test]
1161 async fn with_max_tokens_overrides_default() {
1162 let h = harness().await;
1163 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1164 let client = SamplingLlmClient::with_sampling_client(
1165 fake.clone(),
1166 h.write_handle.clone(),
1167 None,
1168 )
1169 .with_max_tokens(2048);
1170 client
1171 .complete(&[Message::user("hi")])
1172 .await
1173 .expect("ok");
1174 let recorded = fake.record_requests();
1175 assert_eq!(recorded[0].max_tokens, 2048);
1176 }
1177
1178 /// Reconfiguring the fake mid-test produces distinct audit rows
1179 /// for each call (positive then negative).
1180 #[tokio::test]
1181 async fn reconfigurable_fake_distinguishes_audit_rows() {
1182 let h = harness().await;
1183 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1184 let client = SamplingLlmClient::with_sampling_client(
1185 fake.clone(),
1186 h.write_handle.clone(),
1187 Some("alice".into()),
1188 );
1189
1190 client.complete(&[Message::user("a")]).await.expect("ok");
1191 fake.reject_with("user said no");
1192 let _ = client.complete(&[Message::user("b")]).await;
1193
1194 let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1195 let mut stmt = conn
1196 .prepare(
1197 "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1198 )
1199 .expect("prepare");
1200 let rows: Vec<String> = stmt
1201 .query_map([], |r| r.get::<_, String>(0))
1202 .expect("query")
1203 .map(|r| r.expect("row"))
1204 .collect();
1205 assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1206 }
1207
1208 /// `extract_text` walks single-block content.
1209 #[test]
1210 fn extract_text_pulls_text_from_single_block() {
1211 let result = CreateMessageResult::new(
1212 SamplingMessage::assistant_text("hello"),
1213 "fake".into(),
1214 );
1215 assert_eq!(extract_text(&result).unwrap(), "hello");
1216 }
1217
1218 /// `extract_text` rejects an empty-content response.
1219 #[test]
1220 fn extract_text_rejects_empty_content() {
1221 let result = CreateMessageResult::new(
1222 SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1223 "fake".into(),
1224 );
1225 assert!(extract_text(&result).is_err());
1226 }
1227
1228 /// `extract_text` rejects a User-role response (impossible per
1229 /// spec but pinning the defensive check).
1230 #[test]
1231 fn extract_text_rejects_non_assistant_role() {
1232 let result = CreateMessageResult::new(
1233 SamplingMessage::user_text("hello"),
1234 "fake".into(),
1235 );
1236 assert!(extract_text(&result).is_err());
1237 }
1238
1239 // ---- v0.10.1 F5 audit-minor closure: pin multi-block concat
1240 // semantics (deferred from v0.9.0 P2). ----
1241 //
1242 // `extract_text` walks `result.message.content` and pushes a `\n`
1243 // between successive `SamplingMessageContent::Text` blocks. The
1244 // exact semantics are wire-shape decisions that downstream Steward
1245 // parsers depend on; we pin them here so a future refactor can't
1246 // drift silently. The audit minor (v0.9.0 P2 §F5) flagged that the
1247 // existing tests only exercised single-block + empty + non-
1248 // assistant cases — not the multi-block join behavior. Closure
1249 // checklist:
1250 //
1251 // 1. join inserts ONE `\n` between non-trailing-newline blocks
1252 // 2. trailing `\n` in a block is preserved (so a `"abc\n"` block
1253 // + a `"def"` block produces `"abc\n\ndef"`)
1254 // 3. single block returns verbatim (no leading/trailing newline
1255 // added)
1256 //
1257 // These tests are intentionally identical-output to what the code
1258 // produces today. They're a regression net, not a behavior change.
1259 //
1260 // Grep terms: F5, extract_text_joins_multi_block, extract_text_preserves_trailing_newlines.
1261
1262 /// F5 pin: a two-block message joins with exactly ONE `\n` between
1263 /// blocks when neither block ends in a newline.
1264 #[test]
1265 fn extract_text_joins_multi_block_with_newline_separator() {
1266 let blocks = vec![
1267 SamplingMessageContent::text("abc"),
1268 SamplingMessageContent::text("def"),
1269 ];
1270 let result = CreateMessageResult::new(
1271 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1272 "fake".into(),
1273 );
1274 // Exact value pinned: "abc" + "\n" + "def".
1275 assert_eq!(
1276 extract_text(&result).unwrap(),
1277 "abc\ndef",
1278 "two non-newline-terminated blocks must join with a single newline"
1279 );
1280 }
1281
1282 /// F5 pin: a trailing newline in a content block is preserved
1283 /// verbatim AND a join newline is still inserted, so the result
1284 /// contains `\n\n` between such a block and the next. This is the
1285 /// "honest current behavior" pin — a future refactor that strips
1286 /// trailing newlines must explicitly update this test.
1287 #[test]
1288 fn extract_text_preserves_trailing_newlines_in_blocks() {
1289 let blocks = vec![
1290 SamplingMessageContent::text("abc\n"),
1291 SamplingMessageContent::text("def"),
1292 ];
1293 let result = CreateMessageResult::new(
1294 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1295 "fake".into(),
1296 );
1297 assert_eq!(
1298 extract_text(&result).unwrap(),
1299 "abc\n\ndef",
1300 "trailing newline in block 1 + join newline => '\\n\\n' between blocks"
1301 );
1302 }
1303
1304 /// F5 pin: a single block returns verbatim — no leading or
1305 /// trailing newline added by the helper. Already covered by
1306 /// `extract_text_pulls_text_from_single_block` for the "hello"
1307 /// case; this variant pins the negative space for inputs with
1308 /// internal newlines (the helper does NOT mutate inner whitespace).
1309 #[test]
1310 fn extract_text_single_block_returns_verbatim_including_inner_newlines() {
1311 let blocks = vec![SamplingMessageContent::text("line1\nline2")];
1312 let result = CreateMessageResult::new(
1313 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1314 "fake".into(),
1315 );
1316 assert_eq!(
1317 extract_text(&result).unwrap(),
1318 "line1\nline2",
1319 "single block must return verbatim, no extra newlines added"
1320 );
1321 }
1322
1323 /// F5 pin: three blocks, second is empty-string. The first block
1324 /// emits its text + `\n`; the empty middle pushes nothing but
1325 /// `out.is_empty()` is false so the NEXT iteration's pre-newline
1326 /// fires again. Result: `"a\n\nb"`. This pins the "empty middle
1327 /// block adds a blank line" semantic — surprising at first read
1328 /// but consistent with the join-between-non-empty rule applied
1329 /// uniformly to every iteration.
1330 #[test]
1331 fn extract_text_empty_middle_block_inserts_blank_line() {
1332 let blocks = vec![
1333 SamplingMessageContent::text("a"),
1334 SamplingMessageContent::text(""),
1335 SamplingMessageContent::text("b"),
1336 ];
1337 let result = CreateMessageResult::new(
1338 SamplingMessage::new_multiple(RmcpRole::Assistant, blocks),
1339 "fake".into(),
1340 );
1341 // The implementation's exact behavior: after "a" is pushed,
1342 // out = "a". For block 2 (empty): out.is_empty() is false, so
1343 // push '\n' → out = "a\n"; then push_str("") → out = "a\n".
1344 // For block 3 ("b"): out.is_empty() is false, so push '\n' →
1345 // out = "a\n\n"; then push_str("b") → out = "a\n\nb".
1346 assert_eq!(
1347 extract_text(&result).unwrap(),
1348 "a\n\nb",
1349 "empty middle block leaves a blank line between non-empty blocks"
1350 );
1351 }
1352
1353 /// `SamplingError::classify` maps each fake variant to the right
1354 /// audit category.
1355 #[test]
1356 fn sampling_error_classify_maps_fake_variants() {
1357 let refused = SamplingError::Fake(FakeSamplingError::Refused {
1358 reason: "x".into(),
1359 });
1360 let (cat, forb) = refused.classify();
1361 assert_eq!(cat, "client_refused");
1362 assert!(forb);
1363
1364 let transport = SamplingError::Fake(FakeSamplingError::Transport {
1365 message: "x".into(),
1366 });
1367 let (cat, forb) = transport.classify();
1368 assert_eq!(cat, "transport_error");
1369 assert!(!forb);
1370
1371 let malformed =
1372 SamplingError::Fake(FakeSamplingError::MalformedResponse {
1373 message: "x".into(),
1374 });
1375 let (cat, forb) = malformed.classify();
1376 assert_eq!(cat, "malformed_response");
1377 assert!(!forb);
1378 }
1379
1380 // -------- v0.9.0 P5a (M3 wiring) — SamplingCoordinator integration --------
1381 //
1382 // These tests pin the contract that `build_sampling_steward` wraps the
1383 // live peer in a `SamplingCoordinator` before handing it to
1384 // `SamplingLlmClient`. They cannot call `build_sampling_steward`
1385 // directly (it takes a real `Peer<RoleServer>` whose constructors are
1386 // private inside rmcp), but they exercise the **exact same wiring
1387 // shape** by substituting `FakeMcpClient` for `PeerSamplingClient`.
1388 // The production code path is:
1389 //
1390 // PeerSamplingClient -> SamplingCoordinator -> SamplingLlmClient
1391 //
1392 // The tested shape is:
1393 //
1394 // FakeMcpClient -> SamplingCoordinator -> SamplingLlmClient
1395 //
1396 // Only the leaf `SamplingClient` impl differs; the
1397 // `SamplingClient` trait is the same Arc-of-dyn in both paths.
1398
1399 /// SamplingCoordinator wrapping a `FakeMcpClient` and feeding
1400 /// `SamplingLlmClient::with_sampling_client` is the same Arc-of-dyn
1401 /// shape `build_sampling_steward` constructs at MCP-initialize
1402 /// time. Single-element flushes pass through unwrapped, so a lone
1403 /// `complete()` call still emits one audit row and produces the
1404 /// expected text.
1405 #[tokio::test]
1406 async fn sampling_llm_client_uses_coordinator_in_production_path() {
1407 let h = harness().await;
1408 let fake: Arc<dyn SamplingClient> =
1409 Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1410 let coord: Arc<dyn SamplingClient> =
1411 super::super::SamplingCoordinator::with_settings(
1412 fake.clone(),
1413 Duration::from_millis(50),
1414 10,
1415 );
1416 let client = SamplingLlmClient::with_sampling_client(
1417 coord,
1418 h.write_handle.clone(),
1419 Some("alice".into()),
1420 );
1421 let result = client
1422 .complete(&[Message::user("test")])
1423 .await
1424 .expect("ok");
1425 assert_eq!(result.role, Role::Assistant);
1426 assert_eq!(result.content, "ok");
1427 // Single audit row landed — per-call audit semantics
1428 // unchanged by the coordinator wrap.
1429 assert_eq!(
1430 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1431 1,
1432 "one logical call → one audit row, even through coordinator"
1433 );
1434 }
1435
1436 /// End-to-end batching pin: N concurrent `complete()` calls within
1437 /// the coalesce window resolve as ONE inner `create_message` RPC
1438 /// on the underlying `FakeMcpClient`, but N audit rows still land
1439 /// (one per logical call — the privacy + audit invariants from P2
1440 /// hold).
1441 ///
1442 /// This is the v0.9.0 release notes' "⌈N/M⌉ peer.create_message
1443 /// calls per coalesce window" claim, exercised through the same
1444 /// trait-object chain that `build_sampling_steward` constructs in
1445 /// production.
1446 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1447 async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1448 // Coalesced JSON response for 5 tasks — matches the
1449 // `[{task_index, response}]` shape `flush_batch` demuxes
1450 // multi-element batches into.
1451 let response = serde_json::to_string(&(0..5)
1452 .map(|i| serde_json::json!({
1453 "task_index": i,
1454 "response": format!("response-{i}"),
1455 }))
1456 .collect::<Vec<_>>())
1457 .unwrap();
1458
1459 let h = harness().await;
1460 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1461 let coord: Arc<dyn SamplingClient> =
1462 super::super::SamplingCoordinator::with_settings(
1463 fake.clone(),
1464 // Wide window so all 5 submissions land in one batch.
1465 Duration::from_secs(5),
1466 10,
1467 );
1468 let client = SamplingLlmClient::with_sampling_client(
1469 coord,
1470 h.write_handle.clone(),
1471 Some("alice".into()),
1472 );
1473
1474 // Fire 5 concurrent `complete()` calls; the coordinator should
1475 // coalesce them into ONE `FakeMcpClient::create_message` call.
1476 let mut futs = Vec::new();
1477 for i in 0..5 {
1478 let c = client.clone();
1479 futs.push(tokio::spawn(async move {
1480 c.complete(&[Message::user(format!("task-{i}"))]).await
1481 }));
1482 }
1483 for f in futs {
1484 f.await.expect("join").expect("ok");
1485 }
1486
1487 // EXACTLY one inner RPC.
1488 assert_eq!(
1489 fake.record_requests().len(),
1490 1,
1491 "5 logical calls within window must coalesce to 1 inner RPC"
1492 );
1493 // BUT 5 audit rows — per-logical-call audit invariant preserved.
1494 assert_eq!(
1495 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1496 5,
1497 "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1498 );
1499 }
1500
1501 /// Edge case: `coalesce_max_requests = 1` reduces the coordinator
1502 /// to pass-through (each submit flushes a 1-element batch
1503 /// immediately). With max_batch=1 and a wide window, 3 concurrent
1504 /// calls land 3 inner RPCs — coordinator is operating as if no
1505 /// batching were configured.
1506 ///
1507 /// Pins the brief's documented edge-case: zero / one-valued config
1508 /// reduces to pass-through, never panics or deadlocks. Mirrors
1509 /// `SamplingCoordinator::with_settings`'s `max_batch.max(1)`
1510 /// clamping for the `coalesce_max_requests = 0` case.
1511 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1512 async fn coordinator_max_batch_one_acts_as_passthrough() {
1513 let h = harness().await;
1514 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1515 let coord: Arc<dyn SamplingClient> =
1516 super::super::SamplingCoordinator::with_settings(
1517 fake.clone(),
1518 Duration::from_secs(5),
1519 // max_batch=1 → every submission flushes immediately as
1520 // a 1-element batch; pass-through behaviour.
1521 1,
1522 );
1523 let client = SamplingLlmClient::with_sampling_client(
1524 coord,
1525 h.write_handle.clone(),
1526 None,
1527 );
1528 let mut futs = Vec::new();
1529 for _ in 0..3 {
1530 let c = client.clone();
1531 futs.push(tokio::spawn(async move {
1532 c.complete(&[Message::user("hi")]).await
1533 }));
1534 }
1535 for f in futs {
1536 f.await.expect("join").expect("ok");
1537 }
1538 // 3 logical calls → 3 inner RPCs (no coalescing).
1539 assert_eq!(
1540 fake.record_requests().len(),
1541 3,
1542 "max_batch=1 must pass through every submission as its own RPC"
1543 );
1544 assert_eq!(
1545 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1546 3
1547 );
1548 }
1549}