1use std::sync::Arc;
9
10use parking_lot::RwLock;
11use rmcp::{
12 model::{
13 CancelledNotificationParam, ClientInfo, CreateElicitationRequestParam,
14 CreateElicitationResult, LoggingLevel, LoggingMessageNotificationParam,
15 ProgressNotificationParam, ResourceUpdatedNotificationParam,
16 },
17 service::{NotificationContext, RequestContext},
18 ClientHandler, RoleClient,
19};
20use tokio::sync::mpsc;
21use tracing::{debug, error, info, warn};
22
23use crate::{
24 approval::{ApprovalManager, ApprovalMode, ApprovalOutcome, ApprovalParams},
25 inventory::ToolInventory,
26 tenant::TenantContext,
27};
28
29#[derive(Debug, Clone)]
31pub struct RefreshRequest {
32 pub server_key: String,
33}
34
35#[derive(Debug, Clone)]
37pub struct HandlerRequestContext {
38 pub request_id: String,
39 pub approval_mode: ApprovalMode,
40 pub tenant_ctx: TenantContext,
41}
42
43impl HandlerRequestContext {
44 pub fn new(
45 request_id: impl Into<String>,
46 approval_mode: ApprovalMode,
47 tenant_ctx: TenantContext,
48 ) -> Self {
49 Self {
50 request_id: request_id.into(),
51 approval_mode,
52 tenant_ctx,
53 }
54 }
55}
56
57#[derive(Clone)]
58pub struct SmgClientHandler {
59 server_key: Arc<str>,
60 approval_manager: Arc<ApprovalManager>,
61 #[expect(dead_code)]
62 tool_inventory: Arc<ToolInventory>,
63 client_info: ClientInfo,
64 request_ctx: Arc<RwLock<Option<HandlerRequestContext>>>,
65 refresh_tx: Option<mpsc::Sender<RefreshRequest>>,
66}
67
68impl SmgClientHandler {
69 pub fn new(
70 server_key: impl AsRef<str>,
71 approval_manager: Arc<ApprovalManager>,
72 tool_inventory: Arc<ToolInventory>,
73 ) -> Self {
74 let mut client_info = ClientInfo::default();
75 client_info.client_info.name = "smg".to_string();
76 client_info.client_info.version = env!("CARGO_PKG_VERSION").to_string();
77
78 Self {
79 server_key: Arc::from(server_key.as_ref()),
80 approval_manager,
81 tool_inventory,
82 client_info,
83 request_ctx: Arc::new(RwLock::new(None)),
84 refresh_tx: None,
85 }
86 }
87
88 #[must_use]
89 pub fn with_refresh_channel(mut self, tx: mpsc::Sender<RefreshRequest>) -> Self {
90 self.refresh_tx = Some(tx);
91 self
92 }
93
94 #[must_use]
95 pub fn with_client_info(mut self, info: ClientInfo) -> Self {
96 self.client_info = info;
97 self
98 }
99
100 pub fn set_request_context(&self, ctx: HandlerRequestContext) {
101 *self.request_ctx.write() = Some(ctx);
102 }
103
104 pub fn clear_request_context(&self) {
105 *self.request_ctx.write() = None;
106 }
107
108 pub fn request_context(&self) -> Option<HandlerRequestContext> {
109 self.request_ctx.read().clone()
110 }
111
112 pub fn server_key(&self) -> &str {
113 &self.server_key
114 }
115
116 fn send_refresh(&self) {
117 if let Some(tx) = &self.refresh_tx {
118 let _ = tx
119 .try_send(RefreshRequest {
120 server_key: self.server_key.to_string(),
121 })
122 .map_err(|e| {
123 warn!(
124 server_key = %self.server_key,
125 error = %e,
126 "Failed to send refresh request"
127 );
128 });
129 }
130 }
131}
132
133impl ClientHandler for SmgClientHandler {
134 async fn create_elicitation(
135 &self,
136 request: CreateElicitationRequestParam,
137 context: RequestContext<RoleClient>,
138 ) -> Result<CreateElicitationResult, rmcp::ErrorData> {
139 use crate::annotations::ToolAnnotations;
140
141 let elicitation_id = match &context.id {
142 rmcp::model::RequestId::String(s) => s.to_string(),
143 rmcp::model::RequestId::Number(n) => n.to_string(),
144 };
145
146 let req_ctx = self.request_ctx.read().clone().ok_or_else(|| {
148 rmcp::ErrorData::internal_error("No request context set for elicitation", None)
149 })?;
150
151 let message = &request.message;
153
154 let hints = ToolAnnotations::default();
156
157 let params = ApprovalParams {
158 request_id: &req_ctx.request_id,
159 server_key: &self.server_key,
160 elicitation_id: &elicitation_id,
161 tool_name: "elicitation",
162 hints: &hints,
163 message,
164 tenant_ctx: &req_ctx.tenant_ctx,
165 };
166
167 let outcome = self
168 .approval_manager
169 .handle_approval(req_ctx.approval_mode, params)
170 .await
171 .map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
172
173 match outcome {
174 ApprovalOutcome::Decided(decision) => {
175 if decision.is_allowed() {
176 Ok(CreateElicitationResult {
177 action: rmcp::model::ElicitationAction::Accept,
178 content: None,
179 })
180 } else {
181 Ok(CreateElicitationResult {
182 action: rmcp::model::ElicitationAction::Decline,
183 content: None,
184 })
185 }
186 }
187 ApprovalOutcome::Pending { rx, .. } => {
188 match rx.await {
190 Ok(decision) => {
191 if decision.is_approved() {
192 Ok(CreateElicitationResult {
193 action: rmcp::model::ElicitationAction::Accept,
194 content: None,
195 })
196 } else {
197 Ok(CreateElicitationResult {
198 action: rmcp::model::ElicitationAction::Decline,
199 content: None,
200 })
201 }
202 }
203 Err(_) => Err(rmcp::ErrorData::internal_error(
204 "Approval channel closed",
205 None,
206 )),
207 }
208 }
209 }
210 }
211
212 async fn on_cancelled(
213 &self,
214 params: CancelledNotificationParam,
215 _context: NotificationContext<RoleClient>,
216 ) {
217 info!(
218 server_key = %self.server_key,
219 request_id = %params.request_id,
220 reason = ?params.reason,
221 "MCP server cancelled request"
222 );
223 }
224
225 async fn on_progress(
226 &self,
227 params: ProgressNotificationParam,
228 _context: NotificationContext<RoleClient>,
229 ) {
230 debug!(
231 server_key = %self.server_key,
232 token = ?params.progress_token,
233 progress = %params.progress,
234 total = ?params.total,
235 message = ?params.message,
236 "MCP server progress"
237 );
238 }
239
240 async fn on_resource_updated(
241 &self,
242 params: ResourceUpdatedNotificationParam,
243 _context: NotificationContext<RoleClient>,
244 ) {
245 info!(
246 server_key = %self.server_key,
247 uri = %params.uri,
248 "MCP server resource updated"
249 );
250 }
251
252 async fn on_resource_list_changed(&self, _context: NotificationContext<RoleClient>) {
253 info!(server_key = %self.server_key, "MCP server resource list changed");
254 self.send_refresh();
255 }
256
257 async fn on_tool_list_changed(&self, _context: NotificationContext<RoleClient>) {
258 info!(server_key = %self.server_key, "MCP server tool list changed");
259 self.send_refresh();
260 }
261
262 async fn on_prompt_list_changed(&self, _context: NotificationContext<RoleClient>) {
263 info!(server_key = %self.server_key, "MCP server prompt list changed");
264 self.send_refresh();
265 }
266
267 fn get_info(&self) -> ClientInfo {
268 self.client_info.clone()
269 }
270
271 async fn on_logging_message(
272 &self,
273 params: LoggingMessageNotificationParam,
274 _context: NotificationContext<RoleClient>,
275 ) {
276 let logger = params.logger.as_deref().unwrap_or("mcp");
277
278 match params.level {
279 LoggingLevel::Emergency | LoggingLevel::Alert | LoggingLevel::Critical => {
280 error!(
281 server_key = %self.server_key,
282 logger = %logger,
283 level = ?params.level,
284 "MCP: {}",
285 params.data
286 );
287 }
288 LoggingLevel::Error => {
289 error!(
290 server_key = %self.server_key,
291 logger = %logger,
292 "MCP: {}",
293 params.data
294 );
295 }
296 LoggingLevel::Warning => {
297 warn!(
298 server_key = %self.server_key,
299 logger = %logger,
300 "MCP: {}",
301 params.data
302 );
303 }
304 LoggingLevel::Notice | LoggingLevel::Info => {
305 info!(
306 server_key = %self.server_key,
307 logger = %logger,
308 "MCP: {}",
309 params.data
310 );
311 }
312 LoggingLevel::Debug => {
313 debug!(
314 server_key = %self.server_key,
315 logger = %logger,
316 "MCP: {}",
317 params.data
318 );
319 }
320 }
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::approval::{audit::AuditLog, policy::PolicyEngine};
328
329 fn test_handler() -> SmgClientHandler {
330 let audit_log = Arc::new(AuditLog::new());
331 let policy_engine = Arc::new(PolicyEngine::new(audit_log.clone()));
332 let approval_manager = Arc::new(ApprovalManager::new(policy_engine, audit_log));
333 let tool_inventory = Arc::new(ToolInventory::new());
334
335 SmgClientHandler::new("test-server", approval_manager, tool_inventory)
336 }
337
338 #[test]
339 fn test_handler_creation() {
340 let handler = test_handler();
341 assert_eq!(handler.server_key(), "test-server");
342 assert!(handler.request_context().is_none());
343 }
344
345 #[test]
346 fn test_request_context() {
347 let handler = test_handler();
348
349 let ctx = HandlerRequestContext::new(
350 "req-1",
351 ApprovalMode::PolicyOnly,
352 TenantContext::new("tenant-1"),
353 );
354
355 handler.set_request_context(ctx.clone());
356 assert!(handler.request_context().is_some());
357
358 let retrieved = handler.request_context().unwrap();
359 assert_eq!(retrieved.request_id, "req-1");
360
361 handler.clear_request_context();
362 assert!(handler.request_context().is_none());
363 }
364
365 #[test]
366 fn test_client_info() {
367 let handler = test_handler();
368 let info = handler.get_info();
369 assert_eq!(info.client_info.name, "smg");
370 }
371
372 #[test]
373 fn test_with_refresh_channel() {
374 let (tx, _rx) = mpsc::channel(10);
375 let handler = test_handler().with_refresh_channel(tx);
376 assert!(handler.refresh_tx.is_some());
377 }
378}