Skip to main content

smg_mcp/core/
handler.rs

1//! SMG client handler for MCP server notifications and elicitation.
2//!
3//! Implements RMCP's `ClientHandler` trait to handle:
4//! - Elicitation requests (approval flow)
5//! - Tool/resource/prompt list change notifications
6//! - Progress and logging notifications
7
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11use rmcp::{
12    model::{
13        CancelledNotificationParam, ClientInfo, CreateElicitationRequestParam,
14        CreateElicitationResult, LoggingLevel, LoggingMessageNotificationParam,
15        ProgressNotificationParam, ResourceUpdatedNotificationParam,
16    },
17    service::{NotificationContext, RequestContext},
18    ClientHandler, RoleClient,
19};
20use tokio::sync::mpsc;
21use tracing::{debug, error, info, warn};
22
23use crate::{
24    approval::{ApprovalManager, ApprovalMode, ApprovalOutcome, ApprovalParams},
25    inventory::ToolInventory,
26    tenant::TenantContext,
27};
28
29/// Request to refresh server inventory.
30#[derive(Debug, Clone)]
31pub struct RefreshRequest {
32    pub server_key: String,
33}
34
35/// Per-request context set before tool execution, cleared after.
36#[derive(Debug, Clone)]
37pub struct HandlerRequestContext {
38    pub request_id: String,
39    pub approval_mode: ApprovalMode,
40    pub tenant_ctx: TenantContext,
41}
42
43impl HandlerRequestContext {
44    pub fn new(
45        request_id: impl Into<String>,
46        approval_mode: ApprovalMode,
47        tenant_ctx: TenantContext,
48    ) -> Self {
49        Self {
50            request_id: request_id.into(),
51            approval_mode,
52            tenant_ctx,
53        }
54    }
55}
56
57#[derive(Clone)]
58pub struct SmgClientHandler {
59    server_key: Arc<str>,
60    approval_manager: Arc<ApprovalManager>,
61    #[expect(dead_code)]
62    tool_inventory: Arc<ToolInventory>,
63    client_info: ClientInfo,
64    request_ctx: Arc<RwLock<Option<HandlerRequestContext>>>,
65    refresh_tx: Option<mpsc::Sender<RefreshRequest>>,
66}
67
68impl SmgClientHandler {
69    pub fn new(
70        server_key: impl AsRef<str>,
71        approval_manager: Arc<ApprovalManager>,
72        tool_inventory: Arc<ToolInventory>,
73    ) -> Self {
74        let mut client_info = ClientInfo::default();
75        client_info.client_info.name = "smg".to_string();
76        client_info.client_info.version = env!("CARGO_PKG_VERSION").to_string();
77
78        Self {
79            server_key: Arc::from(server_key.as_ref()),
80            approval_manager,
81            tool_inventory,
82            client_info,
83            request_ctx: Arc::new(RwLock::new(None)),
84            refresh_tx: None,
85        }
86    }
87
88    #[must_use]
89    pub fn with_refresh_channel(mut self, tx: mpsc::Sender<RefreshRequest>) -> Self {
90        self.refresh_tx = Some(tx);
91        self
92    }
93
94    #[must_use]
95    pub fn with_client_info(mut self, info: ClientInfo) -> Self {
96        self.client_info = info;
97        self
98    }
99
100    pub fn set_request_context(&self, ctx: HandlerRequestContext) {
101        *self.request_ctx.write() = Some(ctx);
102    }
103
104    pub fn clear_request_context(&self) {
105        *self.request_ctx.write() = None;
106    }
107
108    pub fn request_context(&self) -> Option<HandlerRequestContext> {
109        self.request_ctx.read().clone()
110    }
111
112    pub fn server_key(&self) -> &str {
113        &self.server_key
114    }
115
116    fn send_refresh(&self) {
117        if let Some(tx) = &self.refresh_tx {
118            let _ = tx
119                .try_send(RefreshRequest {
120                    server_key: self.server_key.to_string(),
121                })
122                .map_err(|e| {
123                    warn!(
124                        server_key = %self.server_key,
125                        error = %e,
126                        "Failed to send refresh request"
127                    );
128                });
129        }
130    }
131}
132
133impl ClientHandler for SmgClientHandler {
134    async fn create_elicitation(
135        &self,
136        request: CreateElicitationRequestParam,
137        context: RequestContext<RoleClient>,
138    ) -> Result<CreateElicitationResult, rmcp::ErrorData> {
139        use crate::annotations::ToolAnnotations;
140
141        let elicitation_id = match &context.id {
142            rmcp::model::RequestId::String(s) => s.to_string(),
143            rmcp::model::RequestId::Number(n) => n.to_string(),
144        };
145
146        // Get request context
147        let req_ctx = self.request_ctx.read().clone().ok_or_else(|| {
148            rmcp::ErrorData::internal_error("No request context set for elicitation", None)
149        })?;
150
151        // Use message as the tool identifier (elicitation doesn't have tool name directly)
152        let message = &request.message;
153
154        // Default annotations (conservative - not read-only, potentially destructive)
155        let hints = ToolAnnotations::default();
156
157        let params = ApprovalParams {
158            request_id: &req_ctx.request_id,
159            server_key: &self.server_key,
160            elicitation_id: &elicitation_id,
161            tool_name: "elicitation",
162            hints: &hints,
163            message,
164            tenant_ctx: &req_ctx.tenant_ctx,
165        };
166
167        let outcome = self
168            .approval_manager
169            .handle_approval(req_ctx.approval_mode, params)
170            .await
171            .map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
172
173        match outcome {
174            ApprovalOutcome::Decided(decision) => {
175                if decision.is_allowed() {
176                    Ok(CreateElicitationResult {
177                        action: rmcp::model::ElicitationAction::Accept,
178                        content: None,
179                    })
180                } else {
181                    Ok(CreateElicitationResult {
182                        action: rmcp::model::ElicitationAction::Decline,
183                        content: None,
184                    })
185                }
186            }
187            ApprovalOutcome::Pending { rx, .. } => {
188                // Wait for user response
189                match rx.await {
190                    Ok(decision) => {
191                        if decision.is_approved() {
192                            Ok(CreateElicitationResult {
193                                action: rmcp::model::ElicitationAction::Accept,
194                                content: None,
195                            })
196                        } else {
197                            Ok(CreateElicitationResult {
198                                action: rmcp::model::ElicitationAction::Decline,
199                                content: None,
200                            })
201                        }
202                    }
203                    Err(_) => Err(rmcp::ErrorData::internal_error(
204                        "Approval channel closed",
205                        None,
206                    )),
207                }
208            }
209        }
210    }
211
212    async fn on_cancelled(
213        &self,
214        params: CancelledNotificationParam,
215        _context: NotificationContext<RoleClient>,
216    ) {
217        info!(
218            server_key = %self.server_key,
219            request_id = %params.request_id,
220            reason = ?params.reason,
221            "MCP server cancelled request"
222        );
223    }
224
225    async fn on_progress(
226        &self,
227        params: ProgressNotificationParam,
228        _context: NotificationContext<RoleClient>,
229    ) {
230        debug!(
231            server_key = %self.server_key,
232            token = ?params.progress_token,
233            progress = %params.progress,
234            total = ?params.total,
235            message = ?params.message,
236            "MCP server progress"
237        );
238    }
239
240    async fn on_resource_updated(
241        &self,
242        params: ResourceUpdatedNotificationParam,
243        _context: NotificationContext<RoleClient>,
244    ) {
245        info!(
246            server_key = %self.server_key,
247            uri = %params.uri,
248            "MCP server resource updated"
249        );
250    }
251
252    async fn on_resource_list_changed(&self, _context: NotificationContext<RoleClient>) {
253        info!(server_key = %self.server_key, "MCP server resource list changed");
254        self.send_refresh();
255    }
256
257    async fn on_tool_list_changed(&self, _context: NotificationContext<RoleClient>) {
258        info!(server_key = %self.server_key, "MCP server tool list changed");
259        self.send_refresh();
260    }
261
262    async fn on_prompt_list_changed(&self, _context: NotificationContext<RoleClient>) {
263        info!(server_key = %self.server_key, "MCP server prompt list changed");
264        self.send_refresh();
265    }
266
267    fn get_info(&self) -> ClientInfo {
268        self.client_info.clone()
269    }
270
271    async fn on_logging_message(
272        &self,
273        params: LoggingMessageNotificationParam,
274        _context: NotificationContext<RoleClient>,
275    ) {
276        let logger = params.logger.as_deref().unwrap_or("mcp");
277
278        match params.level {
279            LoggingLevel::Emergency | LoggingLevel::Alert | LoggingLevel::Critical => {
280                error!(
281                    server_key = %self.server_key,
282                    logger = %logger,
283                    level = ?params.level,
284                    "MCP: {}",
285                    params.data
286                );
287            }
288            LoggingLevel::Error => {
289                error!(
290                    server_key = %self.server_key,
291                    logger = %logger,
292                    "MCP: {}",
293                    params.data
294                );
295            }
296            LoggingLevel::Warning => {
297                warn!(
298                    server_key = %self.server_key,
299                    logger = %logger,
300                    "MCP: {}",
301                    params.data
302                );
303            }
304            LoggingLevel::Notice | LoggingLevel::Info => {
305                info!(
306                    server_key = %self.server_key,
307                    logger = %logger,
308                    "MCP: {}",
309                    params.data
310                );
311            }
312            LoggingLevel::Debug => {
313                debug!(
314                    server_key = %self.server_key,
315                    logger = %logger,
316                    "MCP: {}",
317                    params.data
318                );
319            }
320        }
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::approval::{audit::AuditLog, policy::PolicyEngine};
328
329    fn test_handler() -> SmgClientHandler {
330        let audit_log = Arc::new(AuditLog::new());
331        let policy_engine = Arc::new(PolicyEngine::new(audit_log.clone()));
332        let approval_manager = Arc::new(ApprovalManager::new(policy_engine, audit_log));
333        let tool_inventory = Arc::new(ToolInventory::new());
334
335        SmgClientHandler::new("test-server", approval_manager, tool_inventory)
336    }
337
338    #[test]
339    fn test_handler_creation() {
340        let handler = test_handler();
341        assert_eq!(handler.server_key(), "test-server");
342        assert!(handler.request_context().is_none());
343    }
344
345    #[test]
346    fn test_request_context() {
347        let handler = test_handler();
348
349        let ctx = HandlerRequestContext::new(
350            "req-1",
351            ApprovalMode::PolicyOnly,
352            TenantContext::new("tenant-1"),
353        );
354
355        handler.set_request_context(ctx.clone());
356        assert!(handler.request_context().is_some());
357
358        let retrieved = handler.request_context().unwrap();
359        assert_eq!(retrieved.request_id, "req-1");
360
361        handler.clear_request_context();
362        assert!(handler.request_context().is_none());
363    }
364
365    #[test]
366    fn test_client_info() {
367        let handler = test_handler();
368        let info = handler.get_info();
369        assert_eq!(info.client_info.name, "smg");
370    }
371
372    #[test]
373    fn test_with_refresh_channel() {
374        let (tx, _rx) = mpsc::channel(10);
375        let handler = test_handler().with_refresh_channel(tx);
376        assert!(handler.refresh_tx.is_some());
377    }
378}