1use std::sync::Arc;
52use std::time::Duration;
53
54use async_trait::async_trait;
55use rmcp::model::{
56 CreateMessageRequestParams, CreateMessageResult, Role as RmcpRole, SamplingMessage,
57 SamplingMessageContent,
58};
59use tokio::sync::{Mutex, mpsc, oneshot};
60
61use crate::llm::sampling::{SamplingClient, SamplingError};
62
63pub const DEFAULT_COALESCE_WINDOW: Duration = Duration::from_millis(5000);
67
68pub const DEFAULT_COALESCE_MAX_BATCH: usize = 10;
72
73pub struct SamplingCoordinator {
85 tx: mpsc::Sender<Submission>,
89 worker: Mutex<Option<tokio::task::JoinHandle<()>>>,
92}
93
94impl SamplingCoordinator {
95 pub fn new(inner: Arc<dyn SamplingClient>) -> Arc<Self> {
98 Self::with_settings(inner, DEFAULT_COALESCE_WINDOW, DEFAULT_COALESCE_MAX_BATCH)
99 }
100
101 pub fn with_settings(
106 inner: Arc<dyn SamplingClient>,
107 window: Duration,
108 max_batch: usize,
109 ) -> Arc<Self> {
110 let (tx, rx) = mpsc::channel::<Submission>(max_batch.max(1) * 2 + 16);
111 let worker = tokio::spawn(coordinator_worker(rx, inner, window, max_batch.max(1)));
112 Arc::new(Self {
113 tx,
114 worker: Mutex::new(Some(worker)),
115 })
116 }
117
118 pub async fn submit(
122 &self,
123 params: CreateMessageRequestParams,
124 ) -> Result<CreateMessageResult, SamplingError> {
125 let (reply_tx, reply_rx) = oneshot::channel();
126 self.tx
127 .send(Submission {
128 params,
129 reply: reply_tx,
130 })
131 .await
132 .map_err(|_| {
133 SamplingError::Service(rmcp::service::ServiceError::McpError(
134 rmcp::model::ErrorData::internal_error(
135 "sampling coordinator worker is gone (channel closed)",
136 None,
137 ),
138 ))
139 })?;
140 reply_rx.await.map_err(|_| {
141 SamplingError::Service(rmcp::service::ServiceError::McpError(
142 rmcp::model::ErrorData::internal_error(
143 "sampling coordinator worker dropped reply channel",
144 None,
145 ),
146 ))
147 })?
148 }
149
150 pub async fn shutdown(self: Arc<Self>) {
153 let mut guard = self.worker.lock().await;
155 if let Some(join) = guard.take() {
156 join.abort();
157 let _ = join.await;
158 }
159 }
160}
161
162struct Submission {
165 params: CreateMessageRequestParams,
166 reply: oneshot::Sender<Result<CreateMessageResult, SamplingError>>,
167}
168
169async fn coordinator_worker(
173 mut rx: mpsc::Receiver<Submission>,
174 inner: Arc<dyn SamplingClient>,
175 window: Duration,
176 max_batch: usize,
177) {
178 loop {
179 let first = match rx.recv().await {
182 Some(s) => s,
183 None => return,
184 };
185 let mut buffer: Vec<Submission> = vec![first];
186
187 let deadline = tokio::time::Instant::now() + window;
192 while buffer.len() < max_batch {
193 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
194 if remaining.is_zero() {
195 break;
196 }
197 match tokio::time::timeout(remaining, rx.recv()).await {
198 Ok(Some(s)) => buffer.push(s),
199 Ok(None) => {
200 flush_batch(&inner, buffer).await;
202 return;
203 }
204 Err(_) => break,
205 }
206 }
207
208 flush_batch(&inner, buffer).await;
209 }
210}
211
212async fn flush_batch(inner: &Arc<dyn SamplingClient>, batch: Vec<Submission>) {
216 if batch.is_empty() {
217 return;
218 }
219 if batch.len() == 1 {
220 let mut iter = batch.into_iter();
223 let s = iter.next().unwrap();
224 let result = inner.create_message(s.params).await;
225 let _ = s.reply.send(result);
226 return;
227 }
228
229 let coalesced = build_coalesced_request(&batch);
230 let result = inner.create_message(coalesced).await;
231
232 match result {
233 Ok(rendered) => {
234 match demux_coalesced(&rendered, &batch) {
239 Ok(per_task) => {
240 for (sub, task_result) in batch.into_iter().zip(per_task) {
241 let _ = sub.reply.send(task_result);
242 }
243 }
244 Err(parse_err) => {
245 let err_msg = format!(
246 "sampling coordinator: failed to parse coalesced response: {parse_err}"
247 );
248 for sub in batch {
249 let _ = sub.reply.send(Err(SamplingError::Service(
250 rmcp::service::ServiceError::McpError(
251 rmcp::model::ErrorData::internal_error(err_msg.clone(), None),
252 ),
253 )));
254 }
255 }
256 }
257 }
258 Err(e) => {
259 let err_msg = format!("{e}");
264 for sub in batch {
265 let _ = sub.reply.send(Err(SamplingError::Service(
266 rmcp::service::ServiceError::McpError(
267 rmcp::model::ErrorData::internal_error(
268 format!("sampling coordinator: coalesced RPC failed: {err_msg}"),
269 None,
270 ),
271 ),
272 )));
273 }
274 }
275 }
276}
277
278fn build_coalesced_request(batch: &[Submission]) -> CreateMessageRequestParams {
307 let mut tasks: Vec<serde_json::Value> = Vec::with_capacity(batch.len());
308 let mut system_parts: Vec<String> = vec![
309 "You are a batch task processor. Process EVERY task listed in the \
310 user message and reply with a JSON array of objects where each \
311 object has shape: { \"task_index\": <int starting from 0>, \
312 \"response\": \"<string>\" }. The array MUST have exactly N entries \
313 (one per task) in the SAME ORDER. Do NOT include any prose outside \
314 the JSON."
315 .to_string(),
316 ];
317
318 for (idx, sub) in batch.iter().enumerate() {
319 let mut task_messages: Vec<serde_json::Value> = Vec::new();
320 if let Some(sys) = sub.params.system_prompt.as_ref() {
321 system_parts.push(format!("Task-{idx} sub-system: {sys}"));
324 }
325 for sm in &sub.params.messages {
326 let role_str = match sm.role {
327 RmcpRole::User => "user",
328 RmcpRole::Assistant => "assistant",
329 };
330 let mut text_parts: Vec<String> = Vec::new();
331 for content in sm.content.iter() {
332 if let SamplingMessageContent::Text(t) = content {
333 text_parts.push(t.text.clone());
334 }
335 }
336 task_messages.push(serde_json::json!({
337 "role": role_str,
338 "content": text_parts.join("\n"),
339 }));
340 }
341 tasks.push(serde_json::json!({
342 "task_index": idx,
343 "messages": task_messages,
344 }));
345 }
346
347 let user_payload =
348 serde_json::json!({ "tasks": tasks }).to_string();
349
350 let max_tokens = batch
351 .iter()
352 .map(|s| s.params.max_tokens)
353 .fold(0u32, |acc, n| acc.saturating_add(n));
354
355 let mut params = CreateMessageRequestParams::new(
356 vec![SamplingMessage::user_text(&user_payload)],
357 max_tokens.max(1),
358 );
359 params = params.with_system_prompt(system_parts.join("\n\n"));
360 if let Some(prefs) = batch[0].params.model_preferences.as_ref() {
362 params = params.with_model_preferences(prefs.clone());
363 }
364 params
365}
366
367fn demux_coalesced(
380 rendered: &CreateMessageResult,
381 batch: &[Submission],
382) -> Result<Vec<Result<CreateMessageResult, SamplingError>>, String> {
383 let text = extract_text_from_result(rendered).map_err(|e| e.to_string())?;
384 let parsed: serde_json::Value = match serde_json::from_str(&text) {
385 Ok(v) => v,
386 Err(e) => {
387 match extract_fenced_json(&text) {
391 Some(inner) => serde_json::from_str(inner)
392 .map_err(|fe| format!("fenced parse: {fe}"))?,
393 None => return Err(format!("top-level JSON parse: {e}")),
394 }
395 }
396 };
397 let arr = parsed
398 .as_array()
399 .ok_or_else(|| "response root is not a JSON array".to_string())?;
400
401 let mut out: Vec<Result<CreateMessageResult, SamplingError>> =
402 Vec::with_capacity(batch.len());
403 for (idx, _sub) in batch.iter().enumerate() {
404 let entry = arr.iter().find(|e| {
405 e.get("task_index")
406 .and_then(|v| v.as_i64())
407 .map(|i| i as usize == idx)
408 .unwrap_or(false)
409 });
410 match entry {
411 Some(e) => {
412 let response_text = e
413 .get("response")
414 .and_then(|v| v.as_str())
415 .unwrap_or("");
416 out.push(Ok(make_assistant_result(response_text, &rendered.model)));
417 }
418 None => out.push(Err(SamplingError::Service(
419 rmcp::service::ServiceError::McpError(
420 rmcp::model::ErrorData::internal_error(
421 format!(
422 "sampling coordinator: response missing task_index {idx}"
423 ),
424 None,
425 ),
426 ),
427 ))),
428 }
429 }
430 Ok(out)
431}
432
433fn extract_fenced_json(text: &str) -> Option<&str> {
434 let needle = "```json";
435 let start = text.find(needle)?;
436 let after = &text[start + needle.len()..];
437 let end = after.find("```")?;
438 Some(after[..end].trim())
439}
440
441fn extract_text_from_result(result: &CreateMessageResult) -> Result<String, &'static str> {
442 if result.message.role != RmcpRole::Assistant {
443 return Err("response role was not Assistant");
444 }
445 let mut out = String::new();
446 for content in result.message.content.iter() {
447 if let SamplingMessageContent::Text(text) = content {
448 if !out.is_empty() {
449 out.push('\n');
450 }
451 out.push_str(&text.text);
452 }
453 }
454 if out.is_empty() {
455 Err("no text content blocks")
456 } else {
457 Ok(out)
458 }
459}
460
461fn make_assistant_result(text: &str, model: &str) -> CreateMessageResult {
462 CreateMessageResult::new(
463 SamplingMessage::assistant_text(text.to_string()),
464 model.to_string(),
465 )
466}
467
468#[async_trait]
469impl SamplingClient for SamplingCoordinator {
470 async fn create_message(
471 &self,
472 params: CreateMessageRequestParams,
473 ) -> Result<CreateMessageResult, SamplingError> {
474 self.submit(params).await
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::test_support::{FakeMcpClient, FakeResponse};
482
483 fn mk_params(prompt: &str) -> CreateMessageRequestParams {
484 CreateMessageRequestParams::new(vec![SamplingMessage::user_text(prompt)], 128)
485 }
486
487 fn coalesced_response_for(n_tasks: usize) -> String {
488 let mut arr = Vec::with_capacity(n_tasks);
489 for i in 0..n_tasks {
490 arr.push(serde_json::json!({
491 "task_index": i,
492 "response": format!("response-{i}"),
493 }));
494 }
495 serde_json::to_string(&arr).unwrap()
496 }
497
498 #[tokio::test]
502 async fn coalesces_n_concurrent_submissions_into_one_create_message_call() {
503 let response = coalesced_response_for(3);
504 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
505 let coord = SamplingCoordinator::with_settings(
506 fake.clone(),
507 Duration::from_millis(100),
508 10,
509 );
510
511 let c1 = coord.clone();
513 let c2 = coord.clone();
514 let c3 = coord.clone();
515 let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
516 let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
517 let h3 = tokio::spawn(async move { c3.submit(mk_params("task-C")).await });
518
519 let r1 = h1.await.unwrap().expect("submission 1 ok");
520 let r2 = h2.await.unwrap().expect("submission 2 ok");
521 let r3 = h3.await.unwrap().expect("submission 3 ok");
522
523 let recorded = fake.record_requests();
525 assert_eq!(
526 recorded.len(),
527 1,
528 "coordinator must coalesce 3 submissions into 1 inner call"
529 );
530
531 assert_eq!(
533 extract_text_from_result(&r1).unwrap(),
534 "response-0",
535 "task-0 response routed to first submission"
536 );
537 assert_eq!(extract_text_from_result(&r2).unwrap(), "response-1");
538 assert_eq!(extract_text_from_result(&r3).unwrap(), "response-2");
539 }
540
541 #[tokio::test]
545 async fn flushes_at_max_batch_before_window_expires() {
546 let response = coalesced_response_for(2);
547 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
548 let coord = SamplingCoordinator::with_settings(
550 fake.clone(),
551 Duration::from_secs(5),
552 2,
553 );
554
555 let started = tokio::time::Instant::now();
556 let c1 = coord.clone();
557 let c2 = coord.clone();
558 let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
559 let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
560
561 let _ = h1.await.unwrap();
562 let _ = h2.await.unwrap();
563
564 let elapsed = started.elapsed();
565 assert!(
566 elapsed < Duration::from_secs(2),
567 "max_batch must flush before window expires; took {elapsed:?}"
568 );
569 assert_eq!(fake.record_requests().len(), 1);
570 }
571
572 #[tokio::test]
576 async fn single_submission_passes_through_unwrapped() {
577 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(
578 "direct-response",
579 )));
580 let coord = SamplingCoordinator::with_settings(
581 fake.clone(),
582 Duration::from_millis(50),
583 10,
584 );
585
586 let result = coord
587 .submit(mk_params("lonely-task"))
588 .await
589 .expect("submission ok");
590
591 let recorded = fake.record_requests();
594 assert_eq!(recorded.len(), 1);
595 let inner_text = extract_first_user_text(&recorded[0]);
596 assert_eq!(
597 inner_text, "lonely-task",
598 "single-batch path must NOT wrap the prompt"
599 );
600 assert_eq!(extract_text_from_result(&result).unwrap(), "direct-response");
602 }
603
604 #[tokio::test]
610 async fn window_expiry_flushes_each_submission_individually() {
611 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("r-first")));
612 fake.respond_each(vec![
613 FakeResponse::text("r-first"),
614 FakeResponse::text("r-second"),
615 ]);
616 let coord = SamplingCoordinator::with_settings(
617 fake.clone(),
618 Duration::from_millis(20),
619 10,
620 );
621
622 let r1 = coord
623 .submit(mk_params("first"))
624 .await
625 .expect("submission 1");
626 tokio::time::sleep(Duration::from_millis(50)).await;
629 let r2 = coord
630 .submit(mk_params("second"))
631 .await
632 .expect("submission 2");
633
634 assert_eq!(fake.record_requests().len(), 2);
635 assert_eq!(extract_text_from_result(&r1).unwrap(), "r-first");
636 assert_eq!(extract_text_from_result(&r2).unwrap(), "r-second");
637 }
638
639 #[tokio::test]
643 async fn demux_propagates_per_request_failures() {
644 let response = serde_json::json!([
646 { "task_index": 0, "response": "ok-0" },
647 { "task_index": 2, "response": "ok-2" },
648 ])
649 .to_string();
650 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
651 let coord = SamplingCoordinator::with_settings(
652 fake.clone(),
653 Duration::from_millis(100),
654 10,
655 );
656
657 let c1 = coord.clone();
658 let c2 = coord.clone();
659 let c3 = coord.clone();
660 let h1 = tokio::spawn(async move { c1.submit(mk_params("t0")).await });
661 let h2 = tokio::spawn(async move { c2.submit(mk_params("t1")).await });
662 let h3 = tokio::spawn(async move { c3.submit(mk_params("t2")).await });
663
664 let r1 = h1.await.unwrap();
665 let r2 = h2.await.unwrap();
666 let r3 = h3.await.unwrap();
667
668 assert!(r1.is_ok());
669 assert!(r2.is_err(), "missing task_index must surface as error");
670 assert!(r3.is_ok());
671 }
672
673 #[tokio::test]
678 async fn coalesced_rpc_failure_surfaces_to_every_submission() {
679 let fake = Arc::new(FakeMcpClient::new(FakeResponse::Error(
680 crate::test_support::FakeSamplingError::Transport {
681 message: "simulated transport failure".into(),
682 },
683 )));
684 let coord = SamplingCoordinator::with_settings(
685 fake.clone(),
686 Duration::from_millis(100),
687 10,
688 );
689
690 let c1 = coord.clone();
691 let c2 = coord.clone();
692 let h1 = tokio::spawn(async move { c1.submit(mk_params("a")).await });
693 let h2 = tokio::spawn(async move { c2.submit(mk_params("b")).await });
694
695 assert!(h1.await.unwrap().is_err());
696 assert!(h2.await.unwrap().is_err());
697 }
698
699 fn extract_first_user_text(params: &CreateMessageRequestParams) -> String {
700 for m in ¶ms.messages {
701 if m.role == RmcpRole::User {
702 for c in m.content.iter() {
703 if let SamplingMessageContent::Text(t) = c {
704 return t.text.clone();
705 }
706 }
707 }
708 }
709 String::new()
710 }
711}