Skip to main content

tuitbot_core/safety/
mod.rs

1//! Safety module for rate limiting and duplicate prevention.
2//!
3//! Provides the `SafetyGuard` as the primary pre-flight check interface
4//! for all automation loops. Combines rate limiting with deduplication
5//! to prevent API abuse and duplicate content.
6
7pub mod dedup;
8pub mod redact;
9
10use crate::error::StorageError;
11use crate::storage::rate_limits;
12use crate::storage::{author_interactions, DbPool};
13
14pub use dedup::DedupChecker;
15
16/// Wraps rate limit database operations with a clean API.
17pub struct RateLimiter {
18    pool: DbPool,
19}
20
21impl RateLimiter {
22    /// Create a new rate limiter backed by the given database pool.
23    pub fn new(pool: DbPool) -> Self {
24        Self { pool }
25    }
26
27    /// Check if a reply action is allowed under the current rate limit.
28    pub async fn can_reply(&self) -> Result<bool, StorageError> {
29        rate_limits::check_rate_limit(&self.pool, "reply").await
30    }
31
32    /// Check if a tweet action is allowed under the current rate limit.
33    pub async fn can_tweet(&self) -> Result<bool, StorageError> {
34        rate_limits::check_rate_limit(&self.pool, "tweet").await
35    }
36
37    /// Check if a thread action is allowed under the current rate limit.
38    pub async fn can_thread(&self) -> Result<bool, StorageError> {
39        rate_limits::check_rate_limit(&self.pool, "thread").await
40    }
41
42    /// Check if a search action is allowed under the current rate limit.
43    pub async fn can_search(&self) -> Result<bool, StorageError> {
44        rate_limits::check_rate_limit(&self.pool, "search").await
45    }
46
47    /// Record a successful reply action (increments counter).
48    pub async fn record_reply(&self) -> Result<(), StorageError> {
49        rate_limits::increment_rate_limit(&self.pool, "reply").await
50    }
51
52    /// Record a successful tweet action (increments counter).
53    pub async fn record_tweet(&self) -> Result<(), StorageError> {
54        rate_limits::increment_rate_limit(&self.pool, "tweet").await
55    }
56
57    /// Record a successful thread action (increments counter).
58    pub async fn record_thread(&self) -> Result<(), StorageError> {
59        rate_limits::increment_rate_limit(&self.pool, "thread").await
60    }
61
62    /// Record a successful search action (increments counter).
63    pub async fn record_search(&self) -> Result<(), StorageError> {
64        rate_limits::increment_rate_limit(&self.pool, "search").await
65    }
66
67    /// Atomically check and claim a rate limit slot.
68    ///
69    /// Returns `Ok(true)` if permitted (counter incremented),
70    /// `Ok(false)` if the rate limit is reached.
71    /// Preferred over separate check + record for posting actions.
72    pub async fn acquire_posting_permit(&self, action_type: &str) -> Result<bool, StorageError> {
73        rate_limits::check_and_increment_rate_limit(&self.pool, action_type).await
74    }
75}
76
77/// Reason an action was denied by the safety guard.
78#[derive(Debug, Clone, PartialEq)]
79pub enum DenialReason {
80    /// Action blocked by rate limiting.
81    RateLimited {
82        /// Which action type hit the limit.
83        action_type: String,
84        /// Current request count.
85        current: i64,
86        /// Maximum allowed requests.
87        max: i64,
88    },
89    /// Already replied to this tweet.
90    AlreadyReplied {
91        /// The tweet ID that was already replied to.
92        tweet_id: String,
93    },
94    /// Proposed reply is too similar to a recent reply.
95    SimilarPhrasing,
96    /// Reply contains a banned phrase.
97    BannedPhrase {
98        /// The banned phrase that was found.
99        phrase: String,
100    },
101    /// Already reached the per-author daily reply limit.
102    AuthorLimitReached,
103    /// Replying to own tweet.
104    SelfReply,
105}
106
107impl std::fmt::Display for DenialReason {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            Self::RateLimited {
111                action_type,
112                current,
113                max,
114            } => write!(f, "Rate limited: {action_type} ({current}/{max})"),
115            Self::AlreadyReplied { tweet_id } => {
116                write!(f, "Already replied to tweet {tweet_id}")
117            }
118            Self::SimilarPhrasing => {
119                write!(f, "Reply phrasing too similar to recent replies")
120            }
121            Self::BannedPhrase { phrase } => {
122                write!(f, "Reply contains banned phrase: \"{phrase}\"")
123            }
124            Self::AuthorLimitReached => {
125                write!(f, "Already reached daily reply limit for this author")
126            }
127            Self::SelfReply => {
128                write!(f, "Cannot reply to own tweets")
129            }
130        }
131    }
132}
133
134/// Check if any banned phrase appears in the text (case-insensitive).
135///
136/// Returns the first matching banned phrase, or `None` if clean.
137pub fn contains_banned_phrase(text: &str, banned: &[String]) -> Option<String> {
138    let text_lower = text.to_lowercase();
139    for phrase in banned {
140        if text_lower.contains(&phrase.to_lowercase()) {
141            return Some(phrase.clone());
142        }
143    }
144    None
145}
146
147/// Check if the tweet author is the bot's own user ID.
148pub fn is_self_reply(tweet_author_id: &str, own_user_id: &str) -> bool {
149    !tweet_author_id.is_empty() && !own_user_id.is_empty() && tweet_author_id == own_user_id
150}
151
152/// Combined safety guard for all automation loops.
153///
154/// Provides pre-flight checks that combine rate limiting with deduplication.
155/// All automation loops should call `SafetyGuard` methods before taking actions.
156pub struct SafetyGuard {
157    rate_limiter: RateLimiter,
158    dedup_checker: DedupChecker,
159    pool: DbPool,
160}
161
162impl SafetyGuard {
163    /// Create a new safety guard backed by the given database pool.
164    pub fn new(pool: DbPool) -> Self {
165        Self {
166            rate_limiter: RateLimiter::new(pool.clone()),
167            dedup_checker: DedupChecker::new(pool.clone()),
168            pool,
169        }
170    }
171
172    /// Check whether replying to a tweet is permitted.
173    ///
174    /// Checks rate limits, exact dedup, and optionally phrasing similarity.
175    /// Returns `Ok(Ok(()))` if allowed, `Ok(Err(DenialReason))` if blocked,
176    /// or `Err(StorageError)` on infrastructure failure.
177    pub async fn can_reply_to(
178        &self,
179        tweet_id: &str,
180        proposed_reply: Option<&str>,
181    ) -> Result<Result<(), DenialReason>, StorageError> {
182        // Check rate limit
183        if !self.rate_limiter.can_reply().await? {
184            let limits = rate_limits::get_all_rate_limits(&self.rate_limiter.pool).await?;
185            let reply_limit = limits.iter().find(|l| l.action_type == "reply");
186            let (current, max) = reply_limit
187                .map(|l| (l.request_count, l.max_requests))
188                .unwrap_or((0, 0));
189
190            tracing::debug!(
191                action = "reply",
192                current,
193                max,
194                "Action denied: rate limited"
195            );
196
197            return Ok(Err(DenialReason::RateLimited {
198                action_type: "reply".to_string(),
199                current,
200                max,
201            }));
202        }
203
204        // Check exact dedup
205        if self.dedup_checker.has_replied_to(tweet_id).await? {
206            tracing::debug!(tweet_id, "Action denied: already replied");
207            return Ok(Err(DenialReason::AlreadyReplied {
208                tweet_id: tweet_id.to_string(),
209            }));
210        }
211
212        // Check phrasing similarity
213        if let Some(reply_text) = proposed_reply {
214            if self
215                .dedup_checker
216                .is_phrasing_similar(reply_text, 20)
217                .await?
218            {
219                tracing::debug!("Action denied: similar phrasing");
220                return Ok(Err(DenialReason::SimilarPhrasing));
221            }
222        }
223
224        Ok(Ok(()))
225    }
226
227    /// Check whether posting an original tweet is permitted.
228    ///
229    /// Only checks rate limits (no dedup for original tweets).
230    pub async fn can_post_tweet(&self) -> Result<Result<(), DenialReason>, StorageError> {
231        if !self.rate_limiter.can_tweet().await? {
232            let limits = rate_limits::get_all_rate_limits(&self.rate_limiter.pool).await?;
233            let tweet_limit = limits.iter().find(|l| l.action_type == "tweet");
234            let (current, max) = tweet_limit
235                .map(|l| (l.request_count, l.max_requests))
236                .unwrap_or((0, 0));
237
238            tracing::debug!(
239                action = "tweet",
240                current,
241                max,
242                "Action denied: rate limited"
243            );
244
245            return Ok(Err(DenialReason::RateLimited {
246                action_type: "tweet".to_string(),
247                current,
248                max,
249            }));
250        }
251
252        Ok(Ok(()))
253    }
254
255    /// Check whether posting a thread is permitted.
256    ///
257    /// Only checks rate limits (no dedup for threads).
258    pub async fn can_post_thread(&self) -> Result<Result<(), DenialReason>, StorageError> {
259        if !self.rate_limiter.can_thread().await? {
260            let limits = rate_limits::get_all_rate_limits(&self.rate_limiter.pool).await?;
261            let thread_limit = limits.iter().find(|l| l.action_type == "thread");
262            let (current, max) = thread_limit
263                .map(|l| (l.request_count, l.max_requests))
264                .unwrap_or((0, 0));
265
266            tracing::debug!(
267                action = "thread",
268                current,
269                max,
270                "Action denied: rate limited"
271            );
272
273            return Ok(Err(DenialReason::RateLimited {
274                action_type: "thread".to_string(),
275                current,
276                max,
277            }));
278        }
279
280        Ok(Ok(()))
281    }
282
283    /// Check if replying to this author is permitted (per-author daily limit).
284    pub async fn check_author_limit(
285        &self,
286        author_id: &str,
287        max_per_day: u32,
288    ) -> Result<Result<(), DenialReason>, StorageError> {
289        let count =
290            author_interactions::get_author_reply_count_today(&self.pool, author_id).await?;
291        if count >= max_per_day as i64 {
292            tracing::debug!(
293                author_id,
294                count,
295                max = max_per_day,
296                "Action denied: author daily limit reached"
297            );
298            return Ok(Err(DenialReason::AuthorLimitReached));
299        }
300        Ok(Ok(()))
301    }
302
303    /// Check if a generated reply contains a banned phrase.
304    pub fn check_banned_phrases(reply_text: &str, banned: &[String]) -> Result<(), DenialReason> {
305        if let Some(phrase) = contains_banned_phrase(reply_text, banned) {
306            tracing::debug!(phrase = %phrase, "Action denied: banned phrase");
307            return Err(DenialReason::BannedPhrase { phrase });
308        }
309        Ok(())
310    }
311
312    /// Record a reply for an author interaction.
313    pub async fn record_author_interaction(
314        &self,
315        author_id: &str,
316        author_username: &str,
317    ) -> Result<(), StorageError> {
318        author_interactions::increment_author_interaction(&self.pool, author_id, author_username)
319            .await
320    }
321
322    /// Record a successful reply action.
323    pub async fn record_reply(&self) -> Result<(), StorageError> {
324        self.rate_limiter.record_reply().await
325    }
326
327    /// Record a successful tweet action.
328    pub async fn record_tweet(&self) -> Result<(), StorageError> {
329        self.rate_limiter.record_tweet().await
330    }
331
332    /// Record a successful thread action.
333    pub async fn record_thread(&self) -> Result<(), StorageError> {
334        self.rate_limiter.record_thread().await
335    }
336
337    /// Get a reference to the underlying rate limiter.
338    pub fn rate_limiter(&self) -> &RateLimiter {
339        &self.rate_limiter
340    }
341
342    /// Get a reference to the underlying dedup checker.
343    pub fn dedup_checker(&self) -> &DedupChecker {
344        &self.dedup_checker
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::config::{IntervalsConfig, LimitsConfig};
352    use crate::storage::init_test_db;
353    use crate::storage::replies::{insert_reply, ReplySent};
354
355    fn test_limits() -> LimitsConfig {
356        LimitsConfig {
357            max_replies_per_day: 3,
358            max_tweets_per_day: 2,
359            max_threads_per_week: 1,
360            min_action_delay_seconds: 30,
361            max_action_delay_seconds: 120,
362            max_replies_per_author_per_day: 1,
363            banned_phrases: vec!["check out".to_string(), "you should try".to_string()],
364            product_mention_ratio: 0.2,
365        }
366    }
367
368    fn test_intervals() -> IntervalsConfig {
369        IntervalsConfig {
370            mentions_check_seconds: 300,
371            discovery_search_seconds: 600,
372            content_post_window_seconds: 14400,
373            thread_interval_seconds: 604800,
374        }
375    }
376
377    async fn setup_guard() -> (DbPool, SafetyGuard) {
378        let pool = init_test_db().await.expect("init db");
379        rate_limits::init_rate_limits(&pool, &test_limits(), &test_intervals())
380            .await
381            .expect("init rate limits");
382        let guard = SafetyGuard::new(pool.clone());
383        (pool, guard)
384    }
385
386    fn sample_reply(target_id: &str, content: &str) -> ReplySent {
387        ReplySent {
388            id: 0,
389            target_tweet_id: target_id.to_string(),
390            reply_tweet_id: Some("r_123".to_string()),
391            reply_content: content.to_string(),
392            llm_provider: None,
393            llm_model: None,
394            created_at: chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string(),
395            status: "sent".to_string(),
396            error_message: None,
397        }
398    }
399
400    #[tokio::test]
401    async fn rate_limiter_can_reply_and_record() {
402        let pool = init_test_db().await.expect("init db");
403        rate_limits::init_rate_limits(&pool, &test_limits(), &test_intervals())
404            .await
405            .expect("init");
406
407        let limiter = RateLimiter::new(pool);
408
409        assert!(limiter.can_reply().await.expect("check"));
410        limiter.record_reply().await.expect("record");
411        limiter.record_reply().await.expect("record");
412        limiter.record_reply().await.expect("record");
413        assert!(!limiter.can_reply().await.expect("check"));
414    }
415
416    #[tokio::test]
417    async fn rate_limiter_acquire_posting_permit() {
418        let pool = init_test_db().await.expect("init db");
419        rate_limits::init_rate_limits(&pool, &test_limits(), &test_intervals())
420            .await
421            .expect("init");
422
423        let limiter = RateLimiter::new(pool);
424
425        assert!(limiter.acquire_posting_permit("tweet").await.expect("1"));
426        assert!(limiter.acquire_posting_permit("tweet").await.expect("2"));
427        assert!(!limiter.acquire_posting_permit("tweet").await.expect("3"));
428    }
429
430    #[tokio::test]
431    async fn safety_guard_allows_new_reply() {
432        let (_pool, guard) = setup_guard().await;
433
434        let result = guard.can_reply_to("tweet_1", None).await.expect("check");
435        assert!(result.is_ok());
436    }
437
438    #[tokio::test]
439    async fn safety_guard_blocks_already_replied() {
440        let (pool, guard) = setup_guard().await;
441
442        let reply = sample_reply("tweet_1", "Some reply content");
443        insert_reply(&pool, &reply).await.expect("insert");
444
445        let result = guard.can_reply_to("tweet_1", None).await.expect("check");
446        assert_eq!(
447            result,
448            Err(DenialReason::AlreadyReplied {
449                tweet_id: "tweet_1".to_string()
450            })
451        );
452    }
453
454    #[tokio::test]
455    async fn safety_guard_blocks_rate_limited() {
456        let (_pool, guard) = setup_guard().await;
457
458        // Exhaust the reply limit (max = 3)
459        for _ in 0..3 {
460            guard.record_reply().await.expect("record");
461        }
462
463        let result = guard.can_reply_to("tweet_new", None).await.expect("check");
464        match result {
465            Err(DenialReason::RateLimited {
466                action_type,
467                current,
468                max,
469            }) => {
470                assert_eq!(action_type, "reply");
471                assert_eq!(current, 3);
472                assert_eq!(max, 3);
473            }
474            other => panic!("expected RateLimited, got: {other:?}"),
475        }
476    }
477
478    #[tokio::test]
479    async fn safety_guard_blocks_similar_phrasing() {
480        let (pool, guard) = setup_guard().await;
481
482        let reply = sample_reply(
483            "tweet_1",
484            "This is a great tool for developers and engineers to use daily",
485        );
486        insert_reply(&pool, &reply).await.expect("insert");
487
488        let result = guard
489            .can_reply_to(
490                "tweet_2",
491                Some("This is a great tool for developers and engineers to use often"),
492            )
493            .await
494            .expect("check");
495
496        assert_eq!(result, Err(DenialReason::SimilarPhrasing));
497    }
498
499    #[tokio::test]
500    async fn safety_guard_allows_different_phrasing() {
501        let (pool, guard) = setup_guard().await;
502
503        let reply = sample_reply(
504            "tweet_1",
505            "This is a great tool for developers and engineers to use daily",
506        );
507        insert_reply(&pool, &reply).await.expect("insert");
508
509        let result = guard
510            .can_reply_to(
511                "tweet_2",
512                Some("I love cooking pasta with fresh basil and tomatoes every day"),
513            )
514            .await
515            .expect("check");
516
517        assert!(result.is_ok());
518    }
519
520    #[tokio::test]
521    async fn safety_guard_can_post_tweet_allowed() {
522        let (_pool, guard) = setup_guard().await;
523
524        let result = guard.can_post_tweet().await.expect("check");
525        assert!(result.is_ok());
526    }
527
528    #[tokio::test]
529    async fn safety_guard_can_post_tweet_blocked() {
530        let (_pool, guard) = setup_guard().await;
531
532        // Exhaust tweet limit (max = 2)
533        guard.record_tweet().await.expect("record");
534        guard.record_tweet().await.expect("record");
535
536        let result = guard.can_post_tweet().await.expect("check");
537        assert!(result.is_err());
538    }
539
540    #[tokio::test]
541    async fn safety_guard_can_post_thread_allowed() {
542        let (_pool, guard) = setup_guard().await;
543
544        let result = guard.can_post_thread().await.expect("check");
545        assert!(result.is_ok());
546    }
547
548    #[tokio::test]
549    async fn safety_guard_can_post_thread_blocked() {
550        let (_pool, guard) = setup_guard().await;
551
552        // Exhaust thread limit (max = 1)
553        guard.record_thread().await.expect("record");
554
555        let result = guard.can_post_thread().await.expect("check");
556        assert!(result.is_err());
557    }
558
559    #[tokio::test]
560    async fn denial_reason_display() {
561        let rate = DenialReason::RateLimited {
562            action_type: "reply".to_string(),
563            current: 20,
564            max: 20,
565        };
566        assert_eq!(rate.to_string(), "Rate limited: reply (20/20)");
567
568        let replied = DenialReason::AlreadyReplied {
569            tweet_id: "abc123".to_string(),
570        };
571        assert_eq!(replied.to_string(), "Already replied to tweet abc123");
572
573        let similar = DenialReason::SimilarPhrasing;
574        assert_eq!(
575            similar.to_string(),
576            "Reply phrasing too similar to recent replies"
577        );
578
579        let banned = DenialReason::BannedPhrase {
580            phrase: "check out".to_string(),
581        };
582        assert_eq!(
583            banned.to_string(),
584            "Reply contains banned phrase: \"check out\""
585        );
586
587        let author = DenialReason::AuthorLimitReached;
588        assert_eq!(
589            author.to_string(),
590            "Already reached daily reply limit for this author"
591        );
592
593        let self_reply = DenialReason::SelfReply;
594        assert_eq!(self_reply.to_string(), "Cannot reply to own tweets");
595    }
596
597    #[test]
598    fn contains_banned_phrase_detects_match() {
599        let banned = vec!["check out".to_string(), "link in bio".to_string()];
600        assert_eq!(
601            contains_banned_phrase("You should check out this tool!", &banned),
602            Some("check out".to_string())
603        );
604    }
605
606    #[test]
607    fn contains_banned_phrase_case_insensitive() {
608        let banned = vec!["Check Out".to_string()];
609        assert_eq!(
610            contains_banned_phrase("check out this thing", &banned),
611            Some("Check Out".to_string())
612        );
613    }
614
615    #[test]
616    fn contains_banned_phrase_no_match() {
617        let banned = vec!["check out".to_string()];
618        assert_eq!(
619            contains_banned_phrase("This is a helpful reply", &banned),
620            None
621        );
622    }
623
624    #[test]
625    fn is_self_reply_detects_self() {
626        assert!(is_self_reply("user_123", "user_123"));
627    }
628
629    #[test]
630    fn is_self_reply_different_users() {
631        assert!(!is_self_reply("user_123", "user_456"));
632    }
633
634    #[test]
635    fn is_self_reply_empty_ids() {
636        assert!(!is_self_reply("", "user_123"));
637        assert!(!is_self_reply("user_123", ""));
638        assert!(!is_self_reply("", ""));
639    }
640
641    #[tokio::test]
642    async fn safety_guard_check_author_limit_allows_first() {
643        let (_pool, guard) = setup_guard().await;
644        let result = guard
645            .check_author_limit("author_1", 1)
646            .await
647            .expect("check");
648        assert!(result.is_ok());
649    }
650
651    #[tokio::test]
652    async fn safety_guard_check_author_limit_blocks_over_limit() {
653        let (_pool, guard) = setup_guard().await;
654        guard
655            .record_author_interaction("author_1", "alice")
656            .await
657            .expect("record");
658
659        let result = guard
660            .check_author_limit("author_1", 1)
661            .await
662            .expect("check");
663        assert_eq!(result, Err(DenialReason::AuthorLimitReached));
664    }
665
666    #[test]
667    fn check_banned_phrases_blocks_banned() {
668        let banned = vec!["check out".to_string(), "I recommend".to_string()];
669        let result = SafetyGuard::check_banned_phrases("You should check out this tool!", &banned);
670        assert_eq!(
671            result,
672            Err(DenialReason::BannedPhrase {
673                phrase: "check out".to_string()
674            })
675        );
676    }
677
678    #[test]
679    fn check_banned_phrases_allows_clean() {
680        let banned = vec!["check out".to_string()];
681        let result = SafetyGuard::check_banned_phrases("Great insight on testing!", &banned);
682        assert!(result.is_ok());
683    }
684
685    #[tokio::test]
686    async fn safety_guard_exposes_rate_limiter_and_dedup() {
687        let (_pool, guard) = setup_guard().await;
688
689        // Verify accessors work without panicking
690        assert!(guard.rate_limiter().can_search().await.expect("search"));
691        let phrases = guard
692            .dedup_checker()
693            .get_recent_reply_phrases(5)
694            .await
695            .expect("phrases");
696        assert!(phrases.is_empty());
697    }
698}