Skip to main content

solo_api/llm/
sampling_coordinator.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! [`SamplingCoordinator`] — coalesces N concurrent
4//! [`SamplingClient::create_message`] calls into ⌈N/M⌉ calls within a
5//! configurable time window or batch-size limit.
6//!
7//! Per the v0.9.0 design (`docs/dev-log/0098-v0.9.0-implementation-plan.md`
8//! §4 P4 / Risk #7 / MAJOR 3 batching resolution):
9//!
10//! ## Why batch?
11//!
12//! Each `sampling/createMessage` call surfaces ONE approval prompt
13//! in the user's MCP client (Claude Desktop / Claude Code / future
14//! clients). When the daemon-side consolidate-timer fires
15//! [`solo_storage::triples_batch::run_triples_batch_tick`], it
16//! can produce N per-cluster sampling calls in quick succession —
17//! N separate approval prompts spam the user.
18//!
19//! `SamplingCoordinator` collapses N calls within a `window` window
20//! into ONE coalesced `peer.create_message` call. The user sees ONE
21//! approval per coalesce window; the per-cluster results are
22//! demultiplexed back to the individual callers via their oneshot
23//! reply channels.
24//!
25//! ## When NOT to batch
26//!
27//! Coordinator is bypassed for non-sampling backends (Anthropic /
28//! Ollama / None) — those don't surface approval prompts and have
29//! their own rate-limiting concerns. The coordinator inserts itself
30//! ONLY when wrapping [`PeerSamplingClient`] / a fake equivalent.
31//!
32//! ## Coalesce strategy
33//!
34//! - **Single-request batch**: passes through as a normal
35//!   `create_message` call, with no prompt rewriting. Zero behaviour
36//!   change from the v0.9.0 P2 path.
37//! - **Multi-request batch (N > 1)**: wraps each request in a
38//!   numbered JSON object, asks the LLM for a JSON array of
39//!   responses, parses the array, demultiplexes per-task. The
40//!   prompt template is documented in [`build_coalesced_request`].
41//!
42//! ## Privacy invariant
43//!
44//! The audit emit per logical request (one
45//! `AuditOperation::LlmSamplingCall` row per submitted
46//! [`SamplingLlmClient::complete`]) STAYS — the coordinator is an
47//! optimisation on the wire, NOT a change to the audit shape. See
48//! plan §11 Risk #8 — operators MUST be able to count per-logical-call
49//! audit rows, not per-coalesce.
50
51use std::sync::Arc;
52use std::time::Duration;
53
54use async_trait::async_trait;
55use rmcp::model::{
56    CreateMessageRequestParams, CreateMessageResult, Role as RmcpRole, SamplingMessage,
57    SamplingMessageContent,
58};
59use tokio::sync::{Mutex, mpsc, oneshot};
60
61use crate::llm::sampling::{SamplingClient, SamplingError};
62
63/// Default coalesce window: 5 seconds. Plan §4 P4c default —
64/// chosen so the user's approval-prompt latency stays under typical
65/// MCP-session "I'm doing work" tolerance.
66pub const DEFAULT_COALESCE_WINDOW: Duration = Duration::from_millis(5000);
67
68/// Default max-batch size: 10 logical requests per coalesced
69/// `create_message`. Plan §4 P4c default — caps the rendered prompt
70/// size + prevents one slow batch from holding the worker indefinitely.
71pub const DEFAULT_COALESCE_MAX_BATCH: usize = 10;
72
73/// Wrapper around a [`SamplingClient`] that coalesces concurrent
74/// `create_message` calls into batched `create_message` calls
75/// (within a configurable time window OR batch-size limit).
76///
77/// Construct via [`SamplingCoordinator::new`] or
78/// [`SamplingCoordinator::with_settings`]. Drop the coordinator's
79/// last `Arc` clone to shut down the worker task.
80///
81/// **Thread safety**: cheap to clone; every clone shares the same
82/// underlying mpsc + worker task. Concurrent callers are serialised
83/// by the worker.
84pub struct SamplingCoordinator {
85    /// Send-side of the worker mpsc. Cloning the coordinator clones
86    /// this sender; dropping every clone closes the channel and the
87    /// worker exits.
88    tx: mpsc::Sender<Submission>,
89    /// JoinHandle for the worker task. Stored in a Mutex so
90    /// `shutdown_blocking` can `.take()` it on first call.
91    worker: Mutex<Option<tokio::task::JoinHandle<()>>>,
92}
93
94impl SamplingCoordinator {
95    /// Build a coordinator wrapping the supplied [`SamplingClient`]
96    /// with default settings (5s window, max-batch 10).
97    pub fn new(inner: Arc<dyn SamplingClient>) -> Arc<Self> {
98        Self::with_settings(inner, DEFAULT_COALESCE_WINDOW, DEFAULT_COALESCE_MAX_BATCH)
99    }
100
101    /// Build a coordinator with explicit settings. `window` is the
102    /// upper bound the worker waits before flushing a non-empty
103    /// buffer; `max_batch` is the buffer size that triggers an
104    /// immediate flush regardless of `window`.
105    pub fn with_settings(
106        inner: Arc<dyn SamplingClient>,
107        window: Duration,
108        max_batch: usize,
109    ) -> Arc<Self> {
110        let (tx, rx) = mpsc::channel::<Submission>(max_batch.max(1) * 2 + 16);
111        let worker = tokio::spawn(coordinator_worker(rx, inner, window, max_batch.max(1)));
112        Arc::new(Self {
113            tx,
114            worker: Mutex::new(Some(worker)),
115        })
116    }
117
118    /// Coalesced equivalent of `SamplingClient::create_message`. The
119    /// returned future resolves when the worker has demultiplexed the
120    /// coalesced batch's response back to this submission's slot.
121    pub async fn submit(
122        &self,
123        params: CreateMessageRequestParams,
124    ) -> Result<CreateMessageResult, SamplingError> {
125        let (reply_tx, reply_rx) = oneshot::channel();
126        self.tx
127            .send(Submission {
128                params,
129                reply: reply_tx,
130            })
131            .await
132            .map_err(|_| {
133                SamplingError::Service(rmcp::service::ServiceError::McpError(
134                    rmcp::model::ErrorData::internal_error(
135                        "sampling coordinator worker is gone (channel closed)",
136                        None,
137                    ),
138                ))
139            })?;
140        reply_rx.await.map_err(|_| {
141            SamplingError::Service(rmcp::service::ServiceError::McpError(
142                rmcp::model::ErrorData::internal_error(
143                    "sampling coordinator worker dropped reply channel",
144                    None,
145                ),
146            ))
147        })?
148    }
149
150    /// Drain the worker task. Called rarely in tests; production
151    /// drops the coordinator on daemon shutdown.
152    pub async fn shutdown(self: Arc<Self>) {
153        // Drop the send-side to close the channel.
154        let mut guard = self.worker.lock().await;
155        if let Some(join) = guard.take() {
156            join.abort();
157            let _ = join.await;
158        }
159    }
160}
161
162/// One coordinator submission: the caller's `create_message` params
163/// + the oneshot we send the demultiplexed reply back on.
164struct Submission {
165    params: CreateMessageRequestParams,
166    reply: oneshot::Sender<Result<CreateMessageResult, SamplingError>>,
167}
168
169/// Worker task: loops, draining `rx` into batches bounded by
170/// `window` (time) or `max_batch` (count), and dispatches each
171/// batch to `inner` as ONE `create_message` call.
172async fn coordinator_worker(
173    mut rx: mpsc::Receiver<Submission>,
174    inner: Arc<dyn SamplingClient>,
175    window: Duration,
176    max_batch: usize,
177) {
178    loop {
179        // Block until at least one submission arrives or the
180        // channel is closed (last sender dropped).
181        let first = match rx.recv().await {
182            Some(s) => s,
183            None => return,
184        };
185        let mut buffer: Vec<Submission> = vec![first];
186
187        // Drain additional submissions for up to `window` ms, or
188        // until `max_batch` reached. `tokio::time::timeout(window,
189        // rx.recv())` returns Ok(Some(_)) on incoming submission,
190        // Ok(None) on channel close, Err(_) on timeout.
191        let deadline = tokio::time::Instant::now() + window;
192        while buffer.len() < max_batch {
193            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
194            if remaining.is_zero() {
195                break;
196            }
197            match tokio::time::timeout(remaining, rx.recv()).await {
198                Ok(Some(s)) => buffer.push(s),
199                Ok(None) => {
200                    // Channel closed; flush this last batch and exit.
201                    flush_batch(&inner, buffer).await;
202                    return;
203                }
204                Err(_) => break,
205            }
206        }
207
208        flush_batch(&inner, buffer).await;
209    }
210}
211
212/// Dispatch one batch. Single-element batches pass through as a
213/// normal `create_message`; multi-element batches go through the
214/// coalesced JSON-array prompt path.
215async fn flush_batch(inner: &Arc<dyn SamplingClient>, batch: Vec<Submission>) {
216    if batch.is_empty() {
217        return;
218    }
219    if batch.len() == 1 {
220        // Pass-through. Zero behaviour change from the
221        // unwrapped-coordinator path.
222        let mut iter = batch.into_iter();
223        let s = iter.next().unwrap();
224        let result = inner.create_message(s.params).await;
225        let _ = s.reply.send(result);
226        return;
227    }
228
229    let coalesced = build_coalesced_request(&batch);
230    let result = inner.create_message(coalesced).await;
231
232    match result {
233        Ok(rendered) => {
234            // Try to demultiplex the rendered JSON array back into
235            // per-task results. If parsing fails, surface the
236            // error to EVERY submission — the caller will see a
237            // structured `malformed_response` per audit row.
238            match demux_coalesced(&rendered, &batch) {
239                Ok(per_task) => {
240                    for (sub, task_result) in batch.into_iter().zip(per_task) {
241                        let _ = sub.reply.send(task_result);
242                    }
243                }
244                Err(parse_err) => {
245                    let err_msg = format!(
246                        "sampling coordinator: failed to parse coalesced response: {parse_err}"
247                    );
248                    for sub in batch {
249                        let _ = sub.reply.send(Err(SamplingError::Service(
250                            rmcp::service::ServiceError::McpError(
251                                rmcp::model::ErrorData::internal_error(err_msg.clone(), None),
252                            ),
253                        )));
254                    }
255                }
256            }
257        }
258        Err(e) => {
259            // The single coalesced RPC failed. Surface the failure
260            // to EVERY submission so per-logical-call audit rows
261            // correctly record the failure (lesson #30: every logical
262            // request needs its audit trail).
263            let err_msg = format!("{e}");
264            for sub in batch {
265                let _ = sub.reply.send(Err(SamplingError::Service(
266                    rmcp::service::ServiceError::McpError(
267                        rmcp::model::ErrorData::internal_error(
268                            format!("sampling coordinator: coalesced RPC failed: {err_msg}"),
269                            None,
270                        ),
271                    ),
272                )));
273            }
274        }
275    }
276}
277
278/// Build a coalesced request from N submissions.
279///
280/// Prompt template:
281///
282/// ```text
283/// System:
284///   You are a batch task processor. Process EVERY task listed
285///   in the user message and reply with a JSON array of objects
286///   where each object has shape:
287///   { "task_index": <int starting from 0>, "response": "<string>" }
288///   The array MUST have exactly N entries (one per task) in the
289///   SAME ORDER. Do NOT include any prose outside the JSON.
290///   [+ any system prompts from individual tasks, concatenated]
291///
292/// User:
293///   {
294///     "tasks": [
295///       { "task_index": 0, "messages": [...messages from request 0...] },
296///       { "task_index": 1, "messages": [...messages from request 1...] },
297///       ...
298///     ]
299///   }
300/// ```
301///
302/// The `max_tokens` of the coalesced request is the SUM of every
303/// submission's `max_tokens`, capped at u32::MAX / 2. This gives
304/// each per-task response room to be ~as long as its un-coalesced
305/// counterpart would have been.
306fn build_coalesced_request(batch: &[Submission]) -> CreateMessageRequestParams {
307    let mut tasks: Vec<serde_json::Value> = Vec::with_capacity(batch.len());
308    let mut system_parts: Vec<String> = vec![
309        "You are a batch task processor. Process EVERY task listed in the \
310         user message and reply with a JSON array of objects where each \
311         object has shape: { \"task_index\": <int starting from 0>, \
312         \"response\": \"<string>\" }. The array MUST have exactly N entries \
313         (one per task) in the SAME ORDER. Do NOT include any prose outside \
314         the JSON."
315            .to_string(),
316    ];
317
318    for (idx, sub) in batch.iter().enumerate() {
319        let mut task_messages: Vec<serde_json::Value> = Vec::new();
320        if let Some(sys) = sub.params.system_prompt.as_ref() {
321            // Hoist per-task system prompts so the LLM still sees
322            // them inside its task context.
323            system_parts.push(format!("Task-{idx} sub-system: {sys}"));
324        }
325        for sm in &sub.params.messages {
326            let role_str = match sm.role {
327                RmcpRole::User => "user",
328                RmcpRole::Assistant => "assistant",
329            };
330            let mut text_parts: Vec<String> = Vec::new();
331            for content in sm.content.iter() {
332                if let SamplingMessageContent::Text(t) = content {
333                    text_parts.push(t.text.clone());
334                }
335            }
336            task_messages.push(serde_json::json!({
337                "role": role_str,
338                "content": text_parts.join("\n"),
339            }));
340        }
341        tasks.push(serde_json::json!({
342            "task_index": idx,
343            "messages": task_messages,
344        }));
345    }
346
347    let user_payload =
348        serde_json::json!({ "tasks": tasks }).to_string();
349
350    let max_tokens = batch
351        .iter()
352        .map(|s| s.params.max_tokens)
353        .fold(0u32, |acc, n| acc.saturating_add(n));
354
355    let mut params = CreateMessageRequestParams::new(
356        vec![SamplingMessage::user_text(&user_payload)],
357        max_tokens.max(1),
358    );
359    params = params.with_system_prompt(system_parts.join("\n\n"));
360    // Carry over model_preferences from the first submission (if any).
361    if let Some(prefs) = batch[0].params.model_preferences.as_ref() {
362        params = params.with_model_preferences(prefs.clone());
363    }
364    params
365}
366
367/// Parse a coalesced response back into per-task `CreateMessageResult`s.
368///
369/// Expects the rendered message text to be a JSON array of objects
370/// shaped `{ "task_index": <int>, "response": "<string>" }`. The
371/// `task_index` field is required and must match the submission
372/// order. Extra fields are ignored.
373///
374/// Returns one `Result<CreateMessageResult, SamplingError>` per
375/// submission, in the SAME ORDER as `batch`. Per-task entries
376/// missing from the response surface as
377/// `SamplingError::Service(...)` with a `malformed_response`-style
378/// message; the per-call audit row carries the failure.
379fn demux_coalesced(
380    rendered: &CreateMessageResult,
381    batch: &[Submission],
382) -> Result<Vec<Result<CreateMessageResult, SamplingError>>, String> {
383    let text = extract_text_from_result(rendered).map_err(|e| e.to_string())?;
384    let parsed: serde_json::Value = match serde_json::from_str(&text) {
385        Ok(v) => v,
386        Err(e) => {
387            // Tolerate fenced ```json ... ``` blocks the model may
388            // wrap the array in (matches `solo_steward::abstraction`
389            // tolerance).
390            match extract_fenced_json(&text) {
391                Some(inner) => serde_json::from_str(inner)
392                    .map_err(|fe| format!("fenced parse: {fe}"))?,
393                None => return Err(format!("top-level JSON parse: {e}")),
394            }
395        }
396    };
397    let arr = parsed
398        .as_array()
399        .ok_or_else(|| "response root is not a JSON array".to_string())?;
400
401    let mut out: Vec<Result<CreateMessageResult, SamplingError>> =
402        Vec::with_capacity(batch.len());
403    for (idx, _sub) in batch.iter().enumerate() {
404        let entry = arr.iter().find(|e| {
405            e.get("task_index")
406                .and_then(|v| v.as_i64())
407                .map(|i| i as usize == idx)
408                .unwrap_or(false)
409        });
410        match entry {
411            Some(e) => {
412                let response_text = e
413                    .get("response")
414                    .and_then(|v| v.as_str())
415                    .unwrap_or("");
416                out.push(Ok(make_assistant_result(response_text, &rendered.model)));
417            }
418            None => out.push(Err(SamplingError::Service(
419                rmcp::service::ServiceError::McpError(
420                    rmcp::model::ErrorData::internal_error(
421                        format!(
422                            "sampling coordinator: response missing task_index {idx}"
423                        ),
424                        None,
425                    ),
426                ),
427            ))),
428        }
429    }
430    Ok(out)
431}
432
433fn extract_fenced_json(text: &str) -> Option<&str> {
434    let needle = "```json";
435    let start = text.find(needle)?;
436    let after = &text[start + needle.len()..];
437    let end = after.find("```")?;
438    Some(after[..end].trim())
439}
440
441fn extract_text_from_result(result: &CreateMessageResult) -> Result<String, &'static str> {
442    if result.message.role != RmcpRole::Assistant {
443        return Err("response role was not Assistant");
444    }
445    let mut out = String::new();
446    for content in result.message.content.iter() {
447        if let SamplingMessageContent::Text(text) = content {
448            if !out.is_empty() {
449                out.push('\n');
450            }
451            out.push_str(&text.text);
452        }
453    }
454    if out.is_empty() {
455        Err("no text content blocks")
456    } else {
457        Ok(out)
458    }
459}
460
461fn make_assistant_result(text: &str, model: &str) -> CreateMessageResult {
462    CreateMessageResult::new(
463        SamplingMessage::assistant_text(text.to_string()),
464        model.to_string(),
465    )
466}
467
468#[async_trait]
469impl SamplingClient for SamplingCoordinator {
470    async fn create_message(
471        &self,
472        params: CreateMessageRequestParams,
473    ) -> Result<CreateMessageResult, SamplingError> {
474        self.submit(params).await
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::test_support::{FakeMcpClient, FakeResponse};
482
483    fn mk_params(prompt: &str) -> CreateMessageRequestParams {
484        CreateMessageRequestParams::new(vec![SamplingMessage::user_text(prompt)], 128)
485    }
486
487    fn coalesced_response_for(n_tasks: usize) -> String {
488        let mut arr = Vec::with_capacity(n_tasks);
489        for i in 0..n_tasks {
490            arr.push(serde_json::json!({
491                "task_index": i,
492                "response": format!("response-{i}"),
493            }));
494        }
495        serde_json::to_string(&arr).unwrap()
496    }
497
498    /// P4d coalesce window: when N submissions arrive within the
499    /// coalesce window, the underlying SamplingClient sees ONE
500    /// `create_message` call carrying the coalesced prompt.
501    #[tokio::test]
502    async fn coalesces_n_concurrent_submissions_into_one_create_message_call() {
503        let response = coalesced_response_for(3);
504        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
505        let coord = SamplingCoordinator::with_settings(
506            fake.clone(),
507            Duration::from_millis(100),
508            10,
509        );
510
511        // Submit 3 concurrent requests within the window.
512        let c1 = coord.clone();
513        let c2 = coord.clone();
514        let c3 = coord.clone();
515        let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
516        let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
517        let h3 = tokio::spawn(async move { c3.submit(mk_params("task-C")).await });
518
519        let r1 = h1.await.unwrap().expect("submission 1 ok");
520        let r2 = h2.await.unwrap().expect("submission 2 ok");
521        let r3 = h3.await.unwrap().expect("submission 3 ok");
522
523        // Inner saw EXACTLY one call.
524        let recorded = fake.record_requests();
525        assert_eq!(
526            recorded.len(),
527            1,
528            "coordinator must coalesce 3 submissions into 1 inner call"
529        );
530
531        // Each submission got its task's text demultiplexed.
532        assert_eq!(
533            extract_text_from_result(&r1).unwrap(),
534            "response-0",
535            "task-0 response routed to first submission"
536        );
537        assert_eq!(extract_text_from_result(&r2).unwrap(), "response-1");
538        assert_eq!(extract_text_from_result(&r3).unwrap(), "response-2");
539    }
540
541    /// P4d max-batch trigger: when `max_batch` submissions arrive
542    /// FASTER than the window, the coordinator flushes immediately
543    /// without waiting out the window.
544    #[tokio::test]
545    async fn flushes_at_max_batch_before_window_expires() {
546        let response = coalesced_response_for(2);
547        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
548        // 5-second window, 2-request max-batch.
549        let coord = SamplingCoordinator::with_settings(
550            fake.clone(),
551            Duration::from_secs(5),
552            2,
553        );
554
555        let started = tokio::time::Instant::now();
556        let c1 = coord.clone();
557        let c2 = coord.clone();
558        let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
559        let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
560
561        let _ = h1.await.unwrap();
562        let _ = h2.await.unwrap();
563
564        let elapsed = started.elapsed();
565        assert!(
566            elapsed < Duration::from_secs(2),
567            "max_batch must flush before window expires; took {elapsed:?}"
568        );
569        assert_eq!(fake.record_requests().len(), 1);
570    }
571
572    /// P4d single-request pass-through: a lone submission within
573    /// the window goes through unchanged (no coalesce wrapping).
574    /// Caller's prompt reaches the inner client verbatim.
575    #[tokio::test]
576    async fn single_submission_passes_through_unwrapped() {
577        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(
578            "direct-response",
579        )));
580        let coord = SamplingCoordinator::with_settings(
581            fake.clone(),
582            Duration::from_millis(50),
583            10,
584        );
585
586        let result = coord
587            .submit(mk_params("lonely-task"))
588            .await
589            .expect("submission ok");
590
591        // Inner saw the SAME prompt the caller submitted — no
592        // "tasks" JSON wrapper.
593        let recorded = fake.record_requests();
594        assert_eq!(recorded.len(), 1);
595        let inner_text = extract_first_user_text(&recorded[0]);
596        assert_eq!(
597            inner_text, "lonely-task",
598            "single-batch path must NOT wrap the prompt"
599        );
600        // Caller got the inner response verbatim.
601        assert_eq!(extract_text_from_result(&result).unwrap(), "direct-response");
602    }
603
604    /// P4d window expiry: submissions trickle in slower than
605    /// `window`, so each flush carries exactly one task.
606    /// Verifies that the window timer is the trigger (not the
607    /// max_batch), and that demux works correctly for 1-task
608    /// batches.
609    #[tokio::test]
610    async fn window_expiry_flushes_each_submission_individually() {
611        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("r-first")));
612        fake.respond_each(vec![
613            FakeResponse::text("r-first"),
614            FakeResponse::text("r-second"),
615        ]);
616        let coord = SamplingCoordinator::with_settings(
617            fake.clone(),
618            Duration::from_millis(20),
619            10,
620        );
621
622        let r1 = coord
623            .submit(mk_params("first"))
624            .await
625            .expect("submission 1");
626        // Wait past the window so the second submission lands in a
627        // fresh batch.
628        tokio::time::sleep(Duration::from_millis(50)).await;
629        let r2 = coord
630            .submit(mk_params("second"))
631            .await
632            .expect("submission 2");
633
634        assert_eq!(fake.record_requests().len(), 2);
635        assert_eq!(extract_text_from_result(&r1).unwrap(), "r-first");
636        assert_eq!(extract_text_from_result(&r2).unwrap(), "r-second");
637    }
638
639    /// P4d demux fault tolerance: when the LLM omits a task_index,
640    /// only that submission gets the malformed-response error; the
641    /// others land successfully.
642    #[tokio::test]
643    async fn demux_propagates_per_request_failures() {
644        // Coalesced response with task 0 + task 2 only (task 1 missing).
645        let response = serde_json::json!([
646            { "task_index": 0, "response": "ok-0" },
647            { "task_index": 2, "response": "ok-2" },
648        ])
649        .to_string();
650        let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
651        let coord = SamplingCoordinator::with_settings(
652            fake.clone(),
653            Duration::from_millis(100),
654            10,
655        );
656
657        let c1 = coord.clone();
658        let c2 = coord.clone();
659        let c3 = coord.clone();
660        let h1 = tokio::spawn(async move { c1.submit(mk_params("t0")).await });
661        let h2 = tokio::spawn(async move { c2.submit(mk_params("t1")).await });
662        let h3 = tokio::spawn(async move { c3.submit(mk_params("t2")).await });
663
664        let r1 = h1.await.unwrap();
665        let r2 = h2.await.unwrap();
666        let r3 = h3.await.unwrap();
667
668        assert!(r1.is_ok());
669        assert!(r2.is_err(), "missing task_index must surface as error");
670        assert!(r3.is_ok());
671    }
672
673    /// P4d coalesced-RPC-failure surfaces to EVERY submission. If
674    /// the inner peer.create_message call fails, all coalesced
675    /// callers see an error (so each per-logical-call audit row
676    /// records the failure).
677    #[tokio::test]
678    async fn coalesced_rpc_failure_surfaces_to_every_submission() {
679        let fake = Arc::new(FakeMcpClient::new(FakeResponse::Error(
680            crate::test_support::FakeSamplingError::Transport {
681                message: "simulated transport failure".into(),
682            },
683        )));
684        let coord = SamplingCoordinator::with_settings(
685            fake.clone(),
686            Duration::from_millis(100),
687            10,
688        );
689
690        let c1 = coord.clone();
691        let c2 = coord.clone();
692        let h1 = tokio::spawn(async move { c1.submit(mk_params("a")).await });
693        let h2 = tokio::spawn(async move { c2.submit(mk_params("b")).await });
694
695        assert!(h1.await.unwrap().is_err());
696        assert!(h2.await.unwrap().is_err());
697    }
698
699    fn extract_first_user_text(params: &CreateMessageRequestParams) -> String {
700        for m in &params.messages {
701            if m.role == RmcpRole::User {
702                for c in m.content.iter() {
703                    if let SamplingMessageContent::Text(t) = c {
704                        return t.text.clone();
705                    }
706                }
707            }
708        }
709        String::new()
710    }
711}