Skip to main content

tuitbot_server/routes/
mcp.rs

1//! MCP governance and telemetry endpoints.
2
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use axum::extract::{Path, Query, State};
7use axum::Json;
8use serde::Deserialize;
9use serde_json::{json, Value};
10use tuitbot_core::config::Config;
11use tuitbot_core::mcp_policy::templates;
12use tuitbot_core::mcp_policy::types::PolicyTemplateName;
13use tuitbot_core::storage::{mcp_telemetry, rate_limits};
14
15use crate::error::ApiError;
16use crate::state::AppState;
17
18// ---------------------------------------------------------------------------
19// Query types
20// ---------------------------------------------------------------------------
21
22#[derive(Deserialize)]
23pub struct TimeWindowQuery {
24    /// Lookback window in hours (default: 24).
25    #[serde(default = "default_hours")]
26    pub hours: u32,
27}
28
29fn default_hours() -> u32 {
30    24
31}
32
33#[derive(Deserialize)]
34pub struct RecentQuery {
35    /// Number of recent entries to return (default: 50).
36    #[serde(default = "default_limit")]
37    pub limit: u32,
38}
39
40fn default_limit() -> u32 {
41    50
42}
43
44// ---------------------------------------------------------------------------
45// Policy endpoints
46// ---------------------------------------------------------------------------
47
48/// `GET /api/mcp/policy` — current MCP policy config + rate limit usage + v2 fields.
49pub async fn get_policy(State(state): State<Arc<AppState>>) -> Result<Json<Value>, ApiError> {
50    let config = read_config(&state)?;
51
52    let rate_limit_info = match rate_limits::get_all_rate_limits(&state.db).await {
53        Ok(limits) => {
54            let mcp = limits.iter().find(|l| l.action_type == "mcp_mutation");
55            match mcp {
56                Some(rl) => json!({
57                    "used": rl.request_count,
58                    "max": rl.max_requests,
59                    "period_seconds": rl.period_seconds,
60                    "period_start": rl.period_start,
61                }),
62                None => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
63            }
64        }
65        Err(_) => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
66    };
67
68    Ok(Json(json!({
69        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
70        "require_approval_for": config.mcp_policy.require_approval_for,
71        "blocked_tools": config.mcp_policy.blocked_tools,
72        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
73        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
74        "mode": format!("{}", config.mode),
75        "rate_limit": rate_limit_info,
76        "template": config.mcp_policy.template,
77        "rules": config.mcp_policy.rules,
78        "rate_limits": config.mcp_policy.rate_limits,
79    })))
80}
81
82/// `PATCH /api/mcp/policy` — update MCP policy config fields.
83///
84/// Accepts partial JSON with `mcp_policy` fields and merges into config.
85pub async fn patch_policy(
86    State(state): State<Arc<AppState>>,
87    Json(patch): Json<Value>,
88) -> Result<Json<Value>, ApiError> {
89    if !patch.is_object() {
90        return Err(ApiError::BadRequest(
91            "request body must be a JSON object".to_string(),
92        ));
93    }
94
95    // Wrap the patch under `mcp_policy` key for the settings merge.
96    let wrapped = json!({ "mcp_policy": patch });
97
98    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
99        ApiError::BadRequest(format!(
100            "could not read config file {}: {e}",
101            state.config_path.display()
102        ))
103    })?;
104
105    let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
106        ApiError::BadRequest(format!("failed to parse existing config: {e}"))
107    })?;
108
109    let patch_toml = json_to_toml(&wrapped)
110        .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
111
112    merge_toml(&mut toml_value, &patch_toml);
113
114    let merged_str = toml::to_string_pretty(&toml_value)
115        .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
116
117    let config: Config = toml::from_str(&merged_str)
118        .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
119
120    std::fs::write(&state.config_path, &merged_str).map_err(|e| {
121        ApiError::BadRequest(format!(
122            "could not write config file {}: {e}",
123            state.config_path.display()
124        ))
125    })?;
126
127    Ok(Json(json!({
128        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
129        "require_approval_for": config.mcp_policy.require_approval_for,
130        "blocked_tools": config.mcp_policy.blocked_tools,
131        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
132        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
133        "template": config.mcp_policy.template,
134        "rules": config.mcp_policy.rules,
135        "rate_limits": config.mcp_policy.rate_limits,
136    })))
137}
138
139// ---------------------------------------------------------------------------
140// Template endpoints
141// ---------------------------------------------------------------------------
142
143/// `GET /api/mcp/policy/templates` — list available policy templates.
144pub async fn list_templates() -> Json<Value> {
145    let templates = templates::list_templates();
146    Json(json!(templates))
147}
148
149/// `POST /api/mcp/policy/templates/{name}` — apply a template.
150pub async fn apply_template(
151    State(state): State<Arc<AppState>>,
152    Path(name): Path<String>,
153) -> Result<Json<Value>, ApiError> {
154    let template_name: PolicyTemplateName =
155        name.parse().map_err(|e: String| ApiError::BadRequest(e))?;
156
157    let template = templates::get_template(&template_name);
158
159    // Build a patch that sets the template and its rules/rate_limits
160    let patch = json!({
161        "template": template_name,
162        "rules": template.rules,
163        "rate_limits": template.rate_limits,
164    });
165
166    // Wrap under mcp_policy and merge into config
167    let wrapped = json!({ "mcp_policy": patch });
168
169    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
170        ApiError::BadRequest(format!(
171            "could not read config file {}: {e}",
172            state.config_path.display()
173        ))
174    })?;
175
176    let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
177        ApiError::BadRequest(format!("failed to parse existing config: {e}"))
178    })?;
179
180    let patch_toml = json_to_toml(&wrapped)
181        .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
182
183    merge_toml(&mut toml_value, &patch_toml);
184
185    let merged_str = toml::to_string_pretty(&toml_value)
186        .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
187
188    let config: Config = toml::from_str(&merged_str)
189        .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
190
191    std::fs::write(&state.config_path, &merged_str).map_err(|e| {
192        ApiError::BadRequest(format!(
193            "could not write config file {}: {e}",
194            state.config_path.display()
195        ))
196    })?;
197
198    // Initialize rate limit rows for the new template limits
199    if let Err(e) =
200        rate_limits::init_policy_rate_limits(&state.db, &config.mcp_policy.rate_limits).await
201    {
202        tracing::warn!("Failed to initialize policy rate limits: {e}");
203    }
204
205    Ok(Json(json!({
206        "applied_template": template_name,
207        "description": template.description,
208        "rules_count": config.mcp_policy.rules.len(),
209        "rate_limits_count": config.mcp_policy.rate_limits.len(),
210    })))
211}
212
213// ---------------------------------------------------------------------------
214// Telemetry endpoints
215// ---------------------------------------------------------------------------
216
217/// `GET /api/mcp/telemetry/summary` — aggregate stats over a time window.
218pub async fn telemetry_summary(
219    State(state): State<Arc<AppState>>,
220    Query(params): Query<TimeWindowQuery>,
221) -> Result<Json<Value>, ApiError> {
222    let since = since_timestamp(params.hours);
223    let summary = mcp_telemetry::get_summary(&state.db, &since).await?;
224    // `serde_json::to_value` can only fail if T contains a non-string map key or
225    // an f64 NaN/infinity; the telemetry summary struct contains neither.
226    let value = serde_json::to_value(summary)
227        .map_err(|e| ApiError::Internal(format!("serialization error: {e}")))?;
228    Ok(Json(value))
229}
230
231/// `GET /api/mcp/telemetry/metrics` — per-tool metrics over a time window.
232pub async fn telemetry_metrics(
233    State(state): State<Arc<AppState>>,
234    Query(params): Query<TimeWindowQuery>,
235) -> Result<Json<Value>, ApiError> {
236    let since = since_timestamp(params.hours);
237    let metrics = mcp_telemetry::get_metrics_since(&state.db, &since).await?;
238    Ok(Json(json!(metrics)))
239}
240
241/// `GET /api/mcp/telemetry/errors` — error breakdown over a time window.
242pub async fn telemetry_errors(
243    State(state): State<Arc<AppState>>,
244    Query(params): Query<TimeWindowQuery>,
245) -> Result<Json<Value>, ApiError> {
246    let since = since_timestamp(params.hours);
247    let errors = mcp_telemetry::get_error_breakdown(&state.db, &since).await?;
248    Ok(Json(json!(errors)))
249}
250
251/// `GET /api/mcp/telemetry/recent` — recent tool executions.
252pub async fn telemetry_recent(
253    State(state): State<Arc<AppState>>,
254    Query(params): Query<RecentQuery>,
255) -> Result<Json<Value>, ApiError> {
256    let entries = mcp_telemetry::get_recent_entries(&state.db, params.limit).await?;
257    Ok(Json(json!(entries)))
258}
259
260// ---------------------------------------------------------------------------
261// Helpers
262// ---------------------------------------------------------------------------
263
264fn read_config(state: &AppState) -> Result<Config, ApiError> {
265    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
266        ApiError::BadRequest(format!(
267            "could not read config file {}: {e}",
268            state.config_path.display()
269        ))
270    })?;
271    let config: Config = toml::from_str(&contents)
272        .map_err(|e| ApiError::BadRequest(format!("failed to parse config: {e}")))?;
273    Ok(config)
274}
275
276fn since_timestamp(hours: u32) -> String {
277    let now = SystemTime::now()
278        .duration_since(UNIX_EPOCH)
279        .unwrap_or_default()
280        .as_secs();
281    let since_epoch = now.saturating_sub(u64::from(hours) * 3600);
282
283    // Convert epoch seconds to ISO-8601 UTC (YYYY-MM-DDTHH:MM:SSZ).
284    let secs = since_epoch as i64;
285    let days = secs.div_euclid(86400);
286    let day_secs = secs.rem_euclid(86400);
287    let h = day_secs / 3600;
288    let m = (day_secs % 3600) / 60;
289    let s = day_secs % 60;
290
291    // Days since epoch → date using the civil-from-days algorithm.
292    let z = days + 719468;
293    let era = z.div_euclid(146097);
294    let doe = z.rem_euclid(146097);
295    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
296    let y = yoe + era * 400;
297    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
298    let mp = (5 * doy + 2) / 153;
299    let d = doy - (153 * mp + 2) / 5 + 1;
300    let month = if mp < 10 { mp + 3 } else { mp - 9 };
301    let year = if month <= 2 { y + 1 } else { y };
302
303    format!("{year:04}-{month:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
304}
305
306/// Recursively merge `patch` into `base`.
307fn merge_toml(base: &mut toml::Value, patch: &toml::Value) {
308    match (base, patch) {
309        (toml::Value::Table(base_table), toml::Value::Table(patch_table)) => {
310            for (key, patch_val) in patch_table {
311                if let Some(base_val) = base_table.get_mut(key) {
312                    merge_toml(base_val, patch_val);
313                } else {
314                    base_table.insert(key.clone(), patch_val.clone());
315                }
316            }
317        }
318        (base, _) => {
319            *base = patch.clone();
320        }
321    }
322}
323
324/// Convert JSON to TOML, skipping nulls in objects.
325fn json_to_toml(json: &serde_json::Value) -> Result<toml::Value, String> {
326    match json {
327        serde_json::Value::Object(map) => {
328            let mut table = toml::map::Map::new();
329            for (key, val) in map {
330                if val.is_null() {
331                    continue;
332                }
333                table.insert(key.clone(), json_to_toml(val)?);
334            }
335            Ok(toml::Value::Table(table))
336        }
337        serde_json::Value::Array(arr) => {
338            let values: Result<Vec<_>, _> = arr.iter().map(json_to_toml).collect();
339            Ok(toml::Value::Array(values?))
340        }
341        serde_json::Value::String(s) => Ok(toml::Value::String(s.clone())),
342        serde_json::Value::Number(n) => {
343            if let Some(i) = n.as_i64() {
344                Ok(toml::Value::Integer(i))
345            } else if let Some(f) = n.as_f64() {
346                Ok(toml::Value::Float(f))
347            } else {
348                Err(format!("unsupported number: {n}"))
349            }
350        }
351        serde_json::Value::Bool(b) => Ok(toml::Value::Boolean(*b)),
352        serde_json::Value::Null => Err("null values are not supported in TOML arrays".to_string()),
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn since_timestamp_is_valid_utc() {
362        let ts = since_timestamp(24);
363        assert!(ts.ends_with('Z'));
364        assert!(ts.contains('T'));
365    }
366
367    #[test]
368    fn since_timestamp_zero_hours_is_now() {
369        let ts = since_timestamp(0);
370        assert!(ts.ends_with('Z'));
371        // Should be close to current time
372        assert!(ts.starts_with("20"));
373    }
374
375    #[test]
376    fn since_timestamp_large_hours() {
377        let ts = since_timestamp(8760); // 1 year
378        assert!(ts.ends_with('Z'));
379        assert!(ts.contains('T'));
380    }
381
382    #[test]
383    fn since_timestamp_format_correct() {
384        let ts = since_timestamp(48);
385        // Should match YYYY-MM-DDTHH:MM:SSZ
386        assert_eq!(ts.len(), 20);
387        assert_eq!(&ts[4..5], "-");
388        assert_eq!(&ts[7..8], "-");
389        assert_eq!(&ts[10..11], "T");
390        assert_eq!(&ts[13..14], ":");
391        assert_eq!(&ts[16..17], ":");
392    }
393
394    // --- merge_toml ---
395
396    #[test]
397    fn merge_toml_adds_new_keys() {
398        let mut base: toml::Value = "key1 = \"value1\"".parse().unwrap();
399        let patch: toml::Value = "key2 = \"value2\"".parse().unwrap();
400        merge_toml(&mut base, &patch);
401        assert_eq!(
402            base.as_table().unwrap().get("key2").unwrap().as_str(),
403            Some("value2")
404        );
405    }
406
407    #[test]
408    fn merge_toml_overwrites_existing() {
409        let mut base: toml::Value = "key = \"old\"".parse().unwrap();
410        let patch: toml::Value = "key = \"new\"".parse().unwrap();
411        merge_toml(&mut base, &patch);
412        assert_eq!(
413            base.as_table().unwrap().get("key").unwrap().as_str(),
414            Some("new")
415        );
416    }
417
418    #[test]
419    fn merge_toml_deep_merge() {
420        let mut base: toml::Value = "[section]\na = 1".parse().unwrap();
421        let patch: toml::Value = "[section]\nb = 2".parse().unwrap();
422        merge_toml(&mut base, &patch);
423        let section = base.as_table().unwrap().get("section").unwrap();
424        assert_eq!(section.get("a").unwrap().as_integer(), Some(1));
425        assert_eq!(section.get("b").unwrap().as_integer(), Some(2));
426    }
427
428    // --- json_to_toml ---
429
430    #[test]
431    fn json_to_toml_string() {
432        let json = serde_json::json!("hello");
433        let toml = json_to_toml(&json).unwrap();
434        assert_eq!(toml.as_str(), Some("hello"));
435    }
436
437    #[test]
438    fn json_to_toml_integer() {
439        let json = serde_json::json!(42);
440        let toml = json_to_toml(&json).unwrap();
441        assert_eq!(toml.as_integer(), Some(42));
442    }
443
444    #[test]
445    fn json_to_toml_float() {
446        let json = serde_json::json!(3.14);
447        let toml = json_to_toml(&json).unwrap();
448        assert!((toml.as_float().unwrap() - 3.14).abs() < 0.001);
449    }
450
451    #[test]
452    fn json_to_toml_boolean() {
453        let json = serde_json::json!(true);
454        let toml = json_to_toml(&json).unwrap();
455        assert_eq!(toml.as_bool(), Some(true));
456    }
457
458    #[test]
459    fn json_to_toml_array() {
460        let json = serde_json::json!([1, 2, 3]);
461        let toml = json_to_toml(&json).unwrap();
462        assert_eq!(toml.as_array().unwrap().len(), 3);
463    }
464
465    #[test]
466    fn json_to_toml_object() {
467        let json = serde_json::json!({"key": "value"});
468        let toml = json_to_toml(&json).unwrap();
469        assert!(toml.as_table().is_some());
470    }
471
472    #[test]
473    fn json_to_toml_skips_null_in_objects() {
474        let json = serde_json::json!({"key": "value", "null_key": null});
475        let toml = json_to_toml(&json).unwrap();
476        let table = toml.as_table().unwrap();
477        assert!(table.contains_key("key"));
478        assert!(!table.contains_key("null_key"));
479    }
480
481    #[test]
482    fn json_to_toml_null_in_array_errors() {
483        let json = serde_json::json!([null]);
484        assert!(json_to_toml(&json).is_err());
485    }
486
487    // --- TimeWindowQuery defaults ---
488
489    #[test]
490    fn time_window_query_defaults() {
491        let json = "{}";
492        let q: TimeWindowQuery = serde_json::from_str(json).expect("deser");
493        assert_eq!(q.hours, 24);
494    }
495
496    // --- RecentQuery defaults ---
497
498    #[test]
499    fn recent_query_defaults() {
500        let json = "{}";
501        let q: RecentQuery = serde_json::from_str(json).expect("deser");
502        assert_eq!(q.limit, 50);
503    }
504}