turbomcp_client/
sampling.rs1use async_trait::async_trait;
19use std::sync::Arc;
20use turbomcp_protocol::types::{CreateMessageRequest, CreateMessageResult};
21
22#[async_trait]
27pub trait SamplingHandler: Send + Sync + std::fmt::Debug {
28 async fn handle_create_message(
36 &self,
37 request: CreateMessageRequest,
38 ) -> Result<CreateMessageResult, Box<dyn std::error::Error + Send + Sync>>;
39}
40
41#[derive(Debug)]
46pub struct DelegatingSamplingHandler {
47 llm_clients: Vec<Arc<dyn LLMServerClient>>,
49 user_handler: Arc<dyn UserInteractionHandler>,
51}
52
53#[async_trait]
55pub trait LLMServerClient: Send + Sync + std::fmt::Debug {
56 async fn create_message(
58 &self,
59 request: CreateMessageRequest,
60 ) -> Result<CreateMessageResult, Box<dyn std::error::Error + Send + Sync>>;
61
62 async fn get_server_info(&self)
64 -> Result<ServerInfo, Box<dyn std::error::Error + Send + Sync>>;
65}
66
67#[async_trait]
69pub trait UserInteractionHandler: Send + Sync + std::fmt::Debug {
70 async fn approve_request(
72 &self,
73 request: &CreateMessageRequest,
74 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
75
76 async fn approve_response(
78 &self,
79 request: &CreateMessageRequest,
80 response: &CreateMessageResult,
81 ) -> Result<Option<CreateMessageResult>, Box<dyn std::error::Error + Send + Sync>>;
82}
83
84#[derive(Debug, Clone)]
86pub struct ServerInfo {
87 pub name: String,
88 pub models: Vec<String>,
89 pub capabilities: Vec<String>,
90}
91
92#[async_trait]
93impl SamplingHandler for DelegatingSamplingHandler {
94 async fn handle_create_message(
95 &self,
96 request: CreateMessageRequest,
97 ) -> Result<CreateMessageResult, Box<dyn std::error::Error + Send + Sync>> {
98 if !self.user_handler.approve_request(&request).await? {
100 return Err("User rejected sampling request".into());
101 }
102
103 let selected_client = self.select_llm_client(&request).await?;
105
106 let result = selected_client.create_message(request.clone()).await?;
108
109 let approved_result = self
111 .user_handler
112 .approve_response(&request, &result)
113 .await?;
114
115 Ok(approved_result.unwrap_or(result))
116 }
117}
118
119impl DelegatingSamplingHandler {
120 pub fn new(
122 llm_clients: Vec<Arc<dyn LLMServerClient>>,
123 user_handler: Arc<dyn UserInteractionHandler>,
124 ) -> Self {
125 Self {
126 llm_clients,
127 user_handler,
128 }
129 }
130
131 async fn select_llm_client(
133 &self,
134 _request: &CreateMessageRequest,
135 ) -> Result<Arc<dyn LLMServerClient>, Box<dyn std::error::Error + Send + Sync>> {
136 if let Some(first_client) = self.llm_clients.first() {
140 Ok(first_client.clone())
141 } else {
142 Err("No LLM servers configured".into())
143 }
144 }
145}
146
147#[derive(Debug)]
149pub struct AutoApprovingUserHandler;
150
151#[async_trait]
152impl UserInteractionHandler for AutoApprovingUserHandler {
153 async fn approve_request(
154 &self,
155 _request: &CreateMessageRequest,
156 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
157 Ok(true) }
159
160 async fn approve_response(
161 &self,
162 _request: &CreateMessageRequest,
163 _response: &CreateMessageResult,
164 ) -> Result<Option<CreateMessageResult>, Box<dyn std::error::Error + Send + Sync>> {
165 Ok(None) }
167}