Skip to main content

tuitbot_server/routes/
mcp.rs

1//! MCP governance and telemetry endpoints.
2
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use axum::extract::{Path, Query, State};
7use axum::Json;
8use serde::Deserialize;
9use serde_json::{json, Value};
10use tuitbot_core::config::Config;
11use tuitbot_core::mcp_policy::templates;
12use tuitbot_core::mcp_policy::types::PolicyTemplateName;
13use tuitbot_core::storage::{mcp_telemetry, rate_limits};
14
15use crate::error::ApiError;
16use crate::state::AppState;
17
18// ---------------------------------------------------------------------------
19// Query types
20// ---------------------------------------------------------------------------
21
22#[derive(Deserialize)]
23pub struct TimeWindowQuery {
24    /// Lookback window in hours (default: 24).
25    #[serde(default = "default_hours")]
26    pub hours: u32,
27}
28
29fn default_hours() -> u32 {
30    24
31}
32
33#[derive(Deserialize)]
34pub struct RecentQuery {
35    /// Number of recent entries to return (default: 50).
36    #[serde(default = "default_limit")]
37    pub limit: u32,
38}
39
40fn default_limit() -> u32 {
41    50
42}
43
44// ---------------------------------------------------------------------------
45// Policy endpoints
46// ---------------------------------------------------------------------------
47
48/// `GET /api/mcp/policy` — current MCP policy config + rate limit usage + v2 fields.
49pub async fn get_policy(State(state): State<Arc<AppState>>) -> Result<Json<Value>, ApiError> {
50    let config = read_config(&state)?;
51
52    let rate_limit_info = match rate_limits::get_all_rate_limits(&state.db).await {
53        Ok(limits) => {
54            let mcp = limits.iter().find(|l| l.action_type == "mcp_mutation");
55            match mcp {
56                Some(rl) => json!({
57                    "used": rl.request_count,
58                    "max": rl.max_requests,
59                    "period_seconds": rl.period_seconds,
60                    "period_start": rl.period_start,
61                }),
62                None => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
63            }
64        }
65        Err(_) => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
66    };
67
68    Ok(Json(json!({
69        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
70        "require_approval_for": config.mcp_policy.require_approval_for,
71        "blocked_tools": config.mcp_policy.blocked_tools,
72        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
73        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
74        "mode": format!("{}", config.mode),
75        "rate_limit": rate_limit_info,
76        "template": config.mcp_policy.template,
77        "rules": config.mcp_policy.rules,
78        "rate_limits": config.mcp_policy.rate_limits,
79    })))
80}
81
82/// `PATCH /api/mcp/policy` — update MCP policy config fields.
83///
84/// Accepts partial JSON with `mcp_policy` fields and merges into config.
85pub async fn patch_policy(
86    State(state): State<Arc<AppState>>,
87    Json(patch): Json<Value>,
88) -> Result<Json<Value>, ApiError> {
89    if !patch.is_object() {
90        return Err(ApiError::BadRequest(
91            "request body must be a JSON object".to_string(),
92        ));
93    }
94
95    // Wrap the patch under `mcp_policy` key for the settings merge.
96    let wrapped = json!({ "mcp_policy": patch });
97
98    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
99        ApiError::BadRequest(format!(
100            "could not read config file {}: {e}",
101            state.config_path.display()
102        ))
103    })?;
104
105    let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
106        ApiError::BadRequest(format!("failed to parse existing config: {e}"))
107    })?;
108
109    let patch_toml = json_to_toml(&wrapped)
110        .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
111
112    merge_toml(&mut toml_value, &patch_toml);
113
114    let merged_str = toml::to_string_pretty(&toml_value)
115        .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
116
117    let config: Config = toml::from_str(&merged_str)
118        .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
119
120    std::fs::write(&state.config_path, &merged_str).map_err(|e| {
121        ApiError::BadRequest(format!(
122            "could not write config file {}: {e}",
123            state.config_path.display()
124        ))
125    })?;
126
127    Ok(Json(json!({
128        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
129        "require_approval_for": config.mcp_policy.require_approval_for,
130        "blocked_tools": config.mcp_policy.blocked_tools,
131        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
132        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
133        "template": config.mcp_policy.template,
134        "rules": config.mcp_policy.rules,
135        "rate_limits": config.mcp_policy.rate_limits,
136    })))
137}
138
139// ---------------------------------------------------------------------------
140// Template endpoints
141// ---------------------------------------------------------------------------
142
143/// `GET /api/mcp/policy/templates` — list available policy templates.
144pub async fn list_templates() -> Json<Value> {
145    let templates = templates::list_templates();
146    Json(json!(templates))
147}
148
149/// `POST /api/mcp/policy/templates/{name}` — apply a template.
150pub async fn apply_template(
151    State(state): State<Arc<AppState>>,
152    Path(name): Path<String>,
153) -> Result<Json<Value>, ApiError> {
154    let template_name: PolicyTemplateName =
155        name.parse().map_err(|e: String| ApiError::BadRequest(e))?;
156
157    let template = templates::get_template(&template_name);
158
159    // Build a patch that sets the template and its rules/rate_limits
160    let patch = json!({
161        "template": template_name,
162        "rules": template.rules,
163        "rate_limits": template.rate_limits,
164    });
165
166    // Wrap under mcp_policy and merge into config
167    let wrapped = json!({ "mcp_policy": patch });
168
169    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
170        ApiError::BadRequest(format!(
171            "could not read config file {}: {e}",
172            state.config_path.display()
173        ))
174    })?;
175
176    let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
177        ApiError::BadRequest(format!("failed to parse existing config: {e}"))
178    })?;
179
180    let patch_toml = json_to_toml(&wrapped)
181        .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
182
183    merge_toml(&mut toml_value, &patch_toml);
184
185    let merged_str = toml::to_string_pretty(&toml_value)
186        .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
187
188    let config: Config = toml::from_str(&merged_str)
189        .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
190
191    std::fs::write(&state.config_path, &merged_str).map_err(|e| {
192        ApiError::BadRequest(format!(
193            "could not write config file {}: {e}",
194            state.config_path.display()
195        ))
196    })?;
197
198    // Initialize rate limit rows for the new template limits
199    if let Err(e) =
200        rate_limits::init_policy_rate_limits(&state.db, &config.mcp_policy.rate_limits).await
201    {
202        tracing::warn!("Failed to initialize policy rate limits: {e}");
203    }
204
205    Ok(Json(json!({
206        "applied_template": template_name,
207        "description": template.description,
208        "rules_count": config.mcp_policy.rules.len(),
209        "rate_limits_count": config.mcp_policy.rate_limits.len(),
210    })))
211}
212
213// ---------------------------------------------------------------------------
214// Telemetry endpoints
215// ---------------------------------------------------------------------------
216
217/// `GET /api/mcp/telemetry/summary` — aggregate stats over a time window.
218pub async fn telemetry_summary(
219    State(state): State<Arc<AppState>>,
220    Query(params): Query<TimeWindowQuery>,
221) -> Result<Json<Value>, ApiError> {
222    let since = since_timestamp(params.hours);
223    let summary = mcp_telemetry::get_summary(&state.db, &since).await?;
224    Ok(Json(serde_json::to_value(summary).unwrap()))
225}
226
227/// `GET /api/mcp/telemetry/metrics` — per-tool metrics over a time window.
228pub async fn telemetry_metrics(
229    State(state): State<Arc<AppState>>,
230    Query(params): Query<TimeWindowQuery>,
231) -> Result<Json<Value>, ApiError> {
232    let since = since_timestamp(params.hours);
233    let metrics = mcp_telemetry::get_metrics_since(&state.db, &since).await?;
234    Ok(Json(json!(metrics)))
235}
236
237/// `GET /api/mcp/telemetry/errors` — error breakdown over a time window.
238pub async fn telemetry_errors(
239    State(state): State<Arc<AppState>>,
240    Query(params): Query<TimeWindowQuery>,
241) -> Result<Json<Value>, ApiError> {
242    let since = since_timestamp(params.hours);
243    let errors = mcp_telemetry::get_error_breakdown(&state.db, &since).await?;
244    Ok(Json(json!(errors)))
245}
246
247/// `GET /api/mcp/telemetry/recent` — recent tool executions.
248pub async fn telemetry_recent(
249    State(state): State<Arc<AppState>>,
250    Query(params): Query<RecentQuery>,
251) -> Result<Json<Value>, ApiError> {
252    let entries = mcp_telemetry::get_recent_entries(&state.db, params.limit).await?;
253    Ok(Json(json!(entries)))
254}
255
256// ---------------------------------------------------------------------------
257// Helpers
258// ---------------------------------------------------------------------------
259
260fn read_config(state: &AppState) -> Result<Config, ApiError> {
261    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
262        ApiError::BadRequest(format!(
263            "could not read config file {}: {e}",
264            state.config_path.display()
265        ))
266    })?;
267    let config: Config = toml::from_str(&contents)
268        .map_err(|e| ApiError::BadRequest(format!("failed to parse config: {e}")))?;
269    Ok(config)
270}
271
272fn since_timestamp(hours: u32) -> String {
273    let now = SystemTime::now()
274        .duration_since(UNIX_EPOCH)
275        .unwrap_or_default()
276        .as_secs();
277    let since_epoch = now.saturating_sub(u64::from(hours) * 3600);
278
279    // Convert epoch seconds to ISO-8601 UTC (YYYY-MM-DDTHH:MM:SSZ).
280    let secs = since_epoch as i64;
281    let days = secs.div_euclid(86400);
282    let day_secs = secs.rem_euclid(86400);
283    let h = day_secs / 3600;
284    let m = (day_secs % 3600) / 60;
285    let s = day_secs % 60;
286
287    // Days since epoch → date using the civil-from-days algorithm.
288    let z = days + 719468;
289    let era = z.div_euclid(146097);
290    let doe = z.rem_euclid(146097);
291    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
292    let y = yoe + era * 400;
293    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
294    let mp = (5 * doy + 2) / 153;
295    let d = doy - (153 * mp + 2) / 5 + 1;
296    let month = if mp < 10 { mp + 3 } else { mp - 9 };
297    let year = if month <= 2 { y + 1 } else { y };
298
299    format!("{year:04}-{month:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
300}
301
302/// Recursively merge `patch` into `base`.
303fn merge_toml(base: &mut toml::Value, patch: &toml::Value) {
304    match (base, patch) {
305        (toml::Value::Table(base_table), toml::Value::Table(patch_table)) => {
306            for (key, patch_val) in patch_table {
307                if let Some(base_val) = base_table.get_mut(key) {
308                    merge_toml(base_val, patch_val);
309                } else {
310                    base_table.insert(key.clone(), patch_val.clone());
311                }
312            }
313        }
314        (base, _) => {
315            *base = patch.clone();
316        }
317    }
318}
319
320/// Convert JSON to TOML, skipping nulls in objects.
321fn json_to_toml(json: &serde_json::Value) -> Result<toml::Value, String> {
322    match json {
323        serde_json::Value::Object(map) => {
324            let mut table = toml::map::Map::new();
325            for (key, val) in map {
326                if val.is_null() {
327                    continue;
328                }
329                table.insert(key.clone(), json_to_toml(val)?);
330            }
331            Ok(toml::Value::Table(table))
332        }
333        serde_json::Value::Array(arr) => {
334            let values: Result<Vec<_>, _> = arr.iter().map(json_to_toml).collect();
335            Ok(toml::Value::Array(values?))
336        }
337        serde_json::Value::String(s) => Ok(toml::Value::String(s.clone())),
338        serde_json::Value::Number(n) => {
339            if let Some(i) = n.as_i64() {
340                Ok(toml::Value::Integer(i))
341            } else if let Some(f) = n.as_f64() {
342                Ok(toml::Value::Float(f))
343            } else {
344                Err(format!("unsupported number: {n}"))
345            }
346        }
347        serde_json::Value::Bool(b) => Ok(toml::Value::Boolean(*b)),
348        serde_json::Value::Null => Err("null values are not supported in TOML arrays".to_string()),
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn since_timestamp_is_valid_utc() {
358        let ts = since_timestamp(24);
359        assert!(ts.ends_with('Z'));
360        assert!(ts.contains('T'));
361    }
362
363    #[test]
364    fn since_timestamp_zero_hours_is_now() {
365        let ts = since_timestamp(0);
366        assert!(ts.ends_with('Z'));
367        // Should be close to current time
368        assert!(ts.starts_with("20"));
369    }
370
371    #[test]
372    fn since_timestamp_large_hours() {
373        let ts = since_timestamp(8760); // 1 year
374        assert!(ts.ends_with('Z'));
375        assert!(ts.contains('T'));
376    }
377
378    #[test]
379    fn since_timestamp_format_correct() {
380        let ts = since_timestamp(48);
381        // Should match YYYY-MM-DDTHH:MM:SSZ
382        assert_eq!(ts.len(), 20);
383        assert_eq!(&ts[4..5], "-");
384        assert_eq!(&ts[7..8], "-");
385        assert_eq!(&ts[10..11], "T");
386        assert_eq!(&ts[13..14], ":");
387        assert_eq!(&ts[16..17], ":");
388    }
389
390    // --- merge_toml ---
391
392    #[test]
393    fn merge_toml_adds_new_keys() {
394        let mut base: toml::Value = "key1 = \"value1\"".parse().unwrap();
395        let patch: toml::Value = "key2 = \"value2\"".parse().unwrap();
396        merge_toml(&mut base, &patch);
397        assert_eq!(
398            base.as_table().unwrap().get("key2").unwrap().as_str(),
399            Some("value2")
400        );
401    }
402
403    #[test]
404    fn merge_toml_overwrites_existing() {
405        let mut base: toml::Value = "key = \"old\"".parse().unwrap();
406        let patch: toml::Value = "key = \"new\"".parse().unwrap();
407        merge_toml(&mut base, &patch);
408        assert_eq!(
409            base.as_table().unwrap().get("key").unwrap().as_str(),
410            Some("new")
411        );
412    }
413
414    #[test]
415    fn merge_toml_deep_merge() {
416        let mut base: toml::Value = "[section]\na = 1".parse().unwrap();
417        let patch: toml::Value = "[section]\nb = 2".parse().unwrap();
418        merge_toml(&mut base, &patch);
419        let section = base.as_table().unwrap().get("section").unwrap();
420        assert_eq!(section.get("a").unwrap().as_integer(), Some(1));
421        assert_eq!(section.get("b").unwrap().as_integer(), Some(2));
422    }
423
424    // --- json_to_toml ---
425
426    #[test]
427    fn json_to_toml_string() {
428        let json = serde_json::json!("hello");
429        let toml = json_to_toml(&json).unwrap();
430        assert_eq!(toml.as_str(), Some("hello"));
431    }
432
433    #[test]
434    fn json_to_toml_integer() {
435        let json = serde_json::json!(42);
436        let toml = json_to_toml(&json).unwrap();
437        assert_eq!(toml.as_integer(), Some(42));
438    }
439
440    #[test]
441    fn json_to_toml_float() {
442        let json = serde_json::json!(3.14);
443        let toml = json_to_toml(&json).unwrap();
444        assert!((toml.as_float().unwrap() - 3.14).abs() < 0.001);
445    }
446
447    #[test]
448    fn json_to_toml_boolean() {
449        let json = serde_json::json!(true);
450        let toml = json_to_toml(&json).unwrap();
451        assert_eq!(toml.as_bool(), Some(true));
452    }
453
454    #[test]
455    fn json_to_toml_array() {
456        let json = serde_json::json!([1, 2, 3]);
457        let toml = json_to_toml(&json).unwrap();
458        assert_eq!(toml.as_array().unwrap().len(), 3);
459    }
460
461    #[test]
462    fn json_to_toml_object() {
463        let json = serde_json::json!({"key": "value"});
464        let toml = json_to_toml(&json).unwrap();
465        assert!(toml.as_table().is_some());
466    }
467
468    #[test]
469    fn json_to_toml_skips_null_in_objects() {
470        let json = serde_json::json!({"key": "value", "null_key": null});
471        let toml = json_to_toml(&json).unwrap();
472        let table = toml.as_table().unwrap();
473        assert!(table.contains_key("key"));
474        assert!(!table.contains_key("null_key"));
475    }
476
477    #[test]
478    fn json_to_toml_null_in_array_errors() {
479        let json = serde_json::json!([null]);
480        assert!(json_to_toml(&json).is_err());
481    }
482
483    // --- TimeWindowQuery defaults ---
484
485    #[test]
486    fn time_window_query_defaults() {
487        let json = "{}";
488        let q: TimeWindowQuery = serde_json::from_str(json).expect("deser");
489        assert_eq!(q.hours, 24);
490    }
491
492    // --- RecentQuery defaults ---
493
494    #[test]
495    fn recent_query_defaults() {
496        let json = "{}";
497        let q: RecentQuery = serde_json::from_str(json).expect("deser");
498        assert_eq!(q.limit, 50);
499    }
500}