Skip to main content

solo_api/llm/
sampling_coordinator.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingCoordinator`] — coalesces N concurrent
4//! [`SamplingClient::create_message`] calls into ⌈N/M⌉ calls within a
5//! configurable time window or batch-size limit.
6//!
7//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
8//! §4 P4 / Risk #7 / MAJOR 3 batching resolution):
9//!
10//! ## Why batch?
11//!
12//! Each `sampling/createMessage` call surfaces ONE approval prompt
13//! in the user's MCP client (Claude Desktop / Claude Code / future
14//! clients). When the daemon-side consolidate-timer fires
15//! [`solo_storage::triples_batch::run_triples_batch_tick`], it
16//! can produce N per-cluster sampling calls in quick succession —
17//! N separate approval prompts spam the user.
18//!
19//! `SamplingCoordinator` collapses N calls within a `window` window
20//! into ONE coalesced `peer.create_message` call. The user sees ONE
21//! approval per coalesce window; the per-cluster results are
22//! demultiplexed back to the individual callers via their oneshot
23//! reply channels.
24//!
25//! ## When NOT to batch
26//!
27//! Coordinator is bypassed for non-sampling backends (Anthropic /
28//! Ollama / None) — those don't surface approval prompts and have
29//! their own rate-limiting concerns. The coordinator inserts itself
30//! ONLY when wrapping [`PeerSamplingClient`] / a fake equivalent.
31//!
32//! ## Coalesce strategy
33//!
34//! - **Single-request batch**: passes through as a normal
35//!   `create_message` call, with no prompt rewriting. Zero behaviour
36//!   change from the v0.9.0 P2 path.
37//! - **Multi-request batch (N > 1)**: wraps each request in a
38//!   numbered JSON object, asks the LLM for a JSON array of
39//!   responses, parses the array, demultiplexes per-task. The
40//!   prompt template is documented in [`build_coalesced_request`].
41//!
42//! ## Privacy invariant
43//!
44//! The audit emit per logical request (one
45//! `AuditOperation::LlmSamplingCall` row per submitted
46//! [`SamplingLlmClient::complete`]) STAYS — the coordinator is an
47//! optimisation on the wire, NOT a change to the audit shape. See
48//! plan §11 Risk #8 — operators MUST be able to count per-logical-call
49//! audit rows, not per-coalesce.
50
51use std::sync::Arc;
52use std::time::Duration;
53
54use async_trait::async_trait;
55use rmcp::model::{
56    CreateMessageRequestParams, CreateMessageResult, Role as RmcpRole, SamplingMessage,
57    SamplingMessageContent,
58};
59use tokio::sync::{Mutex, mpsc, oneshot};
60
61use crate::llm::sampling::{SamplingClient, SamplingError};
62
63/// Default coalesce window: 5 seconds. Plan §4 P4c default —
64/// chosen so the user's approval-prompt latency stays under typical
65/// MCP-session "I'm doing work" tolerance.
66pub const DEFAULT_COALESCE_WINDOW: Duration = Duration::from_millis(5000);
67
68/// Default max-batch size: 10 logical requests per coalesced
69/// `create_message`. Plan §4 P4c default — caps the rendered prompt
70/// size + prevents one slow batch from holding the worker indefinitely.
71pub const DEFAULT_COALESCE_MAX_BATCH: usize = 10;
72
73/// Wrapper around a [`SamplingClient`] that coalesces concurrent
74/// `create_message` calls into batched `create_message` calls
75/// (within a configurable time window OR batch-size limit).
76///
77/// Construct via [`SamplingCoordinator::new`] or
78/// [`SamplingCoordinator::with_settings`]. Drop the coordinator's
79/// last `Arc` clone to shut down the worker task.
80///
81/// **Thread safety**: cheap to clone; every clone shares the same
82/// underlying mpsc + worker task. Concurrent callers are serialised
83/// by the worker.
84pub struct SamplingCoordinator {
85    /// Send-side of the worker mpsc. Cloning the coordinator clones
86    /// this sender; dropping every clone closes the channel and the
87    /// worker exits.
88    tx: mpsc::Sender<Submission>,
89    /// JoinHandle for the worker task. Stored in a Mutex so
90    /// `shutdown_blocking` can `.take()` it on first call.
91    worker: Mutex<Option<tokio::task::JoinHandle<()>>>,
92}
93
94impl SamplingCoordinator {
95    /// Build a coordinator wrapping the supplied [`SamplingClient`]
96    /// with default settings (5s window, max-batch 10).
97    pub fn new(inner: Arc<dyn SamplingClient>) -> Arc<Self> {
98        Self::with_settings(inner, DEFAULT_COALESCE_WINDOW, DEFAULT_COALESCE_MAX_BATCH)
99    }
100
101    /// Build a coordinator with explicit settings. `window` is the
102    /// upper bound the worker waits before flushing a non-empty
103    /// buffer; `max_batch` is the buffer size that triggers an
104    /// immediate flush regardless of `window`.
105    pub fn with_settings(
106        inner: Arc<dyn SamplingClient>,
107        window: Duration,
108        max_batch: usize,
109    ) -> Arc<Self> {
110        let (tx, rx) = mpsc::channel::<Submission>(max_batch.max(1) * 2 + 16);
111        let worker = tokio::spawn(coordinator_worker(rx, inner, window, max_batch.max(1)));
112        Arc::new(Self {
113            tx,
114            worker: Mutex::new(Some(worker)),
115        })
116    }
117
118    /// Coalesced equivalent of `SamplingClient::create_message`. The
119    /// returned future resolves when the worker has demultiplexed the
120    /// coalesced batch's response back to this submission's slot.
121    pub async fn submit(
122        &self,
123        params: CreateMessageRequestParams,
124    ) -> Result<CreateMessageResult, SamplingError> {
125        let (reply_tx, reply_rx) = oneshot::channel();
126        self.tx
127            .send(Submission {
128                params,
129                reply: reply_tx,
130            })
131            .await
132            .map_err(|_| {
133                SamplingError::Service(rmcp::service::ServiceError::McpError(
134                    rmcp::model::ErrorData::internal_error(
135                        "sampling coordinator worker is gone (channel closed)",
136                        None,
137                    ),
138                ))
139            })?;
140        reply_rx.await.map_err(|_| {
141            SamplingError::Service(rmcp::service::ServiceError::McpError(
142                rmcp::model::ErrorData::internal_error(
143                    "sampling coordinator worker dropped reply channel",
144                    None,
145                ),
146            ))
147        })?
148    }
149
150    /// Drain the worker task. Called rarely in tests; production
151    /// drops the coordinator on daemon shutdown.
152    pub async fn shutdown(self: Arc<Self>) {
153        // Drop the send-side to close the channel.
154        let mut guard = self.worker.lock().await;
155        if let Some(join) = guard.take() {
156            join.abort();
157            let _ = join.await;
158        }
159    }
160}
161
162/// One coordinator submission: the caller's `create_message` params
163/// + the oneshot we send the demultiplexed reply back on.
164struct Submission {
165    params: CreateMessageRequestParams,
166    reply: oneshot::Sender<Result<CreateMessageResult, SamplingError>>,
167}
168
169/// Worker task: loops, draining `rx` into batches bounded by
170/// `window` (time) or `max_batch` (count), and dispatches each
171/// batch to `inner` as ONE `create_message` call.
172async fn coordinator_worker(
173    mut rx: mpsc::Receiver<Submission>,
174    inner: Arc<dyn SamplingClient>,
175    window: Duration,
176    max_batch: usize,
177) {
178    loop {
179        // Block until at least one submission arrives or the
180        // channel is closed (last sender dropped).
181        let first = match rx.recv().await {
182            Some(s) => s,
183            None => return,
184        };
185        let mut buffer: Vec<Submission> = vec![first];
186
187        // Drain additional submissions for up to `window` ms, or
188        // until `max_batch` reached. `tokio::time::timeout(window,
189        // rx.recv())` returns Ok(Some(_)) on incoming submission,
190        // Ok(None) on channel close, Err(_) on timeout.
191        let deadline = tokio::time::Instant::now() + window;
192        while buffer.len() < max_batch {
193            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
194            if remaining.is_zero() {
195                break;
196            }
197            match tokio::time::timeout(remaining, rx.recv()).await {
198                Ok(Some(s)) => buffer.push(s),
199                Ok(None) => {
200                    // Channel closed; flush this last batch and exit.
201                    flush_batch(&inner, buffer).await;
202                    return;
203                }
204                Err(_) => break,
205            }
206        }
207
208        flush_batch(&inner, buffer).await;
209    }
210}
211
212/// Dispatch one batch. Single-element batches pass through as a
213/// normal `create_message`; multi-element batches go through the
214/// coalesced JSON-array prompt path.
215async fn flush_batch(inner: &Arc<dyn SamplingClient>, batch: Vec<Submission>) {
216    if batch.is_empty() {
217        return;
218    }
219    if batch.len() == 1 {
220        // Pass-through. Zero behaviour change from the
221        // unwrapped-coordinator path.
222        let mut iter = batch.into_iter();
223        let s = iter.next().unwrap();
224        let result = inner.create_message(s.params).await;
225        let _ = s.reply.send(result);
226        return;
227    }
228
229    let coalesced = build_coalesced_request(&batch);
230    let result = inner.create_message(coalesced).await;
231
232    match result {
233        Ok(rendered) => {
234            // Try to demultiplex the rendered JSON array back into
235            // per-task results. If parsing fails, surface the
236            // error to EVERY submission — the caller will see a
237            // structured `malformed_response` per audit row.
238            match demux_coalesced(&rendered, &batch) {
239                Ok(per_task) => {
240                    for (sub, task_result) in batch.into_iter().zip(per_task) {
241                        let _ = sub.reply.send(task_result);
242                    }
243                }
244                Err(parse_err) => {
245                    let err_msg = format!(
246                        "sampling coordinator: failed to parse coalesced response: {parse_err}"
247                    );
248                    for sub in batch {
249                        let _ = sub.reply.send(Err(SamplingError::Service(
250                            rmcp::service::ServiceError::McpError(
251                                rmcp::model::ErrorData::internal_error(err_msg.clone(), None),
252                            ),
253                        )));
254                    }
255                }
256            }
257        }
258        Err(e) => {
259            // The single coalesced RPC failed. Surface the failure
260            // to EVERY submission so per-logical-call audit rows
261            // correctly record the failure (lesson #30: every logical
262            // request needs its audit trail).
263            let err_msg = format!("{e}");
264            for sub in batch {
265                let _ = sub.reply.send(Err(SamplingError::Service(
266                    rmcp::service::ServiceError::McpError(rmcp::model::ErrorData::internal_error(
267                        format!("sampling coordinator: coalesced RPC failed: {err_msg}"),
268                        None,
269                    )),
270                )));
271            }
272        }
273    }
274}
275
276/// Build a coalesced request from N submissions.
277///
278/// Prompt template:
279///
280/// ```text
281/// System:
282///   You are a batch task processor. Process EVERY task listed
283///   in the user message and reply with a JSON array of objects
284///   where each object has shape:
285///   { "task_index": <int starting from 0>, "response": "<string>" }
286///   The array MUST have exactly N entries (one per task) in the
287///   SAME ORDER. Do NOT include any prose outside the JSON.
288///   [+ any system prompts from individual tasks, concatenated]
289///
290/// User:
291///   {
292///     "tasks": [
293///       { "task_index": 0, "messages": [...messages from request 0...] },
294///       { "task_index": 1, "messages": [...messages from request 1...] },
295///       ...
296///     ]
297///   }
298/// ```
299///
300/// The `max_tokens` of the coalesced request is the SUM of every
301/// submission's `max_tokens`, capped at u32::MAX / 2. This gives
302/// each per-task response room to be ~as long as its un-coalesced
303/// counterpart would have been.
304fn build_coalesced_request(batch: &[Submission]) -> CreateMessageRequestParams {
305    let mut tasks: Vec<serde_json::Value> = Vec::with_capacity(batch.len());
306    let mut system_parts: Vec<String> = vec![
307        "You are a batch task processor. Process EVERY task listed in the \
308         user message and reply with a JSON array of objects where each \
309         object has shape: { \"task_index\": <int starting from 0>, \
310         \"response\": \"<string>\" }. The array MUST have exactly N entries \
311         (one per task) in the SAME ORDER. Do NOT include any prose outside \
312         the JSON."
313            .to_string(),
314    ];
315
316    for (idx, sub) in batch.iter().enumerate() {
317        let mut task_messages: Vec<serde_json::Value> = Vec::new();
318        if let Some(sys) = sub.params.system_prompt.as_ref() {
319            // Hoist per-task system prompts so the LLM still sees
320            // them inside its task context.
321            system_parts.push(format!("Task-{idx} sub-system: {sys}"));
322        }
323        for sm in &sub.params.messages {
324            let role_str = match sm.role {
325                RmcpRole::User => "user",
326                RmcpRole::Assistant => "assistant",
327            };
328            let mut text_parts: Vec<String> = Vec::new();
329            for content in sm.content.iter() {
330                if let SamplingMessageContent::Text(t) = content {
331                    text_parts.push(t.text.clone());
332                }
333            }
334            task_messages.push(serde_json::json!({
335                "role": role_str,
336                "content": text_parts.join("\n"),
337            }));
338        }
339        tasks.push(serde_json::json!({
340            "task_index": idx,
341            "messages": task_messages,
342        }));
343    }
344
345    let user_payload = serde_json::json!({ "tasks": tasks }).to_string();
346
347    let max_tokens = batch
348        .iter()
349        .map(|s| s.params.max_tokens)
350        .fold(0u32, |acc, n| acc.saturating_add(n));
351
352    let mut params = CreateMessageRequestParams::new(
353        vec![SamplingMessage::user_text(&user_payload)],
354        max_tokens.max(1),
355    );
356    params = params.with_system_prompt(system_parts.join("\n\n"));
357    // Carry over model_preferences from the first submission (if any).
358    if let Some(prefs) = batch[0].params.model_preferences.as_ref() {
359        params = params.with_model_preferences(prefs.clone());
360    }
361    params
362}
363
364/// Parse a coalesced response back into per-task `CreateMessageResult`s.
365///
366/// Expects the rendered message text to be a JSON array of objects
367/// shaped `{ "task_index": <int>, "response": "<string>" }`. The
368/// `task_index` field is required and must match the submission
369/// order. Extra fields are ignored.
370///
371/// Returns one `Result<CreateMessageResult, SamplingError>` per
372/// submission, in the SAME ORDER as `batch`. Per-task entries
373/// missing from the response surface as
374/// `SamplingError::Service(...)` with a `malformed_response`-style
375/// message; the per-call audit row carries the failure.
376fn demux_coalesced(
377    rendered: &CreateMessageResult,
378    batch: &[Submission],
379) -> Result<Vec<Result<CreateMessageResult, SamplingError>>, String> {
380    let text = extract_text_from_result(rendered).map_err(|e| e.to_string())?;
381    let parsed: serde_json::Value = match serde_json::from_str(&text) {
382        Ok(v) => v,
383        Err(e) => {
384            // Tolerate fenced ```json ... ``` blocks the model may
385            // wrap the array in (matches `solo_steward::abstraction`
386            // tolerance).
387            match extract_fenced_json(&text) {
388                Some(inner) => {
389                    serde_json::from_str(inner).map_err(|fe| format!("fenced parse: {fe}"))?
390                }
391                None => return Err(format!("top-level JSON parse: {e}")),
392            }
393        }
394    };
395    let arr = parsed
396        .as_array()
397        .ok_or_else(|| "response root is not a JSON array".to_string())?;
398
399    let mut out: Vec<Result<CreateMessageResult, SamplingError>> = Vec::with_capacity(batch.len());
400    for (idx, _sub) in batch.iter().enumerate() {
401        let entry = arr.iter().find(|e| {
402            e.get("task_index")
403                .and_then(|v| v.as_i64())
404                .map(|i| i as usize == idx)
405                .unwrap_or(false)
406        });
407        match entry {
408            Some(e) => {
409                let response_text = e.get("response").and_then(|v| v.as_str()).unwrap_or("");
410                out.push(Ok(make_assistant_result(response_text, &rendered.model)));
411            }
412            None => out.push(Err(SamplingError::Service(
413                rmcp::service::ServiceError::McpError(rmcp::model::ErrorData::internal_error(
414                    format!("sampling coordinator: response missing task_index {idx}"),
415                    None,
416                )),
417            ))),
418        }
419    }
420    Ok(out)
421}
422
423fn extract_fenced_json(text: &str) -> Option<&str> {
424    let needle = "```json";
425    let start = text.find(needle)?;
426    let after = &text[start + needle.len()..];
427    let end = after.find("```")?;
428    Some(after[..end].trim())
429}
430
431fn extract_text_from_result(result: &CreateMessageResult) -> Result<String, &'static str> {
432    if result.message.role != RmcpRole::Assistant {
433        return Err("response role was not Assistant");
434    }
435    let mut out = String::new();
436    for content in result.message.content.iter() {
437        if let SamplingMessageContent::Text(text) = content {
438            if !out.is_empty() {
439                out.push('\n');
440            }
441            out.push_str(&text.text);
442        }
443    }
444    if out.is_empty() {
445        Err("no text content blocks")
446    } else {
447        Ok(out)
448    }
449}
450
451fn make_assistant_result(text: &str, model: &str) -> CreateMessageResult {
452    CreateMessageResult::new(
453        SamplingMessage::assistant_text(text.to_string()),
454        model.to_string(),
455    )
456}
457
458#[async_trait]
459impl SamplingClient for SamplingCoordinator {
460    async fn create_message(
461        &self,
462        params: CreateMessageRequestParams,
463    ) -> Result<CreateMessageResult, SamplingError> {
464        self.submit(params).await
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::test_support::{FakeMcpClient, FakeResponse};
472
473    fn mk_params(prompt: &str) -> CreateMessageRequestParams {
474        CreateMessageRequestParams::new(vec![SamplingMessage::user_text(prompt)], 128)
475    }
476
477    fn coalesced_response_for(n_tasks: usize) -> String {
478        let mut arr = Vec::with_capacity(n_tasks);
479        for i in 0..n_tasks {
480            arr.push(serde_json::json!({
481                "task_index": i,
482                "response": format!("response-{i}"),
483            }));
484        }
485        serde_json::to_string(&arr).unwrap()
486    }
487
488    /// P4d coalesce window: when N submissions arrive within the
489    /// coalesce window, the underlying SamplingClient sees ONE
490    /// `create_message` call carrying the coalesced prompt.
491    #[tokio::test]
492    async fn coalesces_n_concurrent_submissions_into_one_create_message_call() {
493        let response = coalesced_response_for(3);
494        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
495        let coord =
496            SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
497
498        // Submit 3 concurrent requests within the window.
499        let c1 = coord.clone();
500        let c2 = coord.clone();
501        let c3 = coord.clone();
502        let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
503        let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
504        let h3 = tokio::spawn(async move { c3.submit(mk_params("task-C")).await });
505
506        let r1 = h1.await.unwrap().expect("submission 1 ok");
507        let r2 = h2.await.unwrap().expect("submission 2 ok");
508        let r3 = h3.await.unwrap().expect("submission 3 ok");
509
510        // Inner saw EXACTLY one call.
511        let recorded = fake.record_requests();
512        assert_eq!(
513            recorded.len(),
514            1,
515            "coordinator must coalesce 3 submissions into 1 inner call"
516        );
517
518        // Each submission got its task's text demultiplexed.
519        assert_eq!(
520            extract_text_from_result(&r1).unwrap(),
521            "response-0",
522            "task-0 response routed to first submission"
523        );
524        assert_eq!(extract_text_from_result(&r2).unwrap(), "response-1");
525        assert_eq!(extract_text_from_result(&r3).unwrap(), "response-2");
526    }
527
528    /// P4d max-batch trigger: when `max_batch` submissions arrive
529    /// FASTER than the window, the coordinator flushes immediately
530    /// without waiting out the window.
531    #[tokio::test]
532    async fn flushes_at_max_batch_before_window_expires() {
533        let response = coalesced_response_for(2);
534        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
535        // 5-second window, 2-request max-batch.
536        let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_secs(5), 2);
537
538        let started = tokio::time::Instant::now();
539        let c1 = coord.clone();
540        let c2 = coord.clone();
541        let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
542        let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
543
544        let _ = h1.await.unwrap();
545        let _ = h2.await.unwrap();
546
547        let elapsed = started.elapsed();
548        assert!(
549            elapsed < Duration::from_secs(2),
550            "max_batch must flush before window expires; took {elapsed:?}"
551        );
552        assert_eq!(fake.record_requests().len(), 1);
553    }
554
555    /// P4d single-request pass-through: a lone submission within
556    /// the window goes through unchanged (no coalesce wrapping).
557    /// Caller's prompt reaches the inner client verbatim.
558    #[tokio::test]
559    async fn single_submission_passes_through_unwrapped() {
560        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("direct-response")));
561        let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(50), 10);
562
563        let result = coord
564            .submit(mk_params("lonely-task"))
565            .await
566            .expect("submission ok");
567
568        // Inner saw the SAME prompt the caller submitted — no
569        // "tasks" JSON wrapper.
570        let recorded = fake.record_requests();
571        assert_eq!(recorded.len(), 1);
572        let inner_text = extract_first_user_text(&recorded[0]);
573        assert_eq!(
574            inner_text, "lonely-task",
575            "single-batch path must NOT wrap the prompt"
576        );
577        // Caller got the inner response verbatim.
578        assert_eq!(
579            extract_text_from_result(&result).unwrap(),
580            "direct-response"
581        );
582    }
583
584    /// P4d window expiry: submissions trickle in slower than
585    /// `window`, so each flush carries exactly one task.
586    /// Verifies that the window timer is the trigger (not the
587    /// max_batch), and that demux works correctly for 1-task
588    /// batches.
589    #[tokio::test]
590    async fn window_expiry_flushes_each_submission_individually() {
591        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("r-first")));
592        fake.respond_each(vec![
593            FakeResponse::text("r-first"),
594            FakeResponse::text("r-second"),
595        ]);
596        let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(20), 10);
597
598        let r1 = coord
599            .submit(mk_params("first"))
600            .await
601            .expect("submission 1");
602        // Wait past the window so the second submission lands in a
603        // fresh batch.
604        tokio::time::sleep(Duration::from_millis(50)).await;
605        let r2 = coord
606            .submit(mk_params("second"))
607            .await
608            .expect("submission 2");
609
610        assert_eq!(fake.record_requests().len(), 2);
611        assert_eq!(extract_text_from_result(&r1).unwrap(), "r-first");
612        assert_eq!(extract_text_from_result(&r2).unwrap(), "r-second");
613    }
614
615    /// P4d demux fault tolerance: when the LLM omits a task_index,
616    /// only that submission gets the malformed-response error; the
617    /// others land successfully.
618    #[tokio::test]
619    async fn demux_propagates_per_request_failures() {
620        // Coalesced response with task 0 + task 2 only (task 1 missing).
621        let response = serde_json::json!([
622            { "task_index": 0, "response": "ok-0" },
623            { "task_index": 2, "response": "ok-2" },
624        ])
625        .to_string();
626        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
627        let coord =
628            SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
629
630        let c1 = coord.clone();
631        let c2 = coord.clone();
632        let c3 = coord.clone();
633        let h1 = tokio::spawn(async move { c1.submit(mk_params("t0")).await });
634        let h2 = tokio::spawn(async move { c2.submit(mk_params("t1")).await });
635        let h3 = tokio::spawn(async move { c3.submit(mk_params("t2")).await });
636
637        let r1 = h1.await.unwrap();
638        let r2 = h2.await.unwrap();
639        let r3 = h3.await.unwrap();
640
641        assert!(r1.is_ok());
642        assert!(r2.is_err(), "missing task_index must surface as error");
643        assert!(r3.is_ok());
644    }
645
646    /// P4d coalesced-RPC-failure surfaces to EVERY submission. If
647    /// the inner peer.create_message call fails, all coalesced
648    /// callers see an error (so each per-logical-call audit row
649    /// records the failure).
650    #[tokio::test]
651    async fn coalesced_rpc_failure_surfaces_to_every_submission() {
652        let fake = Arc::new(FakeMcpClient::new(FakeResponse::Error(
653            crate::test_support::FakeSamplingError::Transport {
654                message: "simulated transport failure".into(),
655            },
656        )));
657        let coord =
658            SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
659
660        let c1 = coord.clone();
661        let c2 = coord.clone();
662        let h1 = tokio::spawn(async move { c1.submit(mk_params("a")).await });
663        let h2 = tokio::spawn(async move { c2.submit(mk_params("b")).await });
664
665        assert!(h1.await.unwrap().is_err());
666        assert!(h2.await.unwrap().is_err());
667    }
668
669    fn extract_first_user_text(params: &CreateMessageRequestParams) -> String {
670        for m in &params.messages {
671            if m.role == RmcpRole::User {
672                for c in m.content.iter() {
673                    if let SamplingMessageContent::Text(t) = c {
674                        return t.text.clone();
675                    }
676                }
677            }
678        }
679        String::new()
680    }
681}