1use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42use async_trait::async_trait;
43use rmcp::model::{
44 CreateMessageRequestParams, CreateMessageResult, ModelHint,
45 ModelPreferences, Role as RmcpRole, SamplingMessage,
46 SamplingMessageContent,
47};
48use rmcp::service::{Peer, RoleServer, ServiceError};
49use solo_core::{Error as CoreError, LlmClient, Message, Result as CoreResult, Role};
50use solo_storage::{AuditEvent, AuditOperation, AuditResult, WriteHandle};
51
52pub const DEFAULT_SAMPLING_TIMEOUT: Duration = Duration::from_secs(30);
61
62const DEFAULT_SAMPLING_MAX_TOKENS: u32 = 512;
67
68#[derive(Debug)]
73pub enum SamplingError {
74 Service(ServiceError),
76 #[cfg(any(test, feature = "test-support"))]
79 Fake(crate::test_support::FakeSamplingError),
80}
81
82impl std::fmt::Display for SamplingError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 Self::Service(e) => write!(f, "{e}"),
86 #[cfg(any(test, feature = "test-support"))]
87 Self::Fake(e) => write!(f, "{e}"),
88 }
89 }
90}
91
92impl std::error::Error for SamplingError {}
93
94impl SamplingError {
95 pub fn classify(&self) -> (&'static str, bool) {
103 match self {
104 Self::Service(_) => ("transport_error", false),
105 #[cfg(any(test, feature = "test-support"))]
106 Self::Fake(e) => match e {
107 crate::test_support::FakeSamplingError::Refused { .. } => {
108 ("client_refused", true)
109 }
110 crate::test_support::FakeSamplingError::Transport { .. } => {
111 ("transport_error", false)
112 }
113 crate::test_support::FakeSamplingError::MalformedResponse {
114 ..
115 } => ("malformed_response", false),
116 },
117 }
118 }
119}
120
121#[async_trait]
129pub trait SamplingClient: Send + Sync {
130 async fn create_message(
131 &self,
132 params: CreateMessageRequestParams,
133 ) -> Result<CreateMessageResult, SamplingError>;
134}
135
136pub struct PeerSamplingClient {
140 peer: Peer<RoleServer>,
141}
142
143impl PeerSamplingClient {
144 pub fn new(peer: Peer<RoleServer>) -> Self {
145 Self { peer }
146 }
147}
148
149#[async_trait]
150impl SamplingClient for PeerSamplingClient {
151 async fn create_message(
152 &self,
153 params: CreateMessageRequestParams,
154 ) -> Result<CreateMessageResult, SamplingError> {
155 self.peer
156 .create_message(params)
157 .await
158 .map_err(SamplingError::Service)
159 }
160}
161
162#[derive(Clone)]
173pub struct SamplingLlmClient {
174 sampling_client: Arc<dyn SamplingClient>,
176 write_handle: WriteHandle,
181 audit_principal: Option<String>,
185 max_tokens: u32,
189 timeout: Duration,
192}
193
194impl SamplingLlmClient {
195 pub fn new(
200 peer: Peer<RoleServer>,
201 write_handle: WriteHandle,
202 audit_principal: Option<String>,
203 ) -> Self {
204 Self::with_sampling_client(
205 Arc::new(PeerSamplingClient::new(peer)),
206 write_handle,
207 audit_principal,
208 )
209 }
210
211 pub fn with_sampling_client(
216 sampling_client: Arc<dyn SamplingClient>,
217 write_handle: WriteHandle,
218 audit_principal: Option<String>,
219 ) -> Self {
220 Self {
221 sampling_client,
222 write_handle,
223 audit_principal,
224 max_tokens: DEFAULT_SAMPLING_MAX_TOKENS,
225 timeout: DEFAULT_SAMPLING_TIMEOUT,
226 }
227 }
228
229 pub fn with_max_tokens(mut self, n: u32) -> Self {
231 self.max_tokens = n.max(1);
232 self
233 }
234
235 pub fn with_timeout(mut self, t: Duration) -> Self {
237 self.timeout = t;
238 self
239 }
240
241 fn build_request(&self, messages: &[Message]) -> CreateMessageRequestParams {
246 let mut system_parts: Vec<String> = Vec::new();
250 let mut samp_messages: Vec<SamplingMessage> = Vec::new();
251 for m in messages {
252 match m.role {
253 Role::System => system_parts.push(m.content.clone()),
254 Role::User => {
255 samp_messages.push(SamplingMessage::user_text(&m.content));
256 }
257 Role::Assistant => {
258 samp_messages
259 .push(SamplingMessage::assistant_text(&m.content));
260 }
261 }
262 }
263 let preferences = ModelPreferences::new()
266 .with_hints(vec![ModelHint::new("claude")])
267 .with_intelligence_priority(0.7)
268 .with_speed_priority(0.3)
269 .with_cost_priority(0.4);
270 let mut params =
271 CreateMessageRequestParams::new(samp_messages, self.max_tokens)
272 .with_model_preferences(preferences);
273 if !system_parts.is_empty() {
274 params = params.with_system_prompt(system_parts.join("\n\n"));
275 }
276 params
277 }
278
279 fn audit_event(
296 &self,
297 params: &CreateMessageRequestParams,
298 outcome: SamplingOutcome,
299 ) -> AuditEvent {
300 let raw_prompt_chars: usize = params
301 .messages
302 .iter()
303 .flat_map(|m| m.content.iter())
304 .filter_map(|c| c.as_text().map(|t| t.text.len()))
305 .sum::<usize>()
306 + params
307 .system_prompt
308 .as_ref()
309 .map(|s| s.len())
310 .unwrap_or(0);
311 let prompt_chars = next_pow2_bucket(raw_prompt_chars);
314 let input_tokens_est = next_pow2_bucket(raw_prompt_chars / 4) as u64;
319 let model_hint = params
320 .model_preferences
321 .as_ref()
322 .and_then(|p| p.hints.as_ref())
323 .and_then(|h| h.first())
324 .and_then(|h| h.name.clone())
325 .unwrap_or_else(|| "(none)".to_string());
326
327 let mut details = serde_json::Map::new();
328 details.insert(
329 "model_hint".to_string(),
330 serde_json::Value::String(model_hint),
331 );
332 details.insert(
333 "messages_count".to_string(),
334 serde_json::Value::Number(params.messages.len().into()),
335 );
336 details.insert(
337 "max_tokens".to_string(),
338 serde_json::Value::Number(params.max_tokens.into()),
339 );
340 details.insert(
341 "prompt_chars".to_string(),
342 serde_json::Value::Number(prompt_chars.into()),
343 );
344 details.insert(
345 "input_tokens_est".to_string(),
346 serde_json::Value::Number(input_tokens_est.into()),
347 );
348
349 let result = match &outcome {
350 SamplingOutcome::Ok {
351 duration_ms,
352 model,
353 output_chars,
354 } => {
355 let bucketed_output_chars = next_pow2_bucket(*output_chars);
362 let output_tokens_est = next_pow2_bucket(*output_chars / 4) as u64;
363 details.insert(
364 "duration_ms".to_string(),
365 serde_json::Value::Number((*duration_ms).into()),
366 );
367 details.insert(
368 "model".to_string(),
369 serde_json::Value::String(model.clone()),
370 );
371 details.insert(
372 "output_chars".to_string(),
373 serde_json::Value::Number(bucketed_output_chars.into()),
374 );
375 details.insert(
376 "output_tokens_est".to_string(),
377 serde_json::Value::Number(output_tokens_est.into()),
378 );
379 AuditResult::Ok
380 }
381 SamplingOutcome::Forbidden {
382 reason,
383 duration_ms,
384 } => {
385 details.insert(
386 "duration_ms".to_string(),
387 serde_json::Value::Number((*duration_ms).into()),
388 );
389 details.insert(
390 "reason".to_string(),
391 serde_json::Value::String(reason.to_string()),
392 );
393 AuditResult::Forbidden
394 }
395 SamplingOutcome::Error {
396 reason,
397 duration_ms,
398 } => {
399 details.insert(
400 "duration_ms".to_string(),
401 serde_json::Value::Number((*duration_ms).into()),
402 );
403 details.insert(
404 "reason".to_string(),
405 serde_json::Value::String(reason.to_string()),
406 );
407 AuditResult::Error
408 }
409 };
410
411 AuditEvent {
412 ts_ms: chrono::Utc::now().timestamp_millis(),
413 principal_subject: self.audit_principal.clone(),
414 operation: AuditOperation::LlmSamplingCall,
415 target_id: None,
416 result,
417 details: Some(serde_json::Value::Object(details)),
418 }
419 }
420}
421
422enum SamplingOutcome {
424 Ok {
425 duration_ms: u64,
426 model: String,
427 output_chars: usize,
428 },
429 Forbidden {
430 reason: &'static str,
431 duration_ms: u64,
432 },
433 Error {
434 reason: &'static str,
435 duration_ms: u64,
436 },
437}
438
439#[async_trait]
440impl LlmClient for SamplingLlmClient {
441 fn name(&self) -> &str {
442 "mcp-sampling"
443 }
444
445 async fn complete(&self, messages: &[Message]) -> CoreResult<Message> {
446 let params = self.build_request(messages);
447 let start = Instant::now();
448
449 let rpc = tokio::time::timeout(
453 self.timeout,
454 self.sampling_client.create_message(params.clone()),
455 )
456 .await;
457 let duration_ms = start.elapsed().as_millis().min(u128::from(u64::MAX))
458 as u64;
459
460 let (core_result, outcome): (CoreResult<Message>, SamplingOutcome) =
461 match rpc {
462 Ok(Ok(result)) => {
463 match extract_text(&result) {
464 Ok(text) => {
465 let output_chars = text.len();
466 let outcome = SamplingOutcome::Ok {
467 duration_ms,
468 model: result.model.clone(),
469 output_chars,
470 };
471 (Ok(Message::assistant(text)), outcome)
472 }
473 Err(reason) => (
474 Err(CoreError::llm(format!(
475 "mcp sampling: malformed response: {reason}"
476 ))),
477 SamplingOutcome::Error {
478 reason: "malformed_response",
479 duration_ms,
480 },
481 ),
482 }
483 }
484 Ok(Err(e)) => {
485 let (category, is_forbidden) = e.classify();
486 let outcome = if is_forbidden {
487 SamplingOutcome::Forbidden {
488 reason: category,
489 duration_ms,
490 }
491 } else {
492 SamplingOutcome::Error {
493 reason: category,
494 duration_ms,
495 }
496 };
497 let err = if is_forbidden {
498 CoreError::forbidden(format!("mcp sampling: {e}"))
499 } else {
500 CoreError::llm(format!("mcp sampling: {e}"))
501 };
502 (Err(err), outcome)
503 }
504 Err(_elapsed) => (
505 Err(CoreError::llm(format!(
506 "mcp sampling: timeout after {}ms",
507 duration_ms
508 ))),
509 SamplingOutcome::Error {
510 reason: "timeout",
511 duration_ms,
512 },
513 ),
514 };
515
516 let event = self.audit_event(¶ms, outcome);
541 match (
542 core_result,
543 self.write_handle.emit_llm_sampling_audit(event).await,
544 ) {
545 (Ok(text), Ok(())) => Ok(text),
546 (Ok(_text), Err(audit_err)) => {
547 Err(CoreError::storage(format!(
552 "mcp sampling: audit emit failed: {audit_err}"
553 )))
554 }
555 (Err(core_err), Ok(())) => Err(core_err),
556 (Err(core_err), Err(audit_err)) => {
557 tracing::error!(
561 audit_error = %audit_err,
562 core_error = %core_err,
563 "mcp sampling: audit emit failed alongside core \
564 error; surfacing core error to caller"
565 );
566 Err(core_err)
567 }
568 }
569 }
570}
571
572fn next_pow2_bucket(n: usize) -> usize {
587 if n == 0 {
588 return 0;
589 }
590 n.next_power_of_two()
595}
596
597fn extract_text(result: &CreateMessageResult) -> Result<String, &'static str> {
603 if result.message.role != RmcpRole::Assistant {
604 return Err("response role was not Assistant");
605 }
606 let mut out = String::new();
607 for content in result.message.content.iter() {
608 if let SamplingMessageContent::Text(text) = content {
609 if !out.is_empty() {
610 out.push('\n');
611 }
612 out.push_str(&text.text);
613 }
614 }
615 if out.is_empty() {
616 Err("no text content blocks")
617 } else {
618 Ok(out)
619 }
620}
621
622pub fn build_sampling_steward(
651 peer: Peer<RoleServer>,
652 write_handle: WriteHandle,
653 audit_principal: Option<String>,
654 steward_config: solo_steward::StewardConfig,
655 sampling_config: solo_storage::SamplingConfig,
656) -> Arc<solo_steward::Steward> {
657 let inner: Arc<dyn SamplingClient> = Arc::new(PeerSamplingClient::new(peer));
658 let coordinator: Arc<dyn SamplingClient> = super::SamplingCoordinator::with_settings(
659 inner,
660 std::time::Duration::from_millis(sampling_config.coalesce_window_ms),
661 sampling_config.coalesce_max_requests as usize,
662 );
663 let client = SamplingLlmClient::with_sampling_client(
664 coordinator,
665 write_handle,
666 audit_principal,
667 )
668 .with_max_tokens(steward_config.abstraction_max_tokens.min(65_536) as u32);
669 Arc::new(solo_steward::Steward::new(Arc::new(client), steward_config))
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use crate::test_support::{FakeMcpClient, FakeResponse, FakeSamplingError};
676 use rmcp::model::CreateMessageResult;
677 use solo_core::TenantId;
678 use solo_storage::{
679 EmbedderConfig, HnswParams, InitParams, KeyMaterial, StubEmbedder,
680 TenantHandle, TenantRegistry, TenantRegistryParams, init,
681 open_sqlcipher,
682 };
683 use std::path::PathBuf;
684 use std::sync::Arc;
685 use tempfile::TempDir;
686 use zeroize::Zeroizing;
687
688 const TEST_PASSPHRASE: &str = "v0.9.0-p2-sampling-tests";
689
690 struct Harness {
700 _tmp: TempDir,
701 _registry: Arc<TenantRegistry>,
702 _tenant: Arc<TenantHandle>,
703 write_handle: solo_storage::WriteHandle,
704 db_path: PathBuf,
705 key: KeyMaterial,
706 }
707
708 async fn harness() -> Harness {
709 let tmp = TempDir::new().expect("tempdir");
710 let data_dir = tmp.path().to_path_buf();
711 let _ = init(InitParams {
712 data_dir: data_dir.clone(),
713 passphrase: Zeroizing::new(TEST_PASSPHRASE.into()),
714 force: false,
715 embedder: EmbedderConfig {
716 name: "stub".into(),
717 version: "v1".into(),
718 dim: 32,
719 dtype: "f32".into(),
720 },
721 })
722 .expect("init");
723
724 let cfg = solo_storage::SoloConfig::read(
725 &data_dir.join("solo.config.toml"),
726 )
727 .expect("read cfg");
728 let key = KeyMaterial::derive(
729 TEST_PASSPHRASE,
730 &cfg.salt_bytes().expect("salt"),
731 )
732 .expect("derive key");
733
734 let embedder: Arc<dyn solo_core::Embedder> =
735 Arc::new(StubEmbedder::new("stub", "v1", 32));
736 let registry = Arc::new(
737 TenantRegistry::open(TenantRegistryParams {
738 data_dir: data_dir.clone(),
739 key: key.clone(),
740 embedder: embedder.clone(),
741 hnsw_params: HnswParams::default(),
742 steward: None,
743 runtime_handle: Some(tokio::runtime::Handle::current()),
744 steward_factory: None,
745 triples_batch_signal: None,
746 })
747 .expect("open registry"),
748 );
749
750 let tenant_id = TenantId::default_tenant();
751 let tenant = registry
752 .get_or_open(&tenant_id)
753 .await
754 .expect("get_or_open default tenant");
755 let write_handle = tenant.write().clone();
756 let db_path = tenant.db_path().to_path_buf();
757
758 Harness {
759 _tmp: tmp,
760 _registry: registry,
761 _tenant: tenant,
762 write_handle,
763 db_path,
764 key,
765 }
766 }
767
768 fn count_audit_rows(db_path: &std::path::Path, key: &KeyMaterial, op: &str) -> i64 {
772 let conn = open_sqlcipher(db_path, key).expect("open db");
773 conn.query_row(
774 "SELECT COUNT(*) FROM audit_events WHERE operation = ?",
775 rusqlite::params![op],
776 |r| r.get(0),
777 )
778 .expect("count")
779 }
780
781 fn latest_sampling_audit_details(
784 db_path: &std::path::Path,
785 key: &KeyMaterial,
786 ) -> (String, Option<String>, serde_json::Value) {
787 let conn = open_sqlcipher(db_path, key).expect("open db");
788 let (result, principal, details_str): (String, Option<String>, Option<String>) = conn
789 .query_row(
790 "SELECT result, principal_subject, details_json
791 FROM audit_events
792 WHERE operation = 'llm.sampling_call'
793 ORDER BY ts_ms DESC, rowid DESC
794 LIMIT 1",
795 [],
796 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?)),
797 )
798 .expect("query");
799 let details: serde_json::Value =
800 serde_json::from_str(&details_str.expect("details_json present"))
801 .expect("parse details");
802 (result, principal, details)
803 }
804
805 #[tokio::test]
809 async fn sampling_complete_happy_path_returns_text() {
810 let h = harness().await;
811 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("derived theme")));
812 let client = SamplingLlmClient::with_sampling_client(
813 fake.clone(),
814 h.write_handle.clone(),
815 Some("alice".into()),
816 );
817 let messages = vec![Message::user("summarise these episodes")];
818 let result = client.complete(&messages).await.expect("ok");
819 assert_eq!(result.role, Role::Assistant);
820 assert_eq!(result.content, "derived theme");
821
822 assert_eq!(
824 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
825 1
826 );
827 let (result_str, principal, details) =
828 latest_sampling_audit_details(&h.db_path, &h.key);
829 assert_eq!(result_str, "ok");
830 assert_eq!(principal.as_deref(), Some("alice"));
831 assert_eq!(details["model_hint"], "claude");
832 assert_eq!(details["model"], "fake-claude");
833 assert_eq!(details["messages_count"], 1);
834 assert_eq!(details["max_tokens"], 512);
835 }
836
837 #[tokio::test]
841 async fn audit_row_omits_raw_prompt_text() {
842 let h = harness().await;
843 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
844 let client = SamplingLlmClient::with_sampling_client(
845 fake,
846 h.write_handle.clone(),
847 None,
848 );
849 let secret = "THE-USER-ID-IS-bobby-1234";
850 let messages = vec![
851 Message::system("you are a friendly assistant"),
852 Message::user(secret),
853 ];
854 client.complete(&messages).await.expect("ok");
855
856 let (_, _, details) =
857 latest_sampling_audit_details(&h.db_path, &h.key);
858 let serialised =
859 serde_json::to_string(&details).expect("serialise details");
860 assert!(
861 !serialised.contains(secret),
862 "audit details must not carry raw prompt content; was: {serialised}"
863 );
864 assert!(
865 !serialised.contains("you are a friendly assistant"),
866 "audit details must not carry system prompt; was: {serialised}"
867 );
868 assert_eq!(details["messages_count"], 1);
870 assert!(details["prompt_chars"].as_u64().unwrap() > 0);
871 }
872
873 #[tokio::test]
882 async fn audit_row_bucket_prompt_chars_to_pow2() {
883 let h = harness().await;
884 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
885 let client = SamplingLlmClient::with_sampling_client(
886 fake,
887 h.write_handle.clone(),
888 None,
889 );
890 client
892 .complete(&[Message::system("hello "), Message::user("x")])
893 .await
894 .expect("ok");
895 let (_, _, details) =
896 latest_sampling_audit_details(&h.db_path, &h.key);
897 assert_eq!(
898 details["prompt_chars"].as_u64().unwrap(),
899 8,
900 "prompt_chars must be bucketed to next pow2 (7 → 8). \
901 raw count is a privacy side-channel; see Fix 4 F6 in \
902 v0.9.1 P1 dev log. got details={details}"
903 );
904 }
905
906 #[tokio::test]
915 async fn audit_row_bucket_prompt_chars_is_stable_within_bucket() {
916 let h = harness().await;
917 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
918 let client = SamplingLlmClient::with_sampling_client(
919 fake,
920 h.write_handle.clone(),
921 None,
922 );
923 client
925 .complete(&[Message::user("hello")])
926 .await
927 .expect("ok");
928 let (_, _, details_5) =
929 latest_sampling_audit_details(&h.db_path, &h.key);
930 client
932 .complete(&[Message::user("hellooo")])
933 .await
934 .expect("ok");
935 let (_, _, details_7) =
936 latest_sampling_audit_details(&h.db_path, &h.key);
937 assert_eq!(
938 details_5["prompt_chars"], details_7["prompt_chars"],
939 "5 chars and 7 chars must hash to the same bucket (8) — \
940 otherwise the bucketing is leaking raw fidelity. \
941 5-char details: {details_5}, 7-char details: {details_7}"
942 );
943 assert_eq!(details_5["prompt_chars"].as_u64().unwrap(), 8);
944 }
945
946 #[test]
949 fn next_pow2_bucket_table() {
950 assert_eq!(next_pow2_bucket(0), 0, "0 stays 0");
951 assert_eq!(next_pow2_bucket(1), 1, "1 stays 1");
952 assert_eq!(next_pow2_bucket(2), 2, "2 stays 2");
953 assert_eq!(next_pow2_bucket(3), 4, "3 rounds up to 4");
954 assert_eq!(next_pow2_bucket(4), 4, "4 stays 4");
955 assert_eq!(next_pow2_bucket(5), 8);
956 assert_eq!(next_pow2_bucket(6), 8, "6-char prompt (brief case) → 8");
957 assert_eq!(next_pow2_bucket(7), 8);
958 assert_eq!(next_pow2_bucket(8), 8);
959 assert_eq!(next_pow2_bucket(9), 16);
960 assert_eq!(next_pow2_bucket(1023), 1024);
961 assert_eq!(next_pow2_bucket(1024), 1024);
962 assert_eq!(next_pow2_bucket(1025), 2048);
963 }
964
965 #[tokio::test]
968 async fn client_refusal_returns_forbidden_and_audits() {
969 let h = harness().await;
970 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ignored")));
971 fake.reject_with("user dismissed approval");
972 let client = SamplingLlmClient::with_sampling_client(
973 fake,
974 h.write_handle.clone(),
975 Some("alice".into()),
976 );
977 let err = client
978 .complete(&[Message::user("anything")])
979 .await
980 .unwrap_err();
981 match err {
982 CoreError::Forbidden(_) => {}
983 other => panic!("expected Forbidden, got {other:?}"),
984 }
985 let (result_str, _, details) =
986 latest_sampling_audit_details(&h.db_path, &h.key);
987 assert_eq!(result_str, "forbidden");
988 assert_eq!(details["reason"], "client_refused");
989 }
990
991 #[tokio::test]
999 async fn timeout_returns_error_with_timeout_reason() {
1000 let h = harness().await;
1001 let fake = Arc::new(FakeMcpClient::new(FakeResponse::slow(
1002 "late",
1003 Duration::from_millis(800),
1004 )));
1005 let client = SamplingLlmClient::with_sampling_client(
1006 fake,
1007 h.write_handle.clone(),
1008 None,
1009 )
1010 .with_timeout(Duration::from_millis(30));
1011 let err = client
1012 .complete(&[Message::user("hello")])
1013 .await
1014 .unwrap_err();
1015 match err {
1016 CoreError::Llm(msg) => assert!(msg.contains("timeout")),
1017 other => panic!("expected Llm, got {other:?}"),
1018 }
1019 let (result_str, _, details) =
1020 latest_sampling_audit_details(&h.db_path, &h.key);
1021 assert_eq!(result_str, "error");
1022 assert_eq!(details["reason"], "timeout");
1023 }
1024
1025 #[tokio::test]
1029 async fn malformed_response_returns_error_with_reason() {
1030 let h = harness().await;
1031 let fake = Arc::new(FakeMcpClient::new(FakeResponse::EmptyContent));
1032 let client = SamplingLlmClient::with_sampling_client(
1033 fake,
1034 h.write_handle.clone(),
1035 None,
1036 );
1037 let err = client
1038 .complete(&[Message::user("hi")])
1039 .await
1040 .unwrap_err();
1041 assert!(matches!(err, CoreError::Llm(_)));
1042 let (result_str, _, details) =
1043 latest_sampling_audit_details(&h.db_path, &h.key);
1044 assert_eq!(result_str, "error");
1045 assert_eq!(details["reason"], "malformed_response");
1046 }
1047
1048 #[tokio::test]
1051 async fn no_principal_emits_audit_with_null_principal() {
1052 let h = harness().await;
1053 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1054 let client = SamplingLlmClient::with_sampling_client(
1055 fake,
1056 h.write_handle.clone(),
1057 None,
1058 );
1059 client.complete(&[Message::user("hi")]).await.expect("ok");
1060 let (_, principal, _) =
1061 latest_sampling_audit_details(&h.db_path, &h.key);
1062 assert_eq!(principal, None);
1063 }
1064
1065 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1070 async fn parallel_completes_serialise_audit_rows() {
1071 let h = harness().await;
1072 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1073 let client = SamplingLlmClient::with_sampling_client(
1074 fake.clone(),
1075 h.write_handle.clone(),
1076 Some("alice".into()),
1077 );
1078 let mut futs = Vec::new();
1079 for _ in 0..8 {
1080 let c = client.clone();
1081 futs.push(tokio::spawn(async move {
1082 c.complete(&[Message::user("hi")]).await
1083 }));
1084 }
1085 for f in futs {
1086 f.await.expect("join").expect("ok");
1087 }
1088 assert_eq!(
1089 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1090 8,
1091 "8 parallel calls must land 8 audit rows"
1092 );
1093
1094 assert_eq!(fake.record_requests().len(), 8);
1096 }
1097
1098 #[tokio::test]
1102 async fn build_request_splits_system_from_messages() {
1103 let h = harness().await;
1104 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1105 let client = SamplingLlmClient::with_sampling_client(
1106 fake.clone(),
1107 h.write_handle.clone(),
1108 None,
1109 );
1110 client
1111 .complete(&[
1112 Message::system("be terse"),
1113 Message::user("question"),
1114 Message::assistant("answer"),
1115 ])
1116 .await
1117 .expect("ok");
1118 let recorded = fake.record_requests();
1119 assert_eq!(recorded.len(), 1);
1120 let req = &recorded[0];
1121 assert_eq!(
1122 req.system_prompt.as_deref(),
1123 Some("be terse"),
1124 "Role::System must map to system_prompt"
1125 );
1126 assert_eq!(req.messages.len(), 2);
1127 assert_eq!(req.messages[0].role, RmcpRole::User);
1129 assert_eq!(req.messages[1].role, RmcpRole::Assistant);
1130 }
1131
1132 #[tokio::test]
1135 async fn build_request_includes_claude_model_hint() {
1136 let h = harness().await;
1137 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1138 let client = SamplingLlmClient::with_sampling_client(
1139 fake.clone(),
1140 h.write_handle.clone(),
1141 None,
1142 );
1143 client
1144 .complete(&[Message::user("hi")])
1145 .await
1146 .expect("ok");
1147 let recorded = fake.record_requests();
1148 let prefs = recorded[0].model_preferences.as_ref().expect("prefs");
1149 let hint = prefs
1150 .hints
1151 .as_ref()
1152 .and_then(|h| h.first())
1153 .and_then(|h| h.name.clone())
1154 .expect("hint name");
1155 assert_eq!(hint, "claude");
1156 }
1157
1158 #[tokio::test]
1161 async fn with_max_tokens_overrides_default() {
1162 let h = harness().await;
1163 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1164 let client = SamplingLlmClient::with_sampling_client(
1165 fake.clone(),
1166 h.write_handle.clone(),
1167 None,
1168 )
1169 .with_max_tokens(2048);
1170 client
1171 .complete(&[Message::user("hi")])
1172 .await
1173 .expect("ok");
1174 let recorded = fake.record_requests();
1175 assert_eq!(recorded[0].max_tokens, 2048);
1176 }
1177
1178 #[tokio::test]
1181 async fn reconfigurable_fake_distinguishes_audit_rows() {
1182 let h = harness().await;
1183 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1184 let client = SamplingLlmClient::with_sampling_client(
1185 fake.clone(),
1186 h.write_handle.clone(),
1187 Some("alice".into()),
1188 );
1189
1190 client.complete(&[Message::user("a")]).await.expect("ok");
1191 fake.reject_with("user said no");
1192 let _ = client.complete(&[Message::user("b")]).await;
1193
1194 let conn = open_sqlcipher(&h.db_path, &h.key).expect("open");
1195 let mut stmt = conn
1196 .prepare(
1197 "SELECT result FROM audit_events WHERE operation = 'llm.sampling_call' ORDER BY ts_ms ASC, rowid ASC",
1198 )
1199 .expect("prepare");
1200 let rows: Vec<String> = stmt
1201 .query_map([], |r| r.get::<_, String>(0))
1202 .expect("query")
1203 .map(|r| r.expect("row"))
1204 .collect();
1205 assert_eq!(rows, vec!["ok".to_string(), "forbidden".to_string()]);
1206 }
1207
1208 #[test]
1210 fn extract_text_pulls_text_from_single_block() {
1211 let result = CreateMessageResult::new(
1212 SamplingMessage::assistant_text("hello"),
1213 "fake".into(),
1214 );
1215 assert_eq!(extract_text(&result).unwrap(), "hello");
1216 }
1217
1218 #[test]
1220 fn extract_text_rejects_empty_content() {
1221 let result = CreateMessageResult::new(
1222 SamplingMessage::new_multiple(RmcpRole::Assistant, Vec::new()),
1223 "fake".into(),
1224 );
1225 assert!(extract_text(&result).is_err());
1226 }
1227
1228 #[test]
1231 fn extract_text_rejects_non_assistant_role() {
1232 let result = CreateMessageResult::new(
1233 SamplingMessage::user_text("hello"),
1234 "fake".into(),
1235 );
1236 assert!(extract_text(&result).is_err());
1237 }
1238
1239 #[test]
1242 fn sampling_error_classify_maps_fake_variants() {
1243 let refused = SamplingError::Fake(FakeSamplingError::Refused {
1244 reason: "x".into(),
1245 });
1246 let (cat, forb) = refused.classify();
1247 assert_eq!(cat, "client_refused");
1248 assert!(forb);
1249
1250 let transport = SamplingError::Fake(FakeSamplingError::Transport {
1251 message: "x".into(),
1252 });
1253 let (cat, forb) = transport.classify();
1254 assert_eq!(cat, "transport_error");
1255 assert!(!forb);
1256
1257 let malformed =
1258 SamplingError::Fake(FakeSamplingError::MalformedResponse {
1259 message: "x".into(),
1260 });
1261 let (cat, forb) = malformed.classify();
1262 assert_eq!(cat, "malformed_response");
1263 assert!(!forb);
1264 }
1265
1266 #[tokio::test]
1292 async fn sampling_llm_client_uses_coordinator_in_production_path() {
1293 let h = harness().await;
1294 let fake: Arc<dyn SamplingClient> =
1295 Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1296 let coord: Arc<dyn SamplingClient> =
1297 super::super::SamplingCoordinator::with_settings(
1298 fake.clone(),
1299 Duration::from_millis(50),
1300 10,
1301 );
1302 let client = SamplingLlmClient::with_sampling_client(
1303 coord,
1304 h.write_handle.clone(),
1305 Some("alice".into()),
1306 );
1307 let result = client
1308 .complete(&[Message::user("test")])
1309 .await
1310 .expect("ok");
1311 assert_eq!(result.role, Role::Assistant);
1312 assert_eq!(result.content, "ok");
1313 assert_eq!(
1316 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1317 1,
1318 "one logical call → one audit row, even through coordinator"
1319 );
1320 }
1321
1322 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1333 async fn coordinator_coalesces_concurrent_calls_into_one_inner_rpc() {
1334 let response = serde_json::to_string(&(0..5)
1338 .map(|i| serde_json::json!({
1339 "task_index": i,
1340 "response": format!("response-{i}"),
1341 }))
1342 .collect::<Vec<_>>())
1343 .unwrap();
1344
1345 let h = harness().await;
1346 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text(&response)));
1347 let coord: Arc<dyn SamplingClient> =
1348 super::super::SamplingCoordinator::with_settings(
1349 fake.clone(),
1350 Duration::from_secs(5),
1352 10,
1353 );
1354 let client = SamplingLlmClient::with_sampling_client(
1355 coord,
1356 h.write_handle.clone(),
1357 Some("alice".into()),
1358 );
1359
1360 let mut futs = Vec::new();
1363 for i in 0..5 {
1364 let c = client.clone();
1365 futs.push(tokio::spawn(async move {
1366 c.complete(&[Message::user(format!("task-{i}"))]).await
1367 }));
1368 }
1369 for f in futs {
1370 f.await.expect("join").expect("ok");
1371 }
1372
1373 assert_eq!(
1375 fake.record_requests().len(),
1376 1,
1377 "5 logical calls within window must coalesce to 1 inner RPC"
1378 );
1379 assert_eq!(
1381 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1382 5,
1383 "5 logical calls → 5 audit rows (coordinator doesn't merge audits)"
1384 );
1385 }
1386
1387 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1398 async fn coordinator_max_batch_one_acts_as_passthrough() {
1399 let h = harness().await;
1400 let fake = Arc::new(FakeMcpClient::new(FakeResponse::text("ok")));
1401 let coord: Arc<dyn SamplingClient> =
1402 super::super::SamplingCoordinator::with_settings(
1403 fake.clone(),
1404 Duration::from_secs(5),
1405 1,
1408 );
1409 let client = SamplingLlmClient::with_sampling_client(
1410 coord,
1411 h.write_handle.clone(),
1412 None,
1413 );
1414 let mut futs = Vec::new();
1415 for _ in 0..3 {
1416 let c = client.clone();
1417 futs.push(tokio::spawn(async move {
1418 c.complete(&[Message::user("hi")]).await
1419 }));
1420 }
1421 for f in futs {
1422 f.await.expect("join").expect("ok");
1423 }
1424 assert_eq!(
1426 fake.record_requests().len(),
1427 3,
1428 "max_batch=1 must pass through every submission as its own RPC"
1429 );
1430 assert_eq!(
1431 count_audit_rows(&h.db_path, &h.key, "llm.sampling_call"),
1432 3
1433 );
1434 }
1435}