1use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use axum::extract::{Path, Query, State};
7use axum::Json;
8use serde::Deserialize;
9use serde_json::{json, Value};
10use tuitbot_core::config::Config;
11use tuitbot_core::mcp_policy::templates;
12use tuitbot_core::mcp_policy::types::PolicyTemplateName;
13use tuitbot_core::storage::{mcp_telemetry, rate_limits};
14
15use crate::error::ApiError;
16use crate::state::AppState;
17
18#[derive(Deserialize)]
23pub struct TimeWindowQuery {
24 #[serde(default = "default_hours")]
26 pub hours: u32,
27}
28
29fn default_hours() -> u32 {
30 24
31}
32
33#[derive(Deserialize)]
34pub struct RecentQuery {
35 #[serde(default = "default_limit")]
37 pub limit: u32,
38}
39
40fn default_limit() -> u32 {
41 50
42}
43
44pub async fn get_policy(State(state): State<Arc<AppState>>) -> Result<Json<Value>, ApiError> {
50 let config = read_config(&state)?;
51
52 let rate_limit_info = match rate_limits::get_all_rate_limits(&state.db).await {
53 Ok(limits) => {
54 let mcp = limits.iter().find(|l| l.action_type == "mcp_mutation");
55 match mcp {
56 Some(rl) => json!({
57 "used": rl.request_count,
58 "max": rl.max_requests,
59 "period_seconds": rl.period_seconds,
60 "period_start": rl.period_start,
61 }),
62 None => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
63 }
64 }
65 Err(_) => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
66 };
67
68 Ok(Json(json!({
69 "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
70 "require_approval_for": config.mcp_policy.require_approval_for,
71 "blocked_tools": config.mcp_policy.blocked_tools,
72 "dry_run_mutations": config.mcp_policy.dry_run_mutations,
73 "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
74 "mode": format!("{}", config.mode),
75 "rate_limit": rate_limit_info,
76 "template": config.mcp_policy.template,
77 "rules": config.mcp_policy.rules,
78 "rate_limits": config.mcp_policy.rate_limits,
79 })))
80}
81
82pub async fn patch_policy(
86 State(state): State<Arc<AppState>>,
87 Json(patch): Json<Value>,
88) -> Result<Json<Value>, ApiError> {
89 if !patch.is_object() {
90 return Err(ApiError::BadRequest(
91 "request body must be a JSON object".to_string(),
92 ));
93 }
94
95 let wrapped = json!({ "mcp_policy": patch });
97
98 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
99 ApiError::BadRequest(format!(
100 "could not read config file {}: {e}",
101 state.config_path.display()
102 ))
103 })?;
104
105 let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
106 ApiError::BadRequest(format!("failed to parse existing config: {e}"))
107 })?;
108
109 let patch_toml = json_to_toml(&wrapped)
110 .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
111
112 merge_toml(&mut toml_value, &patch_toml);
113
114 let merged_str = toml::to_string_pretty(&toml_value)
115 .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
116
117 let config: Config = toml::from_str(&merged_str)
118 .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
119
120 std::fs::write(&state.config_path, &merged_str).map_err(|e| {
121 ApiError::BadRequest(format!(
122 "could not write config file {}: {e}",
123 state.config_path.display()
124 ))
125 })?;
126
127 Ok(Json(json!({
128 "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
129 "require_approval_for": config.mcp_policy.require_approval_for,
130 "blocked_tools": config.mcp_policy.blocked_tools,
131 "dry_run_mutations": config.mcp_policy.dry_run_mutations,
132 "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
133 "template": config.mcp_policy.template,
134 "rules": config.mcp_policy.rules,
135 "rate_limits": config.mcp_policy.rate_limits,
136 })))
137}
138
139pub async fn list_templates() -> Json<Value> {
145 let templates = templates::list_templates();
146 Json(json!(templates))
147}
148
149pub async fn apply_template(
151 State(state): State<Arc<AppState>>,
152 Path(name): Path<String>,
153) -> Result<Json<Value>, ApiError> {
154 let template_name: PolicyTemplateName =
155 name.parse().map_err(|e: String| ApiError::BadRequest(e))?;
156
157 let template = templates::get_template(&template_name);
158
159 let patch = json!({
161 "template": template_name,
162 "rules": template.rules,
163 "rate_limits": template.rate_limits,
164 });
165
166 let wrapped = json!({ "mcp_policy": patch });
168
169 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
170 ApiError::BadRequest(format!(
171 "could not read config file {}: {e}",
172 state.config_path.display()
173 ))
174 })?;
175
176 let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
177 ApiError::BadRequest(format!("failed to parse existing config: {e}"))
178 })?;
179
180 let patch_toml = json_to_toml(&wrapped)
181 .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
182
183 merge_toml(&mut toml_value, &patch_toml);
184
185 let merged_str = toml::to_string_pretty(&toml_value)
186 .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
187
188 let config: Config = toml::from_str(&merged_str)
189 .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
190
191 std::fs::write(&state.config_path, &merged_str).map_err(|e| {
192 ApiError::BadRequest(format!(
193 "could not write config file {}: {e}",
194 state.config_path.display()
195 ))
196 })?;
197
198 if let Err(e) =
200 rate_limits::init_policy_rate_limits(&state.db, &config.mcp_policy.rate_limits).await
201 {
202 tracing::warn!("Failed to initialize policy rate limits: {e}");
203 }
204
205 Ok(Json(json!({
206 "applied_template": template_name,
207 "description": template.description,
208 "rules_count": config.mcp_policy.rules.len(),
209 "rate_limits_count": config.mcp_policy.rate_limits.len(),
210 })))
211}
212
213pub async fn telemetry_summary(
219 State(state): State<Arc<AppState>>,
220 Query(params): Query<TimeWindowQuery>,
221) -> Result<Json<Value>, ApiError> {
222 let since = since_timestamp(params.hours);
223 let summary = mcp_telemetry::get_summary(&state.db, &since).await?;
224 Ok(Json(serde_json::to_value(summary).unwrap()))
225}
226
227pub async fn telemetry_metrics(
229 State(state): State<Arc<AppState>>,
230 Query(params): Query<TimeWindowQuery>,
231) -> Result<Json<Value>, ApiError> {
232 let since = since_timestamp(params.hours);
233 let metrics = mcp_telemetry::get_metrics_since(&state.db, &since).await?;
234 Ok(Json(json!(metrics)))
235}
236
237pub async fn telemetry_errors(
239 State(state): State<Arc<AppState>>,
240 Query(params): Query<TimeWindowQuery>,
241) -> Result<Json<Value>, ApiError> {
242 let since = since_timestamp(params.hours);
243 let errors = mcp_telemetry::get_error_breakdown(&state.db, &since).await?;
244 Ok(Json(json!(errors)))
245}
246
247pub async fn telemetry_recent(
249 State(state): State<Arc<AppState>>,
250 Query(params): Query<RecentQuery>,
251) -> Result<Json<Value>, ApiError> {
252 let entries = mcp_telemetry::get_recent_entries(&state.db, params.limit).await?;
253 Ok(Json(json!(entries)))
254}
255
256fn read_config(state: &AppState) -> Result<Config, ApiError> {
261 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
262 ApiError::BadRequest(format!(
263 "could not read config file {}: {e}",
264 state.config_path.display()
265 ))
266 })?;
267 let config: Config = toml::from_str(&contents)
268 .map_err(|e| ApiError::BadRequest(format!("failed to parse config: {e}")))?;
269 Ok(config)
270}
271
272fn since_timestamp(hours: u32) -> String {
273 let now = SystemTime::now()
274 .duration_since(UNIX_EPOCH)
275 .unwrap_or_default()
276 .as_secs();
277 let since_epoch = now.saturating_sub(u64::from(hours) * 3600);
278
279 let secs = since_epoch as i64;
281 let days = secs.div_euclid(86400);
282 let day_secs = secs.rem_euclid(86400);
283 let h = day_secs / 3600;
284 let m = (day_secs % 3600) / 60;
285 let s = day_secs % 60;
286
287 let z = days + 719468;
289 let era = z.div_euclid(146097);
290 let doe = z.rem_euclid(146097);
291 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
292 let y = yoe + era * 400;
293 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
294 let mp = (5 * doy + 2) / 153;
295 let d = doy - (153 * mp + 2) / 5 + 1;
296 let month = if mp < 10 { mp + 3 } else { mp - 9 };
297 let year = if month <= 2 { y + 1 } else { y };
298
299 format!("{year:04}-{month:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
300}
301
302fn merge_toml(base: &mut toml::Value, patch: &toml::Value) {
304 match (base, patch) {
305 (toml::Value::Table(base_table), toml::Value::Table(patch_table)) => {
306 for (key, patch_val) in patch_table {
307 if let Some(base_val) = base_table.get_mut(key) {
308 merge_toml(base_val, patch_val);
309 } else {
310 base_table.insert(key.clone(), patch_val.clone());
311 }
312 }
313 }
314 (base, _) => {
315 *base = patch.clone();
316 }
317 }
318}
319
320fn json_to_toml(json: &serde_json::Value) -> Result<toml::Value, String> {
322 match json {
323 serde_json::Value::Object(map) => {
324 let mut table = toml::map::Map::new();
325 for (key, val) in map {
326 if val.is_null() {
327 continue;
328 }
329 table.insert(key.clone(), json_to_toml(val)?);
330 }
331 Ok(toml::Value::Table(table))
332 }
333 serde_json::Value::Array(arr) => {
334 let values: Result<Vec<_>, _> = arr.iter().map(json_to_toml).collect();
335 Ok(toml::Value::Array(values?))
336 }
337 serde_json::Value::String(s) => Ok(toml::Value::String(s.clone())),
338 serde_json::Value::Number(n) => {
339 if let Some(i) = n.as_i64() {
340 Ok(toml::Value::Integer(i))
341 } else if let Some(f) = n.as_f64() {
342 Ok(toml::Value::Float(f))
343 } else {
344 Err(format!("unsupported number: {n}"))
345 }
346 }
347 serde_json::Value::Bool(b) => Ok(toml::Value::Boolean(*b)),
348 serde_json::Value::Null => Err("null values are not supported in TOML arrays".to_string()),
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn since_timestamp_is_valid_utc() {
358 let ts = since_timestamp(24);
359 assert!(ts.ends_with('Z'));
360 assert!(ts.contains('T'));
361 }
362
363 #[test]
364 fn since_timestamp_zero_hours_is_now() {
365 let ts = since_timestamp(0);
366 assert!(ts.ends_with('Z'));
367 assert!(ts.starts_with("20"));
369 }
370
371 #[test]
372 fn since_timestamp_large_hours() {
373 let ts = since_timestamp(8760); assert!(ts.ends_with('Z'));
375 assert!(ts.contains('T'));
376 }
377
378 #[test]
379 fn since_timestamp_format_correct() {
380 let ts = since_timestamp(48);
381 assert_eq!(ts.len(), 20);
383 assert_eq!(&ts[4..5], "-");
384 assert_eq!(&ts[7..8], "-");
385 assert_eq!(&ts[10..11], "T");
386 assert_eq!(&ts[13..14], ":");
387 assert_eq!(&ts[16..17], ":");
388 }
389
390 #[test]
393 fn merge_toml_adds_new_keys() {
394 let mut base: toml::Value = "key1 = \"value1\"".parse().unwrap();
395 let patch: toml::Value = "key2 = \"value2\"".parse().unwrap();
396 merge_toml(&mut base, &patch);
397 assert_eq!(
398 base.as_table().unwrap().get("key2").unwrap().as_str(),
399 Some("value2")
400 );
401 }
402
403 #[test]
404 fn merge_toml_overwrites_existing() {
405 let mut base: toml::Value = "key = \"old\"".parse().unwrap();
406 let patch: toml::Value = "key = \"new\"".parse().unwrap();
407 merge_toml(&mut base, &patch);
408 assert_eq!(
409 base.as_table().unwrap().get("key").unwrap().as_str(),
410 Some("new")
411 );
412 }
413
414 #[test]
415 fn merge_toml_deep_merge() {
416 let mut base: toml::Value = "[section]\na = 1".parse().unwrap();
417 let patch: toml::Value = "[section]\nb = 2".parse().unwrap();
418 merge_toml(&mut base, &patch);
419 let section = base.as_table().unwrap().get("section").unwrap();
420 assert_eq!(section.get("a").unwrap().as_integer(), Some(1));
421 assert_eq!(section.get("b").unwrap().as_integer(), Some(2));
422 }
423
424 #[test]
427 fn json_to_toml_string() {
428 let json = serde_json::json!("hello");
429 let toml = json_to_toml(&json).unwrap();
430 assert_eq!(toml.as_str(), Some("hello"));
431 }
432
433 #[test]
434 fn json_to_toml_integer() {
435 let json = serde_json::json!(42);
436 let toml = json_to_toml(&json).unwrap();
437 assert_eq!(toml.as_integer(), Some(42));
438 }
439
440 #[test]
441 fn json_to_toml_float() {
442 let json = serde_json::json!(3.14);
443 let toml = json_to_toml(&json).unwrap();
444 assert!((toml.as_float().unwrap() - 3.14).abs() < 0.001);
445 }
446
447 #[test]
448 fn json_to_toml_boolean() {
449 let json = serde_json::json!(true);
450 let toml = json_to_toml(&json).unwrap();
451 assert_eq!(toml.as_bool(), Some(true));
452 }
453
454 #[test]
455 fn json_to_toml_array() {
456 let json = serde_json::json!([1, 2, 3]);
457 let toml = json_to_toml(&json).unwrap();
458 assert_eq!(toml.as_array().unwrap().len(), 3);
459 }
460
461 #[test]
462 fn json_to_toml_object() {
463 let json = serde_json::json!({"key": "value"});
464 let toml = json_to_toml(&json).unwrap();
465 assert!(toml.as_table().is_some());
466 }
467
468 #[test]
469 fn json_to_toml_skips_null_in_objects() {
470 let json = serde_json::json!({"key": "value", "null_key": null});
471 let toml = json_to_toml(&json).unwrap();
472 let table = toml.as_table().unwrap();
473 assert!(table.contains_key("key"));
474 assert!(!table.contains_key("null_key"));
475 }
476
477 #[test]
478 fn json_to_toml_null_in_array_errors() {
479 let json = serde_json::json!([null]);
480 assert!(json_to_toml(&json).is_err());
481 }
482
483 #[test]
486 fn time_window_query_defaults() {
487 let json = "{}";
488 let q: TimeWindowQuery = serde_json::from_str(json).expect("deser");
489 assert_eq!(q.hours, 24);
490 }
491
492 #[test]
495 fn recent_query_defaults() {
496 let json = "{}";
497 let q: RecentQuery = serde_json::from_str(json).expect("deser");
498 assert_eq!(q.limit, 50);
499 }
500}