1use super::loop_helpers::{ContentSafety, ContentStorage, TopicScorer, TweetGenerator};
8use super::schedule::{apply_slot_jitter, schedule_gate, ActiveSchedule};
9use super::scheduler::LoopScheduler;
10use rand::seq::SliceRandom;
11use rand::SeedableRng;
12use std::sync::Arc;
13use tokio_util::sync::CancellationToken;
14
15const EXPLOIT_RATIO: f64 = 0.8;
17
18pub struct ContentLoop {
20 generator: Arc<dyn TweetGenerator>,
21 safety: Arc<dyn ContentSafety>,
22 storage: Arc<dyn ContentStorage>,
23 topic_scorer: Option<Arc<dyn TopicScorer>>,
24 topics: Vec<String>,
25 post_window_secs: u64,
26 dry_run: bool,
27}
28
29#[derive(Debug)]
31pub enum ContentResult {
32 Posted { topic: String, content: String },
34 TooSoon { elapsed_secs: u64, window_secs: u64 },
36 RateLimited,
38 NoTopics,
40 Failed { error: String },
42}
43
44impl ContentLoop {
45 pub fn new(
47 generator: Arc<dyn TweetGenerator>,
48 safety: Arc<dyn ContentSafety>,
49 storage: Arc<dyn ContentStorage>,
50 topics: Vec<String>,
51 post_window_secs: u64,
52 dry_run: bool,
53 ) -> Self {
54 Self {
55 generator,
56 safety,
57 storage,
58 topic_scorer: None,
59 topics,
60 post_window_secs,
61 dry_run,
62 }
63 }
64
65 pub fn with_topic_scorer(mut self, scorer: Arc<dyn TopicScorer>) -> Self {
70 self.topic_scorer = Some(scorer);
71 self
72 }
73
74 pub async fn run(
76 &self,
77 cancel: CancellationToken,
78 scheduler: LoopScheduler,
79 schedule: Option<Arc<ActiveSchedule>>,
80 ) {
81 let slot_mode = schedule.as_ref().is_some_and(|s| s.has_preferred_times());
82
83 tracing::info!(
84 dry_run = self.dry_run,
85 topics = self.topics.len(),
86 window_secs = self.post_window_secs,
87 slot_mode = slot_mode,
88 "Content loop started"
89 );
90
91 if self.topics.is_empty() {
92 tracing::warn!("No topics configured, content loop has nothing to post");
93 cancel.cancelled().await;
94 return;
95 }
96
97 let min_recent = 3usize;
98 let max_recent = (self.topics.len() / 2)
99 .max(min_recent)
100 .min(self.topics.len());
101 let mut recent_topics: Vec<String> = Vec::with_capacity(max_recent);
102 let mut rng = rand::rngs::StdRng::from_entropy();
103
104 loop {
105 if cancel.is_cancelled() {
106 break;
107 }
108
109 if !schedule_gate(&schedule, &cancel).await {
110 break;
111 }
112
113 if slot_mode {
114 let sched = schedule.as_ref().expect("slot_mode requires schedule");
116
117 let today_posts = match self.storage.todays_tweet_times().await {
119 Ok(times) => times,
120 Err(e) => {
121 tracing::warn!(error = %e, "Failed to query today's tweet times");
122 Vec::new()
123 }
124 };
125
126 match sched.next_unused_slot(&today_posts) {
127 Some((wait, slot)) => {
128 let jittered_wait = apply_slot_jitter(wait);
129 tracing::info!(
130 slot = %slot.format(),
131 wait_secs = jittered_wait.as_secs(),
132 "Slot mode: sleeping until next posting slot"
133 );
134
135 tokio::select! {
136 _ = cancel.cancelled() => break,
137 _ = tokio::time::sleep(jittered_wait) => {},
138 }
139
140 if cancel.is_cancelled() {
141 break;
142 }
143
144 let result = self
146 .run_slot_iteration(&mut recent_topics, max_recent, &mut rng)
147 .await;
148 self.log_content_result(&result);
149 }
150 None => {
151 tracing::info!(
153 "Slot mode: all slots used today, sleeping until next active period"
154 );
155 if let Some(sched) = &schedule {
156 let wait = sched.time_until_active();
157 if wait.is_zero() {
158 tokio::select! {
160 _ = cancel.cancelled() => break,
161 _ = tokio::time::sleep(std::time::Duration::from_secs(3600)) => {},
162 }
163 } else {
164 tokio::select! {
165 _ = cancel.cancelled() => break,
166 _ = tokio::time::sleep(wait) => {},
167 }
168 }
169 } else {
170 tokio::select! {
171 _ = cancel.cancelled() => break,
172 _ = tokio::time::sleep(std::time::Duration::from_secs(3600)) => {},
173 }
174 }
175 }
176 }
177 } else {
178 let result = self
180 .run_iteration(&mut recent_topics, max_recent, &mut rng)
181 .await;
182 self.log_content_result(&result);
183
184 tokio::select! {
185 _ = cancel.cancelled() => break,
186 _ = scheduler.tick() => {},
187 }
188 }
189 }
190
191 tracing::info!("Content loop stopped");
192 }
193
194 fn log_content_result(&self, result: &ContentResult) {
196 match result {
197 ContentResult::Posted { topic, content } => {
198 tracing::info!(
199 topic = %topic,
200 chars = content.len(),
201 dry_run = self.dry_run,
202 "Content iteration: tweet posted"
203 );
204 }
205 ContentResult::TooSoon {
206 elapsed_secs,
207 window_secs,
208 } => {
209 tracing::debug!(
210 elapsed = elapsed_secs,
211 window = window_secs,
212 "Content iteration: too soon since last tweet"
213 );
214 }
215 ContentResult::RateLimited => {
216 tracing::info!("Content iteration: daily tweet limit reached");
217 }
218 ContentResult::NoTopics => {
219 tracing::warn!("Content iteration: no topics available");
220 }
221 ContentResult::Failed { error } => {
222 tracing::warn!(error = %error, "Content iteration: failed");
223 }
224 }
225 }
226
227 async fn run_slot_iteration(
229 &self,
230 recent_topics: &mut Vec<String>,
231 max_recent: usize,
232 rng: &mut impl rand::Rng,
233 ) -> ContentResult {
234 if !self.safety.can_post_tweet().await {
236 return ContentResult::RateLimited;
237 }
238
239 let topic = self.pick_topic_epsilon_greedy(recent_topics, rng).await;
241
242 let result = self.generate_and_post(&topic).await;
243
244 if matches!(result, ContentResult::Posted { .. }) {
246 if recent_topics.len() >= max_recent {
247 recent_topics.remove(0);
248 }
249 recent_topics.push(topic);
250 }
251
252 result
253 }
254
255 pub async fn run_once(&self, topic: Option<&str>) -> ContentResult {
260 let chosen_topic = match topic {
261 Some(t) => t.to_string(),
262 None => {
263 if self.topics.is_empty() {
264 return ContentResult::NoTopics;
265 }
266 let mut rng = rand::thread_rng();
267 self.topics
268 .choose(&mut rng)
269 .expect("topics is non-empty")
270 .clone()
271 }
272 };
273
274 if !self.safety.can_post_tweet().await {
276 return ContentResult::RateLimited;
277 }
278
279 self.generate_and_post(&chosen_topic).await
280 }
281
282 async fn run_iteration(
284 &self,
285 recent_topics: &mut Vec<String>,
286 max_recent: usize,
287 rng: &mut impl rand::Rng,
288 ) -> ContentResult {
289 match self.storage.last_tweet_time().await {
291 Ok(Some(last_time)) => {
292 let elapsed = chrono::Utc::now()
293 .signed_duration_since(last_time)
294 .num_seconds()
295 .max(0) as u64;
296
297 if elapsed < self.post_window_secs {
298 return ContentResult::TooSoon {
299 elapsed_secs: elapsed,
300 window_secs: self.post_window_secs,
301 };
302 }
303 }
304 Ok(None) => {
305 }
307 Err(e) => {
308 tracing::warn!(error = %e, "Failed to query last tweet time, proceeding anyway");
309 }
310 }
311
312 if !self.safety.can_post_tweet().await {
314 return ContentResult::RateLimited;
315 }
316
317 let topic = self.pick_topic_epsilon_greedy(recent_topics, rng).await;
319
320 let result = self.generate_and_post(&topic).await;
321
322 if matches!(result, ContentResult::Posted { .. }) {
324 if recent_topics.len() >= max_recent {
325 recent_topics.remove(0);
326 }
327 recent_topics.push(topic);
328 }
329
330 result
331 }
332
333 async fn pick_topic_epsilon_greedy(
342 &self,
343 recent_topics: &mut Vec<String>,
344 rng: &mut impl rand::Rng,
345 ) -> String {
346 if let Some(scorer) = &self.topic_scorer {
347 let roll: f64 = rng.gen();
348 if roll < EXPLOIT_RATIO {
349 if let Ok(top_topics) = scorer.get_top_topics(10).await {
351 let candidates: Vec<&String> = top_topics
353 .iter()
354 .filter(|t| self.topics.contains(t) && !recent_topics.contains(t))
355 .collect();
356
357 if !candidates.is_empty() {
358 let topic = candidates[0].clone();
359 tracing::debug!(topic = %topic, "Epsilon-greedy: exploiting top topic");
360 return topic;
361 }
362 }
363 tracing::debug!("Epsilon-greedy: no top topics available, falling back to random");
365 } else {
366 tracing::debug!("Epsilon-greedy: exploring random topic");
367 }
368 }
369
370 pick_topic(&self.topics, recent_topics, rng)
371 }
372
373 async fn generate_and_post(&self, topic: &str) -> ContentResult {
375 tracing::info!(topic = %topic, "Generating tweet on topic");
376
377 let content = match self.generator.generate_tweet(topic).await {
379 Ok(text) => text,
380 Err(e) => {
381 return ContentResult::Failed {
382 error: format!("Generation failed: {e}"),
383 }
384 }
385 };
386
387 let content = if content.len() > 280 {
389 tracing::debug!(
391 chars = content.len(),
392 "Generated tweet too long, retrying with shorter instruction"
393 );
394
395 let shorter_topic = format!("{topic} (IMPORTANT: keep under 280 characters)");
396 match self.generator.generate_tweet(&shorter_topic).await {
397 Ok(text) if text.len() <= 280 => text,
398 Ok(text) => {
399 tracing::warn!(
401 chars = text.len(),
402 "Retry still too long, truncating at word boundary"
403 );
404 truncate_at_word_boundary(&text, 280)
405 }
406 Err(e) => {
407 tracing::warn!(error = %e, "Retry generation failed, truncating original");
409 truncate_at_word_boundary(&content, 280)
410 }
411 }
412 } else {
413 content
414 };
415
416 if self.dry_run {
417 tracing::info!(
418 "DRY RUN: Would post tweet on topic '{}': \"{}\" ({} chars)",
419 topic,
420 content,
421 content.len()
422 );
423
424 let _ = self
425 .storage
426 .log_action(
427 "tweet",
428 "dry_run",
429 &format!("Topic '{}': {}", topic, truncate_display(&content, 80)),
430 )
431 .await;
432 } else {
433 if let Err(e) = self.storage.post_tweet(topic, &content).await {
434 tracing::error!(error = %e, "Failed to post tweet");
435 let _ = self
436 .storage
437 .log_action("tweet", "failure", &format!("Post failed: {e}"))
438 .await;
439 return ContentResult::Failed {
440 error: e.to_string(),
441 };
442 }
443
444 let _ = self
445 .storage
446 .log_action(
447 "tweet",
448 "success",
449 &format!("Topic '{}': {}", topic, truncate_display(&content, 80)),
450 )
451 .await;
452 }
453
454 ContentResult::Posted {
455 topic: topic.to_string(),
456 content,
457 }
458 }
459}
460
461fn pick_topic(topics: &[String], recent: &mut Vec<String>, rng: &mut impl rand::Rng) -> String {
464 let available: Vec<&String> = topics.iter().filter(|t| !recent.contains(t)).collect();
465
466 if available.is_empty() {
467 recent.clear();
469 topics.choose(rng).expect("topics is non-empty").clone()
470 } else {
471 available
472 .choose(rng)
473 .expect("available is non-empty")
474 .to_string()
475 }
476}
477
478fn truncate_at_word_boundary(s: &str, max_len: usize) -> String {
480 if s.len() <= max_len {
481 return s.to_string();
482 }
483
484 let cutoff = max_len.saturating_sub(3);
486 match s[..cutoff].rfind(' ') {
487 Some(pos) => format!("{}...", &s[..pos]),
488 None => format!("{}...", &s[..cutoff]),
489 }
490}
491
492fn truncate_display(s: &str, max_len: usize) -> String {
494 if s.len() <= max_len {
495 s.to_string()
496 } else {
497 format!("{}...", &s[..max_len])
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::automation::ContentLoopError;
505 use std::sync::Mutex;
506
507 struct MockGenerator {
510 response: String,
511 }
512
513 #[async_trait::async_trait]
514 impl TweetGenerator for MockGenerator {
515 async fn generate_tweet(&self, _topic: &str) -> Result<String, ContentLoopError> {
516 Ok(self.response.clone())
517 }
518 }
519
520 struct OverlongGenerator {
521 first_response: String,
522 retry_response: String,
523 call_count: Mutex<usize>,
524 }
525
526 #[async_trait::async_trait]
527 impl TweetGenerator for OverlongGenerator {
528 async fn generate_tweet(&self, _topic: &str) -> Result<String, ContentLoopError> {
529 let mut count = self.call_count.lock().expect("lock");
530 *count += 1;
531 if *count == 1 {
532 Ok(self.first_response.clone())
533 } else {
534 Ok(self.retry_response.clone())
535 }
536 }
537 }
538
539 struct FailingGenerator;
540
541 #[async_trait::async_trait]
542 impl TweetGenerator for FailingGenerator {
543 async fn generate_tweet(&self, _topic: &str) -> Result<String, ContentLoopError> {
544 Err(ContentLoopError::LlmFailure(
545 "model unavailable".to_string(),
546 ))
547 }
548 }
549
550 struct MockSafety {
551 can_tweet: bool,
552 can_thread: bool,
553 }
554
555 #[async_trait::async_trait]
556 impl ContentSafety for MockSafety {
557 async fn can_post_tweet(&self) -> bool {
558 self.can_tweet
559 }
560 async fn can_post_thread(&self) -> bool {
561 self.can_thread
562 }
563 }
564
565 struct MockStorage {
566 last_tweet: Mutex<Option<chrono::DateTime<chrono::Utc>>>,
567 posted_tweets: Mutex<Vec<(String, String)>>,
568 actions: Mutex<Vec<(String, String, String)>>,
569 }
570
571 impl MockStorage {
572 fn new(last_tweet: Option<chrono::DateTime<chrono::Utc>>) -> Self {
573 Self {
574 last_tweet: Mutex::new(last_tweet),
575 posted_tweets: Mutex::new(Vec::new()),
576 actions: Mutex::new(Vec::new()),
577 }
578 }
579
580 fn posted_count(&self) -> usize {
581 self.posted_tweets.lock().expect("lock").len()
582 }
583
584 fn action_count(&self) -> usize {
585 self.actions.lock().expect("lock").len()
586 }
587 }
588
589 #[async_trait::async_trait]
590 impl ContentStorage for MockStorage {
591 async fn last_tweet_time(
592 &self,
593 ) -> Result<Option<chrono::DateTime<chrono::Utc>>, ContentLoopError> {
594 Ok(*self.last_tweet.lock().expect("lock"))
595 }
596
597 async fn last_thread_time(
598 &self,
599 ) -> Result<Option<chrono::DateTime<chrono::Utc>>, ContentLoopError> {
600 Ok(None)
601 }
602
603 async fn todays_tweet_times(
604 &self,
605 ) -> Result<Vec<chrono::DateTime<chrono::Utc>>, ContentLoopError> {
606 Ok(Vec::new())
607 }
608
609 async fn post_tweet(&self, topic: &str, content: &str) -> Result<(), ContentLoopError> {
610 self.posted_tweets
611 .lock()
612 .expect("lock")
613 .push((topic.to_string(), content.to_string()));
614 Ok(())
615 }
616
617 async fn create_thread(
618 &self,
619 _topic: &str,
620 _tweet_count: usize,
621 ) -> Result<String, ContentLoopError> {
622 Ok("thread-1".to_string())
623 }
624
625 async fn update_thread_status(
626 &self,
627 _thread_id: &str,
628 _status: &str,
629 _tweet_count: usize,
630 _root_tweet_id: Option<&str>,
631 ) -> Result<(), ContentLoopError> {
632 Ok(())
633 }
634
635 async fn store_thread_tweet(
636 &self,
637 _thread_id: &str,
638 _position: usize,
639 _tweet_id: &str,
640 _content: &str,
641 ) -> Result<(), ContentLoopError> {
642 Ok(())
643 }
644
645 async fn log_action(
646 &self,
647 action_type: &str,
648 status: &str,
649 message: &str,
650 ) -> Result<(), ContentLoopError> {
651 self.actions.lock().expect("lock").push((
652 action_type.to_string(),
653 status.to_string(),
654 message.to_string(),
655 ));
656 Ok(())
657 }
658 }
659
660 fn make_topics() -> Vec<String> {
661 vec![
662 "Rust".to_string(),
663 "CLI tools".to_string(),
664 "Open source".to_string(),
665 "Developer productivity".to_string(),
666 ]
667 }
668
669 #[tokio::test]
672 async fn run_once_posts_tweet() {
673 let storage = Arc::new(MockStorage::new(None));
674 let content = ContentLoop::new(
675 Arc::new(MockGenerator {
676 response: "Great tweet about Rust!".to_string(),
677 }),
678 Arc::new(MockSafety {
679 can_tweet: true,
680 can_thread: true,
681 }),
682 storage.clone(),
683 make_topics(),
684 14400,
685 false,
686 );
687
688 let result = content.run_once(Some("Rust")).await;
689 assert!(matches!(result, ContentResult::Posted { .. }));
690 assert_eq!(storage.posted_count(), 1);
691 }
692
693 #[tokio::test]
694 async fn run_once_dry_run_does_not_post() {
695 let storage = Arc::new(MockStorage::new(None));
696 let content = ContentLoop::new(
697 Arc::new(MockGenerator {
698 response: "Great tweet about Rust!".to_string(),
699 }),
700 Arc::new(MockSafety {
701 can_tweet: true,
702 can_thread: true,
703 }),
704 storage.clone(),
705 make_topics(),
706 14400,
707 true,
708 );
709
710 let result = content.run_once(Some("Rust")).await;
711 assert!(matches!(result, ContentResult::Posted { .. }));
712 assert_eq!(storage.posted_count(), 0); assert_eq!(storage.action_count(), 1); }
715
716 #[tokio::test]
717 async fn run_once_rate_limited() {
718 let content = ContentLoop::new(
719 Arc::new(MockGenerator {
720 response: "tweet".to_string(),
721 }),
722 Arc::new(MockSafety {
723 can_tweet: false,
724 can_thread: true,
725 }),
726 Arc::new(MockStorage::new(None)),
727 make_topics(),
728 14400,
729 false,
730 );
731
732 let result = content.run_once(None).await;
733 assert!(matches!(result, ContentResult::RateLimited));
734 }
735
736 #[tokio::test]
737 async fn run_once_no_topics_returns_no_topics() {
738 let content = ContentLoop::new(
739 Arc::new(MockGenerator {
740 response: "tweet".to_string(),
741 }),
742 Arc::new(MockSafety {
743 can_tweet: true,
744 can_thread: true,
745 }),
746 Arc::new(MockStorage::new(None)),
747 Vec::new(),
748 14400,
749 false,
750 );
751
752 let result = content.run_once(None).await;
753 assert!(matches!(result, ContentResult::NoTopics));
754 }
755
756 #[tokio::test]
757 async fn run_once_generation_failure() {
758 let content = ContentLoop::new(
759 Arc::new(FailingGenerator),
760 Arc::new(MockSafety {
761 can_tweet: true,
762 can_thread: true,
763 }),
764 Arc::new(MockStorage::new(None)),
765 make_topics(),
766 14400,
767 false,
768 );
769
770 let result = content.run_once(Some("Rust")).await;
771 assert!(matches!(result, ContentResult::Failed { .. }));
772 }
773
774 #[tokio::test]
775 async fn run_iteration_skips_when_too_soon() {
776 let now = chrono::Utc::now();
777 let last_tweet = now - chrono::Duration::hours(1);
779 let storage = Arc::new(MockStorage::new(Some(last_tweet)));
780
781 let content = ContentLoop::new(
782 Arc::new(MockGenerator {
783 response: "tweet".to_string(),
784 }),
785 Arc::new(MockSafety {
786 can_tweet: true,
787 can_thread: true,
788 }),
789 storage,
790 make_topics(),
791 14400, false,
793 );
794
795 let mut recent = Vec::new();
796 let mut rng = rand::thread_rng();
797 let result = content.run_iteration(&mut recent, 3, &mut rng).await;
798 assert!(matches!(result, ContentResult::TooSoon { .. }));
799 }
800
801 #[tokio::test]
802 async fn run_iteration_posts_when_window_elapsed() {
803 let now = chrono::Utc::now();
804 let last_tweet = now - chrono::Duration::hours(5);
806 let storage = Arc::new(MockStorage::new(Some(last_tweet)));
807
808 let content = ContentLoop::new(
809 Arc::new(MockGenerator {
810 response: "Fresh tweet!".to_string(),
811 }),
812 Arc::new(MockSafety {
813 can_tweet: true,
814 can_thread: true,
815 }),
816 storage.clone(),
817 make_topics(),
818 14400,
819 false,
820 );
821
822 let mut recent = Vec::new();
823 let mut rng = rand::thread_rng();
824 let result = content.run_iteration(&mut recent, 3, &mut rng).await;
825 assert!(matches!(result, ContentResult::Posted { .. }));
826 assert_eq!(storage.posted_count(), 1);
827 assert_eq!(recent.len(), 1);
828 }
829
830 #[tokio::test]
831 async fn overlong_tweet_gets_truncated() {
832 let long_text = "a ".repeat(200); let content = ContentLoop::new(
834 Arc::new(OverlongGenerator {
835 first_response: long_text.clone(),
836 retry_response: long_text, call_count: Mutex::new(0),
838 }),
839 Arc::new(MockSafety {
840 can_tweet: true,
841 can_thread: true,
842 }),
843 Arc::new(MockStorage::new(None)),
844 make_topics(),
845 14400,
846 true,
847 );
848
849 let result = content.run_once(Some("Rust")).await;
850 if let ContentResult::Posted { content, .. } = result {
851 assert!(content.len() <= 280);
852 } else {
853 panic!("Expected Posted result");
854 }
855 }
856
857 #[test]
858 fn truncate_at_word_boundary_short() {
859 let result = truncate_at_word_boundary("Hello world", 280);
860 assert_eq!(result, "Hello world");
861 }
862
863 #[test]
864 fn truncate_at_word_boundary_long() {
865 let text = "The quick brown fox jumps over the lazy dog and more words here";
866 let result = truncate_at_word_boundary(text, 30);
867 assert!(result.len() <= 30);
868 assert!(result.ends_with("..."));
869 }
870
871 #[test]
872 fn pick_topic_avoids_recent() {
873 let topics = make_topics();
874 let mut recent = vec!["Rust".to_string(), "CLI tools".to_string()];
875 let mut rng = rand::thread_rng();
876
877 for _ in 0..20 {
878 let topic = pick_topic(&topics, &mut recent, &mut rng);
879 assert_ne!(topic, "Rust");
881 assert_ne!(topic, "CLI tools");
882 }
883 }
884
885 #[test]
886 fn pick_topic_clears_when_all_recent() {
887 let topics = make_topics();
888 let mut recent = topics.clone();
889 let mut rng = rand::thread_rng();
890
891 let topic = pick_topic(&topics, &mut recent, &mut rng);
893 assert!(topics.contains(&topic));
894 assert!(recent.is_empty()); }
896
897 #[test]
898 fn truncate_display_short() {
899 assert_eq!(truncate_display("hello", 10), "hello");
900 }
901
902 #[test]
903 fn truncate_display_long() {
904 let result = truncate_display("hello world this is long", 10);
905 assert_eq!(result, "hello worl...");
906 }
907
908 struct MockTopicScorer {
911 top_topics: Vec<String>,
912 }
913
914 #[async_trait::async_trait]
915 impl TopicScorer for MockTopicScorer {
916 async fn get_top_topics(&self, _limit: u32) -> Result<Vec<String>, ContentLoopError> {
917 Ok(self.top_topics.clone())
918 }
919 }
920
921 struct FailingTopicScorer;
922
923 #[async_trait::async_trait]
924 impl TopicScorer for FailingTopicScorer {
925 async fn get_top_topics(&self, _limit: u32) -> Result<Vec<String>, ContentLoopError> {
926 Err(ContentLoopError::StorageError("db error".to_string()))
927 }
928 }
929
930 #[tokio::test]
931 async fn epsilon_greedy_exploits_top_topic() {
932 let storage = Arc::new(MockStorage::new(None));
933 let scorer = Arc::new(MockTopicScorer {
934 top_topics: vec!["Rust".to_string()],
935 });
936
937 let content = ContentLoop::new(
938 Arc::new(MockGenerator {
939 response: "tweet".to_string(),
940 }),
941 Arc::new(MockSafety {
942 can_tweet: true,
943 can_thread: true,
944 }),
945 storage,
946 make_topics(),
947 14400,
948 false,
949 )
950 .with_topic_scorer(scorer);
951
952 let mut recent = Vec::new();
953 let mut rng = FirstCallRng::low_roll();
955
956 let topic = content
957 .pick_topic_epsilon_greedy(&mut recent, &mut rng)
958 .await;
959 assert_eq!(topic, "Rust");
960 }
961
962 #[tokio::test]
963 async fn epsilon_greedy_explores_when_roll_high() {
964 let storage = Arc::new(MockStorage::new(None));
965 let scorer = Arc::new(MockTopicScorer {
966 top_topics: vec!["Rust".to_string()],
967 });
968
969 let content = ContentLoop::new(
970 Arc::new(MockGenerator {
971 response: "tweet".to_string(),
972 }),
973 Arc::new(MockSafety {
974 can_tweet: true,
975 can_thread: true,
976 }),
977 storage,
978 make_topics(),
979 14400,
980 false,
981 )
982 .with_topic_scorer(scorer);
983
984 let mut recent = Vec::new();
985 let mut rng = FirstCallRng::high_roll();
987
988 let topic = content
989 .pick_topic_epsilon_greedy(&mut recent, &mut rng)
990 .await;
991 assert!(make_topics().contains(&topic));
992 }
993
994 #[tokio::test]
995 async fn epsilon_greedy_falls_back_on_scorer_error() {
996 let storage = Arc::new(MockStorage::new(None));
997 let scorer = Arc::new(FailingTopicScorer);
998
999 let content = ContentLoop::new(
1000 Arc::new(MockGenerator {
1001 response: "tweet".to_string(),
1002 }),
1003 Arc::new(MockSafety {
1004 can_tweet: true,
1005 can_thread: true,
1006 }),
1007 storage,
1008 make_topics(),
1009 14400,
1010 false,
1011 )
1012 .with_topic_scorer(scorer);
1013
1014 let mut recent = Vec::new();
1015 let mut rng = FirstCallRng::low_roll();
1017
1018 let topic = content
1019 .pick_topic_epsilon_greedy(&mut recent, &mut rng)
1020 .await;
1021 assert!(make_topics().contains(&topic));
1022 }
1023
1024 #[tokio::test]
1025 async fn epsilon_greedy_without_scorer_picks_random() {
1026 let storage = Arc::new(MockStorage::new(None));
1027
1028 let content = ContentLoop::new(
1029 Arc::new(MockGenerator {
1030 response: "tweet".to_string(),
1031 }),
1032 Arc::new(MockSafety {
1033 can_tweet: true,
1034 can_thread: true,
1035 }),
1036 storage,
1037 make_topics(),
1038 14400,
1039 false,
1040 );
1041
1042 let mut recent = Vec::new();
1043 let mut rng = rand::thread_rng();
1044
1045 let topic = content
1046 .pick_topic_epsilon_greedy(&mut recent, &mut rng)
1047 .await;
1048 assert!(make_topics().contains(&topic));
1049 }
1050
1051 struct FirstCallRng {
1056 first_u64: Option<u64>,
1057 inner: rand::rngs::ThreadRng,
1058 }
1059
1060 impl FirstCallRng {
1061 fn low_roll() -> Self {
1063 Self {
1064 first_u64: Some(0),
1065 inner: rand::thread_rng(),
1066 }
1067 }
1068
1069 fn high_roll() -> Self {
1071 Self {
1072 first_u64: Some(u64::MAX),
1073 inner: rand::thread_rng(),
1074 }
1075 }
1076 }
1077
1078 impl rand::RngCore for FirstCallRng {
1079 fn next_u32(&mut self) -> u32 {
1080 self.inner.next_u32()
1081 }
1082 fn next_u64(&mut self) -> u64 {
1083 if let Some(val) = self.first_u64.take() {
1084 val
1085 } else {
1086 self.inner.next_u64()
1087 }
1088 }
1089 fn fill_bytes(&mut self, dest: &mut [u8]) {
1090 self.inner.fill_bytes(dest);
1091 }
1092 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
1093 self.inner.try_fill_bytes(dest)
1094 }
1095 }
1096}