1use std::sync::Arc;
52use std::time::Duration;
53
54use async_trait::async_trait;
55use rmcp::model::{
56 CreateMessageRequestParams, CreateMessageResult, Role as RmcpRole, SamplingMessage,
57 SamplingMessageContent,
58};
59use tokio::sync::{Mutex, mpsc, oneshot};
60
61use crate::llm::sampling::{SamplingClient, SamplingError};
62
63pub const DEFAULT_COALESCE_WINDOW: Duration = Duration::from_millis(5000);
67
68pub const DEFAULT_COALESCE_MAX_BATCH: usize = 10;
72
73pub struct SamplingCoordinator {
85 tx: mpsc::Sender<Submission>,
89 worker: Mutex<Option<tokio::task::JoinHandle<()>>>,
92}
93
94impl SamplingCoordinator {
95 pub fn new(inner: Arc<dyn SamplingClient>) -> Arc<Self> {
98 Self::with_settings(inner, DEFAULT_COALESCE_WINDOW, DEFAULT_COALESCE_MAX_BATCH)
99 }
100
101 pub fn with_settings(
106 inner: Arc<dyn SamplingClient>,
107 window: Duration,
108 max_batch: usize,
109 ) -> Arc<Self> {
110 let (tx, rx) = mpsc::channel::<Submission>(max_batch.max(1) * 2 + 16);
111 let worker = tokio::spawn(coordinator_worker(rx, inner, window, max_batch.max(1)));
112 Arc::new(Self {
113 tx,
114 worker: Mutex::new(Some(worker)),
115 })
116 }
117
118 pub async fn submit(
122 &self,
123 params: CreateMessageRequestParams,
124 ) -> Result<CreateMessageResult, SamplingError> {
125 let (reply_tx, reply_rx) = oneshot::channel();
126 self.tx
127 .send(Submission {
128 params,
129 reply: reply_tx,
130 })
131 .await
132 .map_err(|_| {
133 SamplingError::Service(rmcp::service::ServiceError::McpError(
134 rmcp::model::ErrorData::internal_error(
135 "sampling coordinator worker is gone (channel closed)",
136 None,
137 ),
138 ))
139 })?;
140 reply_rx.await.map_err(|_| {
141 SamplingError::Service(rmcp::service::ServiceError::McpError(
142 rmcp::model::ErrorData::internal_error(
143 "sampling coordinator worker dropped reply channel",
144 None,
145 ),
146 ))
147 })?
148 }
149
150 pub async fn shutdown(self: Arc<Self>) {
153 let mut guard = self.worker.lock().await;
155 if let Some(join) = guard.take() {
156 join.abort();
157 let _ = join.await;
158 }
159 }
160}
161
162struct Submission {
165 params: CreateMessageRequestParams,
166 reply: oneshot::Sender<Result<CreateMessageResult, SamplingError>>,
167}
168
169async fn coordinator_worker(
173 mut rx: mpsc::Receiver<Submission>,
174 inner: Arc<dyn SamplingClient>,
175 window: Duration,
176 max_batch: usize,
177) {
178 loop {
179 let first = match rx.recv().await {
182 Some(s) => s,
183 None => return,
184 };
185 let mut buffer: Vec<Submission> = vec![first];
186
187 let deadline = tokio::time::Instant::now() + window;
192 while buffer.len() < max_batch {
193 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
194 if remaining.is_zero() {
195 break;
196 }
197 match tokio::time::timeout(remaining, rx.recv()).await {
198 Ok(Some(s)) => buffer.push(s),
199 Ok(None) => {
200 flush_batch(&inner, buffer).await;
202 return;
203 }
204 Err(_) => break,
205 }
206 }
207
208 flush_batch(&inner, buffer).await;
209 }
210}
211
212async fn flush_batch(inner: &Arc<dyn SamplingClient>, batch: Vec<Submission>) {
216 if batch.is_empty() {
217 return;
218 }
219 if batch.len() == 1 {
220 let mut iter = batch.into_iter();
223 let s = iter.next().unwrap();
224 let result = inner.create_message(s.params).await;
225 let _ = s.reply.send(result);
226 return;
227 }
228
229 let coalesced = build_coalesced_request(&batch);
230 let result = inner.create_message(coalesced).await;
231
232 match result {
233 Ok(rendered) => {
234 match demux_coalesced(&rendered, &batch) {
239 Ok(per_task) => {
240 for (sub, task_result) in batch.into_iter().zip(per_task) {
241 let _ = sub.reply.send(task_result);
242 }
243 }
244 Err(parse_err) => {
245 let err_msg = format!(
246 "sampling coordinator: failed to parse coalesced response: {parse_err}"
247 );
248 for sub in batch {
249 let _ = sub.reply.send(Err(SamplingError::Service(
250 rmcp::service::ServiceError::McpError(
251 rmcp::model::ErrorData::internal_error(err_msg.clone(), None),
252 ),
253 )));
254 }
255 }
256 }
257 }
258 Err(e) => {
259 let err_msg = format!("{e}");
264 for sub in batch {
265 let _ = sub.reply.send(Err(SamplingError::Service(
266 rmcp::service::ServiceError::McpError(rmcp::model::ErrorData::internal_error(
267 format!("sampling coordinator: coalesced RPC failed: {err_msg}"),
268 None,
269 )),
270 )));
271 }
272 }
273 }
274}
275
276fn build_coalesced_request(batch: &[Submission]) -> CreateMessageRequestParams {
305 let mut tasks: Vec<serde_json::Value> = Vec::with_capacity(batch.len());
306 let mut system_parts: Vec<String> = vec![
307 "You are a batch task processor. Process EVERY task listed in the \
308 user message and reply with a JSON array of objects where each \
309 object has shape: { \"task_index\": <int starting from 0>, \
310 \"response\": \"<string>\" }. The array MUST have exactly N entries \
311 (one per task) in the SAME ORDER. Do NOT include any prose outside \
312 the JSON."
313 .to_string(),
314 ];
315
316 for (idx, sub) in batch.iter().enumerate() {
317 let mut task_messages: Vec<serde_json::Value> = Vec::new();
318 if let Some(sys) = sub.params.system_prompt.as_ref() {
319 system_parts.push(format!("Task-{idx} sub-system: {sys}"));
322 }
323 for sm in &sub.params.messages {
324 let role_str = match sm.role {
325 RmcpRole::User => "user",
326 RmcpRole::Assistant => "assistant",
327 };
328 let mut text_parts: Vec<String> = Vec::new();
329 for content in sm.content.iter() {
330 if let SamplingMessageContent::Text(t) = content {
331 text_parts.push(t.text.clone());
332 }
333 }
334 task_messages.push(serde_json::json!({
335 "role": role_str,
336 "content": text_parts.join("\n"),
337 }));
338 }
339 tasks.push(serde_json::json!({
340 "task_index": idx,
341 "messages": task_messages,
342 }));
343 }
344
345 let user_payload = serde_json::json!({ "tasks": tasks }).to_string();
346
347 let max_tokens = batch
348 .iter()
349 .map(|s| s.params.max_tokens)
350 .fold(0u32, |acc, n| acc.saturating_add(n));
351
352 let mut params = CreateMessageRequestParams::new(
353 vec![SamplingMessage::user_text(&user_payload)],
354 max_tokens.max(1),
355 );
356 params = params.with_system_prompt(system_parts.join("\n\n"));
357 if let Some(prefs) = batch[0].params.model_preferences.as_ref() {
359 params = params.with_model_preferences(prefs.clone());
360 }
361 params
362}
363
364fn demux_coalesced(
377 rendered: &CreateMessageResult,
378 batch: &[Submission],
379) -> Result<Vec<Result<CreateMessageResult, SamplingError>>, String> {
380 let text = extract_text_from_result(rendered).map_err(|e| e.to_string())?;
381 let parsed: serde_json::Value = match serde_json::from_str(&text) {
382 Ok(v) => v,
383 Err(e) => {
384 match extract_fenced_json(&text) {
388 Some(inner) => {
389 serde_json::from_str(inner).map_err(|fe| format!("fenced parse: {fe}"))?
390 }
391 None => return Err(format!("top-level JSON parse: {e}")),
392 }
393 }
394 };
395 let arr = parsed
396 .as_array()
397 .ok_or_else(|| "response root is not a JSON array".to_string())?;
398
399 let mut out: Vec<Result<CreateMessageResult, SamplingError>> = Vec::with_capacity(batch.len());
400 for (idx, _sub) in batch.iter().enumerate() {
401 let entry = arr.iter().find(|e| {
402 e.get("task_index")
403 .and_then(|v| v.as_i64())
404 .map(|i| i as usize == idx)
405 .unwrap_or(false)
406 });
407 match entry {
408 Some(e) => {
409 let response_text = e.get("response").and_then(|v| v.as_str()).unwrap_or("");
410 out.push(Ok(make_assistant_result(response_text, &rendered.model)));
411 }
412 None => out.push(Err(SamplingError::Service(
413 rmcp::service::ServiceError::McpError(rmcp::model::ErrorData::internal_error(
414 format!("sampling coordinator: response missing task_index {idx}"),
415 None,
416 )),
417 ))),
418 }
419 }
420 Ok(out)
421}
422
423fn extract_fenced_json(text: &str) -> Option<&str> {
424 let needle = "```json";
425 let start = text.find(needle)?;
426 let after = &text[start + needle.len()..];
427 let end = after.find("```")?;
428 Some(after[..end].trim())
429}
430
431fn extract_text_from_result(result: &CreateMessageResult) -> Result<String, &'static str> {
432 if result.message.role != RmcpRole::Assistant {
433 return Err("response role was not Assistant");
434 }
435 let mut out = String::new();
436 for content in result.message.content.iter() {
437 if let SamplingMessageContent::Text(text) = content {
438 if !out.is_empty() {
439 out.push('\n');
440 }
441 out.push_str(&text.text);
442 }
443 }
444 if out.is_empty() {
445 Err("no text content blocks")
446 } else {
447 Ok(out)
448 }
449}
450
451fn make_assistant_result(text: &str, model: &str) -> CreateMessageResult {
452 CreateMessageResult::new(
453 SamplingMessage::assistant_text(text.to_string()),
454 model.to_string(),
455 )
456}
457
458#[async_trait]
459impl SamplingClient for SamplingCoordinator {
460 async fn create_message(
461 &self,
462 params: CreateMessageRequestParams,
463 ) -> Result<CreateMessageResult, SamplingError> {
464 self.submit(params).await
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::test_support::{FakeMcpClient, FakeResponse};
472
473 fn mk_params(prompt: &str) -> CreateMessageRequestParams {
474 CreateMessageRequestParams::new(vec![SamplingMessage::user_text(prompt)], 128)
475 }
476
477 fn coalesced_response_for(n_tasks: usize) -> String {
478 let mut arr = Vec::with_capacity(n_tasks);
479 for i in 0..n_tasks {
480 arr.push(serde_json::json!({
481 "task_index": i,
482 "response": format!("response-{i}"),
483 }));
484 }
485 serde_json::to_string(&arr).unwrap()
486 }
487
488 #[tokio::test]
492 async fn coalesces_n_concurrent_submissions_into_one_create_message_call() {
493 let response = coalesced_response_for(3);
494 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
495 let coord =
496 SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
497
498 let c1 = coord.clone();
500 let c2 = coord.clone();
501 let c3 = coord.clone();
502 let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
503 let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
504 let h3 = tokio::spawn(async move { c3.submit(mk_params("task-C")).await });
505
506 let r1 = h1.await.unwrap().expect("submission 1 ok");
507 let r2 = h2.await.unwrap().expect("submission 2 ok");
508 let r3 = h3.await.unwrap().expect("submission 3 ok");
509
510 let recorded = fake.record_requests();
512 assert_eq!(
513 recorded.len(),
514 1,
515 "coordinator must coalesce 3 submissions into 1 inner call"
516 );
517
518 assert_eq!(
520 extract_text_from_result(&r1).unwrap(),
521 "response-0",
522 "task-0 response routed to first submission"
523 );
524 assert_eq!(extract_text_from_result(&r2).unwrap(), "response-1");
525 assert_eq!(extract_text_from_result(&r3).unwrap(), "response-2");
526 }
527
528 #[tokio::test]
532 async fn flushes_at_max_batch_before_window_expires() {
533 let response = coalesced_response_for(2);
534 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
535 let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_secs(5), 2);
537
538 let started = tokio::time::Instant::now();
539 let c1 = coord.clone();
540 let c2 = coord.clone();
541 let h1 = tokio::spawn(async move { c1.submit(mk_params("task-A")).await });
542 let h2 = tokio::spawn(async move { c2.submit(mk_params("task-B")).await });
543
544 let _ = h1.await.unwrap();
545 let _ = h2.await.unwrap();
546
547 let elapsed = started.elapsed();
548 assert!(
549 elapsed < Duration::from_secs(2),
550 "max_batch must flush before window expires; took {elapsed:?}"
551 );
552 assert_eq!(fake.record_requests().len(), 1);
553 }
554
555 #[tokio::test]
559 async fn single_submission_passes_through_unwrapped() {
560 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("direct-response")));
561 let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(50), 10);
562
563 let result = coord
564 .submit(mk_params("lonely-task"))
565 .await
566 .expect("submission ok");
567
568 let recorded = fake.record_requests();
571 assert_eq!(recorded.len(), 1);
572 let inner_text = extract_first_user_text(&recorded[0]);
573 assert_eq!(
574 inner_text, "lonely-task",
575 "single-batch path must NOT wrap the prompt"
576 );
577 assert_eq!(
579 extract_text_from_result(&result).unwrap(),
580 "direct-response"
581 );
582 }
583
584 #[tokio::test]
590 async fn window_expiry_flushes_each_submission_individually() {
591 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("r-first")));
592 fake.respond_each(vec![
593 FakeResponse::text("r-first"),
594 FakeResponse::text("r-second"),
595 ]);
596 let coord = SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(20), 10);
597
598 let r1 = coord
599 .submit(mk_params("first"))
600 .await
601 .expect("submission 1");
602 tokio::time::sleep(Duration::from_millis(50)).await;
605 let r2 = coord
606 .submit(mk_params("second"))
607 .await
608 .expect("submission 2");
609
610 assert_eq!(fake.record_requests().len(), 2);
611 assert_eq!(extract_text_from_result(&r1).unwrap(), "r-first");
612 assert_eq!(extract_text_from_result(&r2).unwrap(), "r-second");
613 }
614
615 #[tokio::test]
619 async fn demux_propagates_per_request_failures() {
620 let response = serde_json::json!([
622 { "task_index": 0, "response": "ok-0" },
623 { "task_index": 2, "response": "ok-2" },
624 ])
625 .to_string();
626 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
627 let coord =
628 SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
629
630 let c1 = coord.clone();
631 let c2 = coord.clone();
632 let c3 = coord.clone();
633 let h1 = tokio::spawn(async move { c1.submit(mk_params("t0")).await });
634 let h2 = tokio::spawn(async move { c2.submit(mk_params("t1")).await });
635 let h3 = tokio::spawn(async move { c3.submit(mk_params("t2")).await });
636
637 let r1 = h1.await.unwrap();
638 let r2 = h2.await.unwrap();
639 let r3 = h3.await.unwrap();
640
641 assert!(r1.is_ok());
642 assert!(r2.is_err(), "missing task_index must surface as error");
643 assert!(r3.is_ok());
644 }
645
646 #[tokio::test]
651 async fn coalesced_rpc_failure_surfaces_to_every_submission() {
652 let fake = Arc::new(FakeMcpClient::new(FakeResponse::Error(
653 crate::test_support::FakeSamplingError::Transport {
654 message: "simulated transport failure".into(),
655 },
656 )));
657 let coord =
658 SamplingCoordinator::with_settings(fake.clone(), Duration::from_millis(100), 10);
659
660 let c1 = coord.clone();
661 let c2 = coord.clone();
662 let h1 = tokio::spawn(async move { c1.submit(mk_params("a")).await });
663 let h2 = tokio::spawn(async move { c2.submit(mk_params("b")).await });
664
665 assert!(h1.await.unwrap().is_err());
666 assert!(h2.await.unwrap().is_err());
667 }
668
669 fn extract_first_user_text(params: &CreateMessageRequestParams) -> String {
670 for m in ¶ms.messages {
671 if m.role == RmcpRole::User {
672 for c in m.content.iter() {
673 if let SamplingMessageContent::Text(t) = c {
674 return t.text.clone();
675 }
676 }
677 }
678 }
679 String::new()
680 }
681}