1use super::accounts::DEFAULT_ACCOUNT_ID;
7use super::DbPool;
8use crate::config::{IntervalsConfig, LimitsConfig};
9use crate::error::StorageError;
10use chrono::{DateTime, Utc};
11
12#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
14pub struct RateLimit {
15 pub action_type: String,
17 pub request_count: i64,
19 pub period_start: String,
21 pub max_requests: i64,
23 pub period_seconds: i64,
25}
26
27pub async fn init_rate_limits_for(
32 pool: &DbPool,
33 account_id: &str,
34 config: &LimitsConfig,
35 intervals: &IntervalsConfig,
36) -> Result<(), StorageError> {
37 let _ = intervals;
39
40 let defaults: Vec<(&str, i64, i64)> = vec![
41 ("reply", i64::from(config.max_replies_per_day), 86400),
42 ("tweet", i64::from(config.max_tweets_per_day), 86400),
43 ("thread", i64::from(config.max_threads_per_week), 604800),
44 ("search", 300, 900),
45 ("mention_check", 180, 900),
46 ];
47
48 for (action_type, max_requests, period_seconds) in defaults {
49 sqlx::query(
50 "INSERT OR IGNORE INTO rate_limits \
51 (account_id, action_type, request_count, period_start, max_requests, period_seconds) \
52 VALUES (?, ?, 0, strftime('%Y-%m-%dT%H:%M:%SZ', 'now'), ?, ?)",
53 )
54 .bind(account_id)
55 .bind(action_type)
56 .bind(max_requests)
57 .bind(period_seconds)
58 .execute(pool)
59 .await
60 .map_err(|e| StorageError::Query { source: e })?;
61 }
62
63 Ok(())
64}
65
66pub async fn init_rate_limits(
71 pool: &DbPool,
72 config: &LimitsConfig,
73 intervals: &IntervalsConfig,
74) -> Result<(), StorageError> {
75 init_rate_limits_for(pool, DEFAULT_ACCOUNT_ID, config, intervals).await
76}
77
78pub async fn init_mcp_rate_limit_for(
82 pool: &DbPool,
83 account_id: &str,
84 max_per_hour: u32,
85) -> Result<(), StorageError> {
86 sqlx::query(
87 "INSERT OR IGNORE INTO rate_limits \
88 (account_id, action_type, request_count, period_start, max_requests, period_seconds) \
89 VALUES (?, 'mcp_mutation', 0, strftime('%Y-%m-%dT%H:%M:%SZ', 'now'), ?, 3600)",
90 )
91 .bind(account_id)
92 .bind(i64::from(max_per_hour))
93 .execute(pool)
94 .await
95 .map_err(|e| StorageError::Query { source: e })?;
96
97 Ok(())
98}
99
100pub async fn init_mcp_rate_limit(pool: &DbPool, max_per_hour: u32) -> Result<(), StorageError> {
104 init_mcp_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, max_per_hour).await
105}
106
107pub async fn check_rate_limit_for(
116 pool: &DbPool,
117 account_id: &str,
118 action_type: &str,
119) -> Result<bool, StorageError> {
120 let mut tx = pool
121 .begin()
122 .await
123 .map_err(|e| StorageError::Connection { source: e })?;
124
125 let row = sqlx::query_as::<_, RateLimit>(
126 "SELECT action_type, request_count, period_start, max_requests, period_seconds \
127 FROM rate_limits WHERE account_id = ? AND action_type = ?",
128 )
129 .bind(account_id)
130 .bind(action_type)
131 .fetch_optional(&mut *tx)
132 .await
133 .map_err(|e| StorageError::Query { source: e })?;
134
135 let limit = match row {
136 Some(l) => l,
137 None => {
138 tx.commit()
139 .await
140 .map_err(|e| StorageError::Connection { source: e })?;
141 return Ok(true);
142 }
143 };
144
145 let now = Utc::now();
146 let period_start = limit.period_start.parse::<DateTime<Utc>>().unwrap_or(now);
147
148 let elapsed = now.signed_duration_since(period_start).num_seconds();
149
150 if elapsed >= limit.period_seconds {
151 sqlx::query(
152 "UPDATE rate_limits SET request_count = 0, \
153 period_start = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
154 WHERE account_id = ? AND action_type = ?",
155 )
156 .bind(account_id)
157 .bind(action_type)
158 .execute(&mut *tx)
159 .await
160 .map_err(|e| StorageError::Query { source: e })?;
161
162 tx.commit()
163 .await
164 .map_err(|e| StorageError::Connection { source: e })?;
165 return Ok(true);
166 }
167
168 let allowed = limit.request_count < limit.max_requests;
169
170 tx.commit()
171 .await
172 .map_err(|e| StorageError::Connection { source: e })?;
173
174 Ok(allowed)
175}
176
177pub async fn check_rate_limit(pool: &DbPool, action_type: &str) -> Result<bool, StorageError> {
186 check_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
187}
188
189pub async fn check_and_increment_rate_limit_for(
195 pool: &DbPool,
196 account_id: &str,
197 action_type: &str,
198) -> Result<bool, StorageError> {
199 let mut tx = pool
200 .begin()
201 .await
202 .map_err(|e| StorageError::Connection { source: e })?;
203
204 let row = sqlx::query_as::<_, RateLimit>(
205 "SELECT action_type, request_count, period_start, max_requests, period_seconds \
206 FROM rate_limits WHERE account_id = ? AND action_type = ?",
207 )
208 .bind(account_id)
209 .bind(action_type)
210 .fetch_optional(&mut *tx)
211 .await
212 .map_err(|e| StorageError::Query { source: e })?;
213
214 let limit = match row {
215 Some(l) => l,
216 None => {
217 tx.commit()
218 .await
219 .map_err(|e| StorageError::Connection { source: e })?;
220 return Ok(true);
221 }
222 };
223
224 let now = Utc::now();
225 let period_start = limit.period_start.parse::<DateTime<Utc>>().unwrap_or(now);
226
227 let elapsed = now.signed_duration_since(period_start).num_seconds();
228
229 let current_count = if elapsed >= limit.period_seconds {
230 sqlx::query(
231 "UPDATE rate_limits SET request_count = 0, \
232 period_start = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
233 WHERE account_id = ? AND action_type = ?",
234 )
235 .bind(account_id)
236 .bind(action_type)
237 .execute(&mut *tx)
238 .await
239 .map_err(|e| StorageError::Query { source: e })?;
240 0
241 } else {
242 limit.request_count
243 };
244
245 if current_count < limit.max_requests {
246 sqlx::query(
247 "UPDATE rate_limits SET request_count = request_count + 1 \
248 WHERE account_id = ? AND action_type = ?",
249 )
250 .bind(account_id)
251 .bind(action_type)
252 .execute(&mut *tx)
253 .await
254 .map_err(|e| StorageError::Query { source: e })?;
255
256 tx.commit()
257 .await
258 .map_err(|e| StorageError::Connection { source: e })?;
259 Ok(true)
260 } else {
261 tx.commit()
262 .await
263 .map_err(|e| StorageError::Connection { source: e })?;
264 Ok(false)
265 }
266}
267
268pub async fn check_and_increment_rate_limit(
274 pool: &DbPool,
275 action_type: &str,
276) -> Result<bool, StorageError> {
277 check_and_increment_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
278}
279
280pub async fn increment_rate_limit_for(
284 pool: &DbPool,
285 account_id: &str,
286 action_type: &str,
287) -> Result<(), StorageError> {
288 sqlx::query(
289 "UPDATE rate_limits SET request_count = request_count + 1 \
290 WHERE account_id = ? AND action_type = ?",
291 )
292 .bind(account_id)
293 .bind(action_type)
294 .execute(pool)
295 .await
296 .map_err(|e| StorageError::Query { source: e })?;
297
298 Ok(())
299}
300
301pub async fn increment_rate_limit(pool: &DbPool, action_type: &str) -> Result<(), StorageError> {
305 increment_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
306}
307
308#[derive(Debug, Clone, serde::Serialize)]
310pub struct ActionUsage {
311 pub used: i64,
312 pub max: i64,
313}
314
315#[derive(Debug, Clone, serde::Serialize)]
317pub struct DailyUsage {
318 pub replies: ActionUsage,
319 pub tweets: ActionUsage,
320 pub threads: ActionUsage,
321}
322
323pub async fn get_daily_usage_for(
328 pool: &DbPool,
329 account_id: &str,
330) -> Result<DailyUsage, StorageError> {
331 let limits = get_all_rate_limits_for(pool, account_id).await?;
332
333 let mut usage = DailyUsage {
334 replies: ActionUsage { used: 0, max: 0 },
335 tweets: ActionUsage { used: 0, max: 0 },
336 threads: ActionUsage { used: 0, max: 0 },
337 };
338
339 for limit in limits {
340 let target = match limit.action_type.as_str() {
341 "reply" => &mut usage.replies,
342 "tweet" => &mut usage.tweets,
343 "thread" => &mut usage.threads,
344 _ => continue,
345 };
346 target.used = limit.request_count;
347 target.max = limit.max_requests;
348 }
349
350 Ok(usage)
351}
352
353pub async fn get_daily_usage(pool: &DbPool) -> Result<DailyUsage, StorageError> {
358 get_daily_usage_for(pool, DEFAULT_ACCOUNT_ID).await
359}
360
361pub async fn get_all_rate_limits_for(
365 pool: &DbPool,
366 account_id: &str,
367) -> Result<Vec<RateLimit>, StorageError> {
368 sqlx::query_as::<_, RateLimit>(
369 "SELECT action_type, request_count, period_start, max_requests, period_seconds \
370 FROM rate_limits WHERE account_id = ? ORDER BY action_type",
371 )
372 .bind(account_id)
373 .fetch_all(pool)
374 .await
375 .map_err(|e| StorageError::Query { source: e })
376}
377
378pub async fn get_all_rate_limits(pool: &DbPool) -> Result<Vec<RateLimit>, StorageError> {
382 get_all_rate_limits_for(pool, DEFAULT_ACCOUNT_ID).await
383}
384
385use crate::mcp_policy::types::{PolicyRateLimit, RateLimitDimension};
390
391pub async fn init_policy_rate_limits_for(
395 pool: &DbPool,
396 account_id: &str,
397 limits: &[PolicyRateLimit],
398) -> Result<(), StorageError> {
399 for limit in limits {
400 sqlx::query(
401 "INSERT OR IGNORE INTO rate_limits \
402 (account_id, action_type, request_count, period_start, max_requests, period_seconds) \
403 VALUES (?, ?, 0, strftime('%Y-%m-%dT%H:%M:%SZ', 'now'), ?, ?)",
404 )
405 .bind(account_id)
406 .bind(&limit.key)
407 .bind(i64::from(limit.max_count))
408 .bind(limit.period_seconds as i64)
409 .execute(pool)
410 .await
411 .map_err(|e| StorageError::Query { source: e })?;
412 }
413 Ok(())
414}
415
416pub async fn init_policy_rate_limits(
420 pool: &DbPool,
421 limits: &[PolicyRateLimit],
422) -> Result<(), StorageError> {
423 init_policy_rate_limits_for(pool, DEFAULT_ACCOUNT_ID, limits).await
424}
425
426pub async fn check_policy_rate_limits_for(
430 pool: &DbPool,
431 account_id: &str,
432 tool_name: &str,
433 category: &str,
434 limits: &[PolicyRateLimit],
435) -> Result<Option<String>, StorageError> {
436 for limit in limits {
437 let matches = match limit.dimension {
438 RateLimitDimension::Tool => limit.match_value == tool_name,
439 RateLimitDimension::Category => limit.match_value == category,
440 RateLimitDimension::EngagementType => limit.match_value == tool_name,
441 RateLimitDimension::Global => true,
442 };
443
444 if !matches {
445 continue;
446 }
447
448 let allowed = check_rate_limit_for(pool, account_id, &limit.key).await?;
449 if !allowed {
450 return Ok(Some(limit.key.clone()));
451 }
452 }
453 Ok(None)
454}
455
456pub async fn check_policy_rate_limits(
460 pool: &DbPool,
461 tool_name: &str,
462 category: &str,
463 limits: &[PolicyRateLimit],
464) -> Result<Option<String>, StorageError> {
465 check_policy_rate_limits_for(pool, DEFAULT_ACCOUNT_ID, tool_name, category, limits).await
466}
467
468pub async fn record_policy_rate_limits_for(
470 pool: &DbPool,
471 account_id: &str,
472 tool_name: &str,
473 category: &str,
474 limits: &[PolicyRateLimit],
475) -> Result<(), StorageError> {
476 for limit in limits {
477 let matches = match limit.dimension {
478 RateLimitDimension::Tool => limit.match_value == tool_name,
479 RateLimitDimension::Category => limit.match_value == category,
480 RateLimitDimension::EngagementType => limit.match_value == tool_name,
481 RateLimitDimension::Global => true,
482 };
483
484 if matches {
485 let _ = increment_rate_limit_for(pool, account_id, &limit.key).await;
487 }
488 }
489 Ok(())
490}
491
492pub async fn record_policy_rate_limits(
494 pool: &DbPool,
495 tool_name: &str,
496 category: &str,
497 limits: &[PolicyRateLimit],
498) -> Result<(), StorageError> {
499 record_policy_rate_limits_for(pool, DEFAULT_ACCOUNT_ID, tool_name, category, limits).await
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::storage::init_test_db;
506
507 fn test_limits_config() -> LimitsConfig {
508 LimitsConfig {
509 max_replies_per_day: 3,
510 max_tweets_per_day: 2,
511 max_threads_per_week: 1,
512 min_action_delay_seconds: 30,
513 max_action_delay_seconds: 120,
514 max_replies_per_author_per_day: 1,
515 banned_phrases: vec![],
516 product_mention_ratio: 0.2,
517 }
518 }
519
520 fn test_intervals_config() -> IntervalsConfig {
521 IntervalsConfig {
522 mentions_check_seconds: 300,
523 discovery_search_seconds: 600,
524 content_post_window_seconds: 14400,
525 thread_interval_seconds: 604800,
526 }
527 }
528
529 #[tokio::test]
530 async fn init_creates_all_rate_limit_rows() {
531 let pool = init_test_db().await.expect("init db");
532 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
533 .await
534 .expect("init rate limits");
535
536 let limits = get_all_rate_limits(&pool).await.expect("get limits");
537 assert_eq!(limits.len(), 5);
538
539 let reply = limits
540 .iter()
541 .find(|l| l.action_type == "reply")
542 .expect("reply");
543 assert_eq!(reply.max_requests, 3);
544 assert_eq!(reply.period_seconds, 86400);
545 assert_eq!(reply.request_count, 0);
546
547 let thread = limits
548 .iter()
549 .find(|l| l.action_type == "thread")
550 .expect("thread");
551 assert_eq!(thread.max_requests, 1);
552 assert_eq!(thread.period_seconds, 604800);
553 }
554
555 #[tokio::test]
556 async fn init_preserves_existing_counters() {
557 let pool = init_test_db().await.expect("init db");
558 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
559 .await
560 .expect("first init");
561
562 increment_rate_limit(&pool, "reply")
564 .await
565 .expect("increment");
566
567 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
569 .await
570 .expect("second init");
571
572 let limits = get_all_rate_limits(&pool).await.expect("get limits");
573 let reply = limits
574 .iter()
575 .find(|l| l.action_type == "reply")
576 .expect("reply");
577 assert_eq!(reply.request_count, 1, "counter should be preserved");
578 }
579
580 #[tokio::test]
581 async fn check_rate_limit_allows_under_max() {
582 let pool = init_test_db().await.expect("init db");
583 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
584 .await
585 .expect("init");
586
587 assert!(check_rate_limit(&pool, "reply").await.expect("check"));
588 }
589
590 #[tokio::test]
591 async fn check_rate_limit_blocks_at_max() {
592 let pool = init_test_db().await.expect("init db");
593 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
594 .await
595 .expect("init");
596
597 for _ in 0..3 {
599 increment_rate_limit(&pool, "reply").await.expect("inc");
600 }
601
602 assert!(!check_rate_limit(&pool, "reply").await.expect("check"));
603 }
604
605 #[tokio::test]
606 async fn check_rate_limit_resets_expired_period() {
607 let pool = init_test_db().await.expect("init db");
608 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
609 .await
610 .expect("init");
611
612 for _ in 0..3 {
614 increment_rate_limit(&pool, "reply").await.expect("inc");
615 }
616 sqlx::query(
617 "UPDATE rate_limits SET period_start = strftime('%Y-%m-%dT%H:%M:%SZ', 'now', '-25 hours') \
618 WHERE action_type = 'reply'",
619 )
620 .execute(&pool)
621 .await
622 .expect("backdate");
623
624 assert!(check_rate_limit(&pool, "reply").await.expect("check"));
626
627 let limits = get_all_rate_limits(&pool).await.expect("get");
629 let reply = limits
630 .iter()
631 .find(|l| l.action_type == "reply")
632 .expect("reply");
633 assert_eq!(reply.request_count, 0);
634 }
635
636 #[tokio::test]
637 async fn check_rate_limit_unknown_type_allows() {
638 let pool = init_test_db().await.expect("init db");
639 assert!(check_rate_limit(&pool, "nonexistent").await.expect("check"));
640 }
641
642 #[tokio::test]
643 async fn check_and_increment_works() {
644 let pool = init_test_db().await.expect("init db");
645 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
646 .await
647 .expect("init");
648
649 assert!(check_and_increment_rate_limit(&pool, "reply")
651 .await
652 .expect("1"));
653 assert!(check_and_increment_rate_limit(&pool, "reply")
654 .await
655 .expect("2"));
656 assert!(check_and_increment_rate_limit(&pool, "reply")
657 .await
658 .expect("3"));
659 assert!(!check_and_increment_rate_limit(&pool, "reply")
660 .await
661 .expect("4"));
662
663 let limits = get_all_rate_limits(&pool).await.expect("get");
664 let reply = limits
665 .iter()
666 .find(|l| l.action_type == "reply")
667 .expect("reply");
668 assert_eq!(reply.request_count, 3);
669 }
670
671 #[tokio::test]
672 async fn increment_rate_limit_works() {
673 let pool = init_test_db().await.expect("init db");
674 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
675 .await
676 .expect("init");
677
678 increment_rate_limit(&pool, "tweet").await.expect("inc");
679 increment_rate_limit(&pool, "tweet").await.expect("inc");
680
681 let limits = get_all_rate_limits(&pool).await.expect("get");
682 let tweet = limits
683 .iter()
684 .find(|l| l.action_type == "tweet")
685 .expect("tweet");
686 assert_eq!(tweet.request_count, 2);
687 }
688
689 #[tokio::test]
690 async fn get_all_rate_limits_ordered() {
691 let pool = init_test_db().await.expect("init db");
692 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
693 .await
694 .expect("init");
695
696 let limits = get_all_rate_limits(&pool).await.expect("get");
697 let types: Vec<&str> = limits.iter().map(|l| l.action_type.as_str()).collect();
698 let mut sorted = types.clone();
699 sorted.sort();
700 assert_eq!(types, sorted, "should be sorted by action_type");
701 }
702
703 #[tokio::test]
704 async fn daily_usage_returns_correct_counts() {
705 let pool = init_test_db().await.expect("init db");
706 init_rate_limits(&pool, &test_limits_config(), &test_intervals_config())
707 .await
708 .expect("init");
709
710 increment_rate_limit(&pool, "reply").await.expect("inc");
711 increment_rate_limit(&pool, "reply").await.expect("inc");
712 increment_rate_limit(&pool, "tweet").await.expect("inc");
713
714 let usage = get_daily_usage(&pool).await.expect("get usage");
715
716 assert_eq!(usage.replies.used, 2);
717 assert_eq!(usage.replies.max, 3);
718 assert_eq!(usage.tweets.used, 1);
719 assert_eq!(usage.tweets.max, 2);
720 assert_eq!(usage.threads.used, 0);
721 assert_eq!(usage.threads.max, 1);
722 }
723}