1use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44 CreateMessageRequestParams, CreateMessageResult, ModelHint,
45 ModelPreferences, Role as RmcpRole, SamplingMessage,
46 SamplingMessageContent,
47};
48use rmcp::service::{Peer, RoleServer, ServiceError};
49use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
50use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
51
52pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
61
62const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
67
68#[derive(Debug)]
73pub enum SamplingError {
74 Service(ServiceError),
76 #[cfg(any(test, feature = "test-support"))]
79 Fake(crate::test_support::FakeSamplingError),
80}
81
82impl std::fmt::Display for SamplingError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 Self::Service(e) => write!(f, "{e}"),
86 #[cfg(any(test, feature = "test-support"))]
87 Self::Fake(e) => write!(f, "{e}"),
88 }
89 }
90}
91
92impl std::error::Error for SamplingError {}
93
94impl SamplingError {
95 pub fn classify(&self) -> (&'static str, bool) {
103 match self {
104 Self::Service(_) => ("transport_error", false),
105 #[cfg(any(test, feature = "test-support"))]
106 Self::Fake(e) => match e {
107 crate::test_support::FakeSamplingError::Refused { .. } => {
108 ("client_refused", true)
109 }
110 crate::test_support::FakeSamplingError::Transport { .. } => {
111 ("transport_error", false)
112 }
113 crate::test_support::FakeSamplingError::MalformedResponse {
114 ..
115 } => ("malformed_response", false),
116 },
117 }
118 }
119}
120
121#[async_trait]
129pub trait SamplingClient: Send + Sync {
130 async fn create_message(
131 &self,
132 params: CreateMessageRequestParams,
133 ) -> Result<CreateMessageResult, SamplingError>;
134}
135
136pub struct PeerSamplingClient {
140 peer: Peer<RoleServer>,
141}
142
143impl PeerSamplingClient {
144 pub fn new(peer: Peer<RoleServer>) -> Self {
145 Self { peer }
146 }
147}
148
149#[async_trait]
150impl SamplingClient for PeerSamplingClient {
151 async fn create_message(
152 &self,
153 params: CreateMessageRequestParams,
154 ) -> Result<CreateMessageResult, SamplingError> {
155 self.peer
156 .create_message(params)
157 .await
158 .map_err(SamplingError::Service)
159 }
160}
161
162#[derive(Clone)]
173pub struct SamplingLlmClient {
174 sampling_client: Arc<dyn SamplingClient>,
176 write_handle: WriteHandle,
181 audit_principal: Option<String>,
185 max_tokens: u32,
189 timeout: Duration,
192}
193
194impl SamplingLlmClient {
195 pub fn new(
200 peer: Peer<RoleServer>,
201 write_handle: WriteHandle,
202 audit_principal: Option<String>,
203 ) -> Self {
204 Self::with_sampling_client(
205 Arc::new(PeerSamplingClient::new(peer)),
206 write_handle,
207 audit_principal,
208 )
209 }
210
211 pub fn with_sampling_client(
216 sampling_client: Arc<dyn SamplingClient>,
217 write_handle: WriteHandle,
218 audit_principal: Option<String>,
219 ) -> Self {
220 Self {
221 sampling_client,
222 write_handle,
223 audit_principal,
224 max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
225 timeout: DEFAULT_SAMPLING_TIMEOUT,
226 }
227 }
228
229 pub fn with_max_tokens(mut self, n: u32) -> Self {
231 self.max_tokens = n.max(1);
232 self
233 }
234
235 pub fn with_timeout(mut self, t: Duration) -> Self {
237 self.timeout = t;
238 self
239 }
240
241 fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
246 let mut system_parts: Vec<String> = Vec::new();
250 let mut samp_messages: Vec<SamplingMessage> = Vec::new();
251 for m in messages {
252 match m.role {
253 Role::System => system_parts.push(m.content.clone()),
254 Role::User => {
255 samp_messages.push(SamplingMessage::user_text(&m.content));
256 }
257 Role::Assistant => {
258 samp_messages
259 .push(SamplingMessage::assistant_text(&m.content));
260 }
261 }
262 }
263 let preferences = ModelPreferences::new()
266 .with_hints(vec![ModelHint::new("claude")])
267 .with_intelligence_priority(0.7)
268 .with_speed_priority(0.3)
269 .with_cost_priority(0.4);
270 let mut params =
271 CreateMessageRequestParams::new(samp_messages, self.max_tokens)
272 .with_model_preferences(preferences);
273 if !system_parts.is_empty() {
274 params = params.with_system_prompt(system_parts.join("\n\n"));
275 }
276 params
277 }
278
279 fn audit_event(
284 &self,
285 params: &CreateMessageRequestParams,
286 outcome: SamplingOutcome,
287 ) -> AuditEvent {
288 let prompt_chars: usize = params
289 .messages
290 .iter()
291 .flat_map(|m| m.content.iter())
292 .filter_map(|c| c.as_text().map(|t| t.text.len()))
293 .sum::<usize>()
294 + params
295 .system_prompt
296 .as_ref()
297 .map(|s| s.len())
298 .unwrap_or(0);
299 let input_tokens_est = (prompt_chars / 4) as u64;
303 let model_hint = params
304 .model_preferences
305 .as_ref()
306 .and_then(|p| p.hints.as_ref())
307 .and_then(|h| h.first())
308 .and_then(|h| h.name.clone())
309 .unwrap_or_else(|| "(none)".to_string());
310
311 let mut details = serde_json::Map::new();
312 details.insert(
313 "model_hint".to_string(),
314 serde_json::Value::String(model_hint),
315 );
316 details.insert(
317 "messages_count".to_string(),
318 serde_json::Value::Number(params.messages.len().into()),
319 );
320 details.insert(
321 "max_tokens".to_string(),
322 serde_json::Value::Number(params.max_tokens.into()),
323 );
324 details.insert(
325 "prompt_chars".to_string(),
326 serde_json::Value::Number(prompt_chars.into()),
327 );
328 details.insert(
329 "input_tokens_est".to_string(),
330 serde_json::Value::Number(input_tokens_est.into()),
331 );
332
333 let result = match &outcome {
334 SamplingOutcome::Ok {
335 duration_ms,
336 model,
337 output_chars,
338 } => {
339 let output_tokens_est = (*output_chars / 4) as u64;
340 details.insert(
341 "duration_ms".to_string(),
342 serde_json::Value::Number((*duration_ms).into()),
343 );
344 details.insert(
345 "model".to_string(),
346 serde_json::Value::String(model.clone()),
347 );
348 details.insert(
349 "output_chars".to_string(),
350 serde_json::Value::Number((*output_chars).into()),
351 );
352 details.insert(
353 "output_tokens_est".to_string(),
354 serde_json::Value::Number(output_tokens_est.into()),
355 );
356 AuditResult::Ok
357 }
358 SamplingOutcome::Forbidden {
359 reason,
360 duration_ms,
361 } => {
362 details.insert(
363 "duration_ms".to_string(),
364 serde_json::Value::Number((*duration_ms).into()),
365 );
366 details.insert(
367 "reason".to_string(),
368 serde_json::Value::String(reason.to_string()),
369 );
370 AuditResult::Forbidden
371 }
372 SamplingOutcome::Error {
373 reason,
374 duration_ms,
375 } => {
376 details.insert(
377 "duration_ms".to_string(),
378 serde_json::Value::Number((*duration_ms).into()),
379 );
380 details.insert(
381 "reason".to_string(),
382 serde_json::Value::String(reason.to_string()),
383 );
384 AuditResult::Error
385 }
386 };
387
388 AuditEvent {
389 ts_ms: chrono::Utc::now().timestamp_millis(),
390 principal_subject: self.audit_principal.clone(),
391 operation: AuditOperation::LlmSamplingCall,
392 target_id: None,
393 result,
394 details: Some(serde_json::Value::Object(details)),
395 }
396 }
397}
398
399enum SamplingOutcome {
401 Ok {
402 duration_ms: u64,
403 model: String,
404 output_chars: usize,
405 },
406 Forbidden {
407 reason: &'static str,
408 duration_ms: u64,
409 },
410 Error {
411 reason: &'static str,
412 duration_ms: u64,
413 },
414}
415
416#[async_trait]
417impl LlmClient for SamplingLlmClient {
418 fn name(&self) -> &str {
419 "mcp-sampling"
420 }
421
422 async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
423 let params = self.build_request(messages);
424 let start = Instant::now();
425
426 let rpc = tokio::time::timeout(
430 self.timeout,
431 self.sampling_client.create_message(params.clone()),
432 )
433 .await;
434 let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
435 as u64;
436
437 let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
438 match rpc {
439 Ok(Ok(result)) => {
440 match extract_text(&result) {
441 Ok(text) => {
442 let output_chars = text.len();
443 let outcome = SamplingOutcome::Ok {
444 duration_ms,
445 model: result.model.clone(),
446 output_chars,
447 };
448 (Ok(Message::assistant(text)), outcome)
449 }
450 Err(reason) => (
451 Err(CoreError::llm(format!(
452 "mcp sampling: malformed response: {reason}"
453 ))),
454 SamplingOutcome::Error {
455 reason: "malformed_response",
456 duration_ms,
457 },
458 ),
459 }
460 }
461 Ok(Err(e)) => {
462 let (category, is_forbidden) = e.classify();
463 let outcome = if is_forbidden {
464 SamplingOutcome::Forbidden {
465 reason: category,
466 duration_ms,
467 }
468 } else {
469 SamplingOutcome::Error {
470 reason: category,
471 duration_ms,
472 }
473 };
474 let err = if is_forbidden {
475 CoreError::forbidden(format!("mcp sampling: {e}"))
476 } else {
477 CoreError::llm(format!("mcp sampling: {e}"))
478 };
479 (Err(err), outcome)
480 }
481 Err(_elapsed) => (
482 Err(CoreError::llm(format!(
483 "mcp sampling: timeout after {}ms",
484 duration_ms
485 ))),
486 SamplingOutcome::Error {
487 reason: "timeout",
488 duration_ms,
489 },
490 ),
491 };
492
493 let event = self.audit_event(¶ms, outcome);
499 if let Err(audit_err) = self.write_handle.emit_llm_sampling_audit(event).await
500 {
501 return Err(CoreError::storage(format!(
506 "mcp sampling: audit emit failed: {audit_err}"
507 )));
508 }
509
510 core_result
511 }
512}
513
514fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
520 if result.message.role != RmcpRole::Assistant {
521 return Err("response role was not Assistant");
522 }
523 let mut out = String::new();
524 for content in result.message.content.iter() {
525 if let SamplingMessageContent::Text(text) = content {
526 if !out.is_empty() {
527 out.push('\n');
528 }
529 out.push_str(&text.text);
530 }
531 }
532 if out.is_empty() {
533 Err("no text content blocks")
534 } else {
535 Ok(out)
536 }
537}
538
539pub fn build_sampling_steward(
568 peer: Peer<RoleServer>,
569 write_handle: WriteHandle,
570 audit_principal: Option<String>,
571 steward_config: solo_steward::StewardConfig,
572 sampling_config: solo_storage::SamplingConfig,
573) -> Arc<solo_steward::Steward> {
574 let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
575 let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
576 inner,
577 std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
578 sampling_config.coalesce_max_requests as usize,
579 );
580 let client = SamplingLlmClient::with_sampling_client(
581 coordinator,
582 write_handle,
583 audit_principal,
584 )
585 .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
586 Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
593 use rmcp::model::CreateMessageResult;
594 use solo_core::TenantId;
595 use solo_storage::{
596 EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
597 TenantHandle, TenantRegistry, TenantRegistryParams, init,
598 open_sqlcipher,
599 };
600 use std::path::PathBuf;
601 use std::sync::Arc;
602 use tempfile::TempDir;
603 use zeroize::Zeroizing;
604
605 const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
606
607 struct Harness {
617 _tmp: TempDir,
618 _registry: Arc<TenantRegistry>,
619 _tenant: Arc<TenantHandle>,
620 write_handle: solo_storage::WriteHandle,
621 db_path: PathBuf,
622 key: KeyMaterial,
623 }
624
625 async fn harness() -> Harness {
626 let tmp = TempDir::new().expect("tempdir");
627 let data_dir = tmp.path().to_path_buf();
628 let _ = init(InitParams {
629 data_dir: data_dir.clone(),
630 passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
631 force: false,
632 embedder: EmbedderConfig {
633 name: "stub".into(),
634 version: "v1".into(),
635 dim: 32,
636 dtype: "f32".into(),
637 },
638 })
639 .expect("init");
640
641 let cfg = solo_storage::SoloConfig::read(
642 &data_dir.join("solo.config.toml"),
643 )
644 .expect("read cfg");
645 let key = KeyMaterial::derive(
646 TEST_PASSPHRASE,
647 &cfg.salt_bytes().expect("salt"),
648 )
649 .expect("derive key");
650
651 let embedder: Arc<dyn solo_core::Embedder> =
652 Arc::new(StubEmbedder::new("stub", "v1", 32));
653 let registry = Arc::new(
654 TenantRegistry::open(TenantRegistryParams {
655 data_dir: data_dir.clone(),
656 key: key.clone(),
657 embedder: embedder.clone(),
658 hnsw_params: HnswParams::default(),
659 steward: None,
660 runtime_handle: Some(tokio::runtime::Handle::current()),
661 steward_factory: None,
662 triples_batch_signal: None,
663 })
664 .expect("open registry"),
665 );
666
667 let tenant_id = TenantId::default_tenant();
668 let tenant = registry
669 .get_or_open(&tenant_id)
670 .await
671 .expect("get_or_open default tenant");
672 let write_handle = tenant.write().clone();
673 let db_path = tenant.db_path().to_path_buf();
674
675 Harness {
676 _tmp: tmp,
677 _registry: registry,
678 _tenant: tenant,
679 write_handle,
680 db_path,
681 key,
682 }
683 }
684
685 fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
689 let conn = open_sqlcipher(db_path, key).expect("open db");
690 conn.query_row(
691 "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
692 rusqlite::params![op],
693 |r| r.get(0),
694 )
695 .expect("count")
696 }
697
698 fn latest_sampling_audit_details(
701 db_path: &std::path::Path,
702 key: &KeyMaterial,
703 ) -> (String, Option<String>, serde_json::Value) {
704 let conn = open_sqlcipher(db_path, key).expect("open db");
705 let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
706 .query_row(
707 "SELECT result, principal_subject, details_json
708 FROM audit_events
709 WHERE operation = 'llm.sampling_call'
710 ORDER BY ts_ms DESC, rowid DESC
711 LIMIT 1",
712 [],
713 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
714 )
715 .expect("query");
716 let details: serde_json::Value =
717 serde_json::from_str(&details_str.expect("details_json present"))
718 .expect("parse details");
719 (result, principal, details)
720 }
721
722 #[tokio::test]
726 async fn sampling_complete_happy_path_returns_text() {
727 let h = harness().await;
728 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
729 let client = SamplingLlmClient::with_sampling_client(
730 fake.clone(),
731 h.write_handle.clone(),
732 Some("alice".into()),
733 );
734 let messages = vec![Message::user("summarise these episodes")];
735 let result = client.complete(&messages).await.expect("ok");
736 assert_eq!(result.role, Role::Assistant);
737 assert_eq!(result.content, "derived theme");
738
739 assert_eq!(
741 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
742 1
743 );
744 let (result_str, principal, details) =
745 latest_sampling_audit_details(&h.db_path, &h.key);
746 assert_eq!(result_str, "ok");
747 assert_eq!(principal.as_deref(), Some("alice"));
748 assert_eq!(details["model_hint"], "claude");
749 assert_eq!(details["model"], "fake-claude");
750 assert_eq!(details["messages_count"], 1);
751 assert_eq!(details["max_tokens"], 512);
752 }
753
754 #[tokio::test]
758 async fn audit_row_omits_raw_prompt_text() {
759 let h = harness().await;
760 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
761 let client = SamplingLlmClient::with_sampling_client(
762 fake,
763 h.write_handle.clone(),
764 None,
765 );
766 let secret = "THE-USER-ID-IS-bobby-1234";
767 let messages = vec![
768 Message::system("you are a friendly assistant"),
769 Message::user(secret),
770 ];
771 client.complete(&messages).await.expect("ok");
772
773 let (_, _, details) =
774 latest_sampling_audit_details(&h.db_path, &h.key);
775 let serialised =
776 serde_json::to_string(&details).expect("serialise details");
777 assert!(
778 !serialised.contains(secret),
779 "audit details must not carry raw prompt content; was: {serialised}"
780 );
781 assert!(
782 !serialised.contains("you are a friendly assistant"),
783 "audit details must not carry system prompt; was: {serialised}"
784 );
785 assert_eq!(details["messages_count"], 1);
787 assert!(details["prompt_chars"].as_u64().unwrap() > 0);
788 }
789
790 #[tokio::test]
793 async fn client_refusal_returns_forbidden_and_audits() {
794 let h = harness().await;
795 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
796 fake.reject_with("user dismissed approval");
797 let client = SamplingLlmClient::with_sampling_client(
798 fake,
799 h.write_handle.clone(),
800 Some("alice".into()),
801 );
802 let err = client
803 .complete(&[Message::user("anything")])
804 .await
805 .unwrap_err();
806 match err {
807 CoreError::Forbidden(_) => {}
808 other => panic!("expected Forbidden, got {other:?}"),
809 }
810 let (result_str, _, details) =
811 latest_sampling_audit_details(&h.db_path, &h.key);
812 assert_eq!(result_str, "forbidden");
813 assert_eq!(details["reason"], "client_refused");
814 }
815
816 #[tokio::test]
824 async fn timeout_returns_error_with_timeout_reason() {
825 let h = harness().await;
826 let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
827 "late",
828 Duration::from_millis(800),
829 )));
830 let client = SamplingLlmClient::with_sampling_client(
831 fake,
832 h.write_handle.clone(),
833 None,
834 )
835 .with_timeout(Duration::from_millis(30));
836 let err = client
837 .complete(&[Message::user("hello")])
838 .await
839 .unwrap_err();
840 match err {
841 CoreError::Llm(msg) => assert!(msg.contains("timeout")),
842 other => panic!("expected Llm, got {other:?}"),
843 }
844 let (result_str, _, details) =
845 latest_sampling_audit_details(&h.db_path, &h.key);
846 assert_eq!(result_str, "error");
847 assert_eq!(details["reason"], "timeout");
848 }
849
850 #[tokio::test]
854 async fn malformed_response_returns_error_with_reason() {
855 let h = harness().await;
856 let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
857 let client = SamplingLlmClient::with_sampling_client(
858 fake,
859 h.write_handle.clone(),
860 None,
861 );
862 let err = client
863 .complete(&[Message::user("hi")])
864 .await
865 .unwrap_err();
866 assert!(matches!(err, CoreError::Llm(_)));
867 let (result_str, _, details) =
868 latest_sampling_audit_details(&h.db_path, &h.key);
869 assert_eq!(result_str, "error");
870 assert_eq!(details["reason"], "malformed_response");
871 }
872
873 #[tokio::test]
876 async fn no_principal_emits_audit_with_null_principal() {
877 let h = harness().await;
878 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
879 let client = SamplingLlmClient::with_sampling_client(
880 fake,
881 h.write_handle.clone(),
882 None,
883 );
884 client.complete(&[Message::user("hi")]).await.expect("ok");
885 let (_, principal, _) =
886 latest_sampling_audit_details(&h.db_path, &h.key);
887 assert_eq!(principal, None);
888 }
889
890 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
895 async fn parallel_completes_serialise_audit_rows() {
896 let h = harness().await;
897 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
898 let client = SamplingLlmClient::with_sampling_client(
899 fake.clone(),
900 h.write_handle.clone(),
901 Some("alice".into()),
902 );
903 let mut futs = Vec::new();
904 for _ in 0..8 {
905 let c = client.clone();
906 futs.push(tokio::spawn(async move {
907 c.complete(&[Message::user("hi")]).await
908 }));
909 }
910 for f in futs {
911 f.await.expect("join").expect("ok");
912 }
913 assert_eq!(
914 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
915 8,
916 "8 parallel calls must land 8 audit rows"
917 );
918
919 assert_eq!(fake.record_requests().len(), 8);
921 }
922
923 #[tokio::test]
927 async fn build_request_splits_system_from_messages() {
928 let h = harness().await;
929 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
930 let client = SamplingLlmClient::with_sampling_client(
931 fake.clone(),
932 h.write_handle.clone(),
933 None,
934 );
935 client
936 .complete(&[
937 Message::system("be terse"),
938 Message::user("question"),
939 Message::assistant("answer"),
940 ])
941 .await
942 .expect("ok");
943 let recorded = fake.record_requests();
944 assert_eq!(recorded.len(), 1);
945 let req = &recorded[0];
946 assert_eq!(
947 req.system_prompt.as_deref(),
948 Some("be terse"),
949 "Role::System must map to system_prompt"
950 );
951 assert_eq!(req.messages.len(), 2);
952 assert_eq!(req.messages[0].role, RmcpRole::User);
954 assert_eq!(req.messages[1].role, RmcpRole::Assistant);
955 }
956
957 #[tokio::test]
960 async fn build_request_includes_claude_model_hint() {
961 let h = harness().await;
962 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
963 let client = SamplingLlmClient::with_sampling_client(
964 fake.clone(),
965 h.write_handle.clone(),
966 None,
967 );
968 client
969 .complete(&[Message::user("hi")])
970 .await
971 .expect("ok");
972 let recorded = fake.record_requests();
973 let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
974 let hint = prefs
975 .hints
976 .as_ref()
977 .and_then(|h| h.first())
978 .and_then(|h| h.name.clone())
979 .expect("hint name");
980 assert_eq!(hint, "claude");
981 }
982
983 #[tokio::test]
986 async fn with_max_tokens_overrides_default() {
987 let h = harness().await;
988 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
989 let client = SamplingLlmClient::with_sampling_client(
990 fake.clone(),
991 h.write_handle.clone(),
992 None,
993 )
994 .with_max_tokens(2048);
995 client
996 .complete(&[Message::user("hi")])
997 .await
998 .expect("ok");
999 let recorded = fake.record_requests();
1000 assert_eq!(recorded[0].max_tokens, 2048);
1001 }
1002
1003 #[tokio::test]
1006 async fn reconfigurable_fake_distinguishes_audit_rows() {
1007 let h = harness().await;
1008 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1009 let client = SamplingLlmClient::with_sampling_client(
1010 fake.clone(),
1011 h.write_handle.clone(),
1012 Some("alice".into()),
1013 );
1014
1015 client.complete(&[Message::user("a")]).await.expect("ok");
1016 fake.reject_with("user said no");
1017 let _ = client.complete(&[Message::user("b")]).await;
1018
1019 let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1020 let mut stmt = conn
1021 .prepare(
1022 "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1023 )
1024 .expect("prepare");
1025 let rows: Vec<String> = stmt
1026 .query_map([], |r| r.get::<_, String>(0))
1027 .expect("query")
1028 .map(|r| r.expect("row"))
1029 .collect();
1030 assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1031 }
1032
1033 #[test]
1035 fn extract_text_pulls_text_from_single_block() {
1036 let result = CreateMessageResult::new(
1037 SamplingMessage::assistant_text("hello"),
1038 "fake".into(),
1039 );
1040 assert_eq!(extract_text(&result).unwrap(), "hello");
1041 }
1042
1043 #[test]
1045 fn extract_text_rejects_empty_content() {
1046 let result = CreateMessageResult::new(
1047 SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1048 "fake".into(),
1049 );
1050 assert!(extract_text(&result).is_err());
1051 }
1052
1053 #[test]
1056 fn extract_text_rejects_non_assistant_role() {
1057 let result = CreateMessageResult::new(
1058 SamplingMessage::user_text("hello"),
1059 "fake".into(),
1060 );
1061 assert!(extract_text(&result).is_err());
1062 }
1063
1064 #[test]
1067 fn sampling_error_classify_maps_fake_variants() {
1068 let refused = SamplingError::Fake(FakeSamplingError::Refused {
1069 reason: "x".into(),
1070 });
1071 let (cat, forb) = refused.classify();
1072 assert_eq!(cat, "client_refused");
1073 assert!(forb);
1074
1075 let transport = SamplingError::Fake(FakeSamplingError::Transport {
1076 message: "x".into(),
1077 });
1078 let (cat, forb) = transport.classify();
1079 assert_eq!(cat, "transport_error");
1080 assert!(!forb);
1081
1082 let malformed =
1083 SamplingError::Fake(FakeSamplingError::MalformedResponse {
1084 message: "x".into(),
1085 });
1086 let (cat, forb) = malformed.classify();
1087 assert_eq!(cat, "malformed_response");
1088 assert!(!forb);
1089 }
1090
1091 #[tokio::test]
1117 async fn sampling_llm_client_uses_coordinator_in_production_path() {
1118 let h = harness().await;
1119 let fake: Arc<dyn SamplingClient> =
1120 Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1121 let coord: Arc<dyn SamplingClient> =
1122 super::super::SamplingCoordinator::with_settings(
1123 fake.clone(),
1124 Duration::from_millis(50),
1125 10,
1126 );
1127 let client = SamplingLlmClient::with_sampling_client(
1128 coord,
1129 h.write_handle.clone(),
1130 Some("alice".into()),
1131 );
1132 let result = client
1133 .complete(&[Message::user("test")])
1134 .await
1135 .expect("ok");
1136 assert_eq!(result.role, Role::Assistant);
1137 assert_eq!(result.content, "ok");
1138 assert_eq!(
1141 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1142 1,
1143 "one logical call → one audit row, even through coordinator"
1144 );
1145 }
1146
1147 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1158 async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1159 let response = serde_json::to_string(&(0..5)
1163 .map(|i| serde_json::json!({
1164 "task_index": i,
1165 "response": format!("response-{i}"),
1166 }))
1167 .collect::<Vec<_>>())
1168 .unwrap();
1169
1170 let h = harness().await;
1171 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1172 let coord: Arc<dyn SamplingClient> =
1173 super::super::SamplingCoordinator::with_settings(
1174 fake.clone(),
1175 Duration::from_secs(5),
1177 10,
1178 );
1179 let client = SamplingLlmClient::with_sampling_client(
1180 coord,
1181 h.write_handle.clone(),
1182 Some("alice".into()),
1183 );
1184
1185 let mut futs = Vec::new();
1188 for i in 0..5 {
1189 let c = client.clone();
1190 futs.push(tokio::spawn(async move {
1191 c.complete(&[Message::user(format!("task-{i}"))]).await
1192 }));
1193 }
1194 for f in futs {
1195 f.await.expect("join").expect("ok");
1196 }
1197
1198 assert_eq!(
1200 fake.record_requests().len(),
1201 1,
1202 "5 logical calls within window must coalesce to 1 inner RPC"
1203 );
1204 assert_eq!(
1206 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1207 5,
1208 "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1209 );
1210 }
1211
1212 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1223 async fn coordinator_max_batch_one_acts_as_passthrough() {
1224 let h = harness().await;
1225 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1226 let coord: Arc<dyn SamplingClient> =
1227 super::super::SamplingCoordinator::with_settings(
1228 fake.clone(),
1229 Duration::from_secs(5),
1230 1,
1233 );
1234 let client = SamplingLlmClient::with_sampling_client(
1235 coord,
1236 h.write_handle.clone(),
1237 None,
1238 );
1239 let mut futs = Vec::new();
1240 for _ in 0..3 {
1241 let c = client.clone();
1242 futs.push(tokio::spawn(async move {
1243 c.complete(&[Message::user("hi")]).await
1244 }));
1245 }
1246 for f in futs {
1247 f.await.expect("join").expect("ok");
1248 }
1249 assert_eq!(
1251 fake.record_requests().len(),
1252 3,
1253 "max_batch=1 must pass through every submission as its own RPC"
1254 );
1255 assert_eq!(
1256 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1257 3
1258 );
1259 }
1260}