1use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use axum::extract::{Path, Query, State};
7use axum::Json;
8use serde::Deserialize;
9use serde_json::{json, Value};
10use tuitbot_core::config::Config;
11use tuitbot_core::mcp_policy::templates;
12use tuitbot_core::mcp_policy::types::PolicyTemplateName;
13use tuitbot_core::storage::{mcp_telemetry, rate_limits};
14
15use crate::error::ApiError;
16use crate::state::AppState;
17
18#[derive(Deserialize)]
23pub struct TimeWindowQuery {
24 #[serde(default = "default_hours")]
26 pub hours: u32,
27}
28
29fn default_hours() -> u32 {
30 24
31}
32
33#[derive(Deserialize)]
34pub struct RecentQuery {
35 #[serde(default = "default_limit")]
37 pub limit: u32,
38}
39
40fn default_limit() -> u32 {
41 50
42}
43
44pub async fn get_policy(State(state): State<Arc<AppState>>) -> Result<Json<Value>, ApiError> {
50 let config = read_config(&state)?;
51
52 let rate_limit_info = match rate_limits::get_all_rate_limits(&state.db).await {
53 Ok(limits) => {
54 let mcp = limits.iter().find(|l| l.action_type == "mcp_mutation");
55 match mcp {
56 Some(rl) => json!({
57 "used": rl.request_count,
58 "max": rl.max_requests,
59 "period_seconds": rl.period_seconds,
60 "period_start": rl.period_start,
61 }),
62 None => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
63 }
64 }
65 Err(_) => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
66 };
67
68 Ok(Json(json!({
69 "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
70 "require_approval_for": config.mcp_policy.require_approval_for,
71 "blocked_tools": config.mcp_policy.blocked_tools,
72 "dry_run_mutations": config.mcp_policy.dry_run_mutations,
73 "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
74 "mode": format!("{}", config.mode),
75 "rate_limit": rate_limit_info,
76 "template": config.mcp_policy.template,
77 "rules": config.mcp_policy.rules,
78 "rate_limits": config.mcp_policy.rate_limits,
79 })))
80}
81
82pub async fn patch_policy(
86 State(state): State<Arc<AppState>>,
87 Json(patch): Json<Value>,
88) -> Result<Json<Value>, ApiError> {
89 if !patch.is_object() {
90 return Err(ApiError::BadRequest(
91 "request body must be a JSON object".to_string(),
92 ));
93 }
94
95 let wrapped = json!({ "mcp_policy": patch });
97
98 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
99 ApiError::BadRequest(format!(
100 "could not read config file {}: {e}",
101 state.config_path.display()
102 ))
103 })?;
104
105 let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
106 ApiError::BadRequest(format!("failed to parse existing config: {e}"))
107 })?;
108
109 let patch_toml = json_to_toml(&wrapped)
110 .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
111
112 merge_toml(&mut toml_value, &patch_toml);
113
114 let merged_str = toml::to_string_pretty(&toml_value)
115 .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
116
117 let config: Config = toml::from_str(&merged_str)
118 .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
119
120 std::fs::write(&state.config_path, &merged_str).map_err(|e| {
121 ApiError::BadRequest(format!(
122 "could not write config file {}: {e}",
123 state.config_path.display()
124 ))
125 })?;
126
127 Ok(Json(json!({
128 "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
129 "require_approval_for": config.mcp_policy.require_approval_for,
130 "blocked_tools": config.mcp_policy.blocked_tools,
131 "dry_run_mutations": config.mcp_policy.dry_run_mutations,
132 "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
133 "template": config.mcp_policy.template,
134 "rules": config.mcp_policy.rules,
135 "rate_limits": config.mcp_policy.rate_limits,
136 })))
137}
138
139pub async fn list_templates() -> Json<Value> {
145 let templates = templates::list_templates();
146 Json(json!(templates))
147}
148
149pub async fn apply_template(
151 State(state): State<Arc<AppState>>,
152 Path(name): Path<String>,
153) -> Result<Json<Value>, ApiError> {
154 let template_name: PolicyTemplateName =
155 name.parse().map_err(|e: String| ApiError::BadRequest(e))?;
156
157 let template = templates::get_template(&template_name);
158
159 let patch = json!({
161 "template": template_name,
162 "rules": template.rules,
163 "rate_limits": template.rate_limits,
164 });
165
166 let wrapped = json!({ "mcp_policy": patch });
168
169 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
170 ApiError::BadRequest(format!(
171 "could not read config file {}: {e}",
172 state.config_path.display()
173 ))
174 })?;
175
176 let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
177 ApiError::BadRequest(format!("failed to parse existing config: {e}"))
178 })?;
179
180 let patch_toml = json_to_toml(&wrapped)
181 .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
182
183 merge_toml(&mut toml_value, &patch_toml);
184
185 let merged_str = toml::to_string_pretty(&toml_value)
186 .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
187
188 let config: Config = toml::from_str(&merged_str)
189 .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
190
191 std::fs::write(&state.config_path, &merged_str).map_err(|e| {
192 ApiError::BadRequest(format!(
193 "could not write config file {}: {e}",
194 state.config_path.display()
195 ))
196 })?;
197
198 if let Err(e) =
200 rate_limits::init_policy_rate_limits(&state.db, &config.mcp_policy.rate_limits).await
201 {
202 tracing::warn!("Failed to initialize policy rate limits: {e}");
203 }
204
205 Ok(Json(json!({
206 "applied_template": template_name,
207 "description": template.description,
208 "rules_count": config.mcp_policy.rules.len(),
209 "rate_limits_count": config.mcp_policy.rate_limits.len(),
210 })))
211}
212
213pub async fn telemetry_summary(
219 State(state): State<Arc<AppState>>,
220 Query(params): Query<TimeWindowQuery>,
221) -> Result<Json<Value>, ApiError> {
222 let since = since_timestamp(params.hours);
223 let summary = mcp_telemetry::get_summary(&state.db, &since).await?;
224 Ok(Json(serde_json::to_value(summary).unwrap()))
225}
226
227pub async fn telemetry_metrics(
229 State(state): State<Arc<AppState>>,
230 Query(params): Query<TimeWindowQuery>,
231) -> Result<Json<Value>, ApiError> {
232 let since = since_timestamp(params.hours);
233 let metrics = mcp_telemetry::get_metrics_since(&state.db, &since).await?;
234 Ok(Json(json!(metrics)))
235}
236
237pub async fn telemetry_errors(
239 State(state): State<Arc<AppState>>,
240 Query(params): Query<TimeWindowQuery>,
241) -> Result<Json<Value>, ApiError> {
242 let since = since_timestamp(params.hours);
243 let errors = mcp_telemetry::get_error_breakdown(&state.db, &since).await?;
244 Ok(Json(json!(errors)))
245}
246
247pub async fn telemetry_recent(
249 State(state): State<Arc<AppState>>,
250 Query(params): Query<RecentQuery>,
251) -> Result<Json<Value>, ApiError> {
252 let entries = mcp_telemetry::get_recent_entries(&state.db, params.limit).await?;
253 Ok(Json(json!(entries)))
254}
255
256fn read_config(state: &AppState) -> Result<Config, ApiError> {
261 let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
262 ApiError::BadRequest(format!(
263 "could not read config file {}: {e}",
264 state.config_path.display()
265 ))
266 })?;
267 let config: Config = toml::from_str(&contents)
268 .map_err(|e| ApiError::BadRequest(format!("failed to parse config: {e}")))?;
269 Ok(config)
270}
271
272fn since_timestamp(hours: u32) -> String {
273 let now = SystemTime::now()
274 .duration_since(UNIX_EPOCH)
275 .unwrap_or_default()
276 .as_secs();
277 let since_epoch = now.saturating_sub(u64::from(hours) * 3600);
278
279 let secs = since_epoch as i64;
281 let days = secs.div_euclid(86400);
282 let day_secs = secs.rem_euclid(86400);
283 let h = day_secs / 3600;
284 let m = (day_secs % 3600) / 60;
285 let s = day_secs % 60;
286
287 let z = days + 719468;
289 let era = z.div_euclid(146097);
290 let doe = z.rem_euclid(146097);
291 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
292 let y = yoe + era * 400;
293 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
294 let mp = (5 * doy + 2) / 153;
295 let d = doy - (153 * mp + 2) / 5 + 1;
296 let month = if mp < 10 { mp + 3 } else { mp - 9 };
297 let year = if month <= 2 { y + 1 } else { y };
298
299 format!("{year:04}-{month:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
300}
301
302fn merge_toml(base: &mut toml::Value, patch: &toml::Value) {
304 match (base, patch) {
305 (toml::Value::Table(base_table), toml::Value::Table(patch_table)) => {
306 for (key, patch_val) in patch_table {
307 if let Some(base_val) = base_table.get_mut(key) {
308 merge_toml(base_val, patch_val);
309 } else {
310 base_table.insert(key.clone(), patch_val.clone());
311 }
312 }
313 }
314 (base, _) => {
315 *base = patch.clone();
316 }
317 }
318}
319
320fn json_to_toml(json: &serde_json::Value) -> Result<toml::Value, String> {
322 match json {
323 serde_json::Value::Object(map) => {
324 let mut table = toml::map::Map::new();
325 for (key, val) in map {
326 if val.is_null() {
327 continue;
328 }
329 table.insert(key.clone(), json_to_toml(val)?);
330 }
331 Ok(toml::Value::Table(table))
332 }
333 serde_json::Value::Array(arr) => {
334 let values: Result<Vec<_>, _> = arr.iter().map(json_to_toml).collect();
335 Ok(toml::Value::Array(values?))
336 }
337 serde_json::Value::String(s) => Ok(toml::Value::String(s.clone())),
338 serde_json::Value::Number(n) => {
339 if let Some(i) = n.as_i64() {
340 Ok(toml::Value::Integer(i))
341 } else if let Some(f) = n.as_f64() {
342 Ok(toml::Value::Float(f))
343 } else {
344 Err(format!("unsupported number: {n}"))
345 }
346 }
347 serde_json::Value::Bool(b) => Ok(toml::Value::Boolean(*b)),
348 serde_json::Value::Null => Err("null values are not supported in TOML arrays".to_string()),
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn since_timestamp_is_valid_utc() {
358 let ts = since_timestamp(24);
359 assert!(ts.ends_with('Z'));
360 assert!(ts.contains('T'));
361 }
362}