Skip to main content

tuitbot_core/content/
generator.rs

1//! High-level content generation combining LLM providers with business context.
2//!
3//! Produces replies, tweets, and threads that meet X's format requirements
4//! (280 characters per tweet, 5-8 tweets per thread) with retry logic.
5
6use crate::config::BusinessProfile;
7use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
8use crate::error::LlmError;
9use crate::llm::{GenerationParams, LlmProvider};
10
11/// Maximum characters allowed in a single tweet.
12const MAX_TWEET_CHARS: usize = 280;
13
14/// Maximum retries for thread generation.
15const MAX_THREAD_RETRIES: u32 = 2;
16
17/// Content generator that combines an LLM provider with business context.
18pub struct ContentGenerator {
19    provider: Box<dyn LlmProvider>,
20    business: BusinessProfile,
21}
22
23impl ContentGenerator {
24    /// Create a new content generator.
25    pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
26        Self { provider, business }
27    }
28
29    /// Generate a reply to a tweet.
30    ///
31    /// The reply will be conversational, helpful, and under 280 characters.
32    /// When `mention_product` is false, the system prompt explicitly forbids
33    /// mentioning the product name.
34    /// When `archetype` is provided, the prompt includes archetype-specific
35    /// guidance for varied output (e.g., ask a question, share experience).
36    /// Retries once with a stricter prompt if the first attempt is too long,
37    /// then truncates at a sentence boundary as a last resort.
38    pub async fn generate_reply(
39        &self,
40        tweet_text: &str,
41        tweet_author: &str,
42        mention_product: bool,
43    ) -> Result<String, LlmError> {
44        self.generate_reply_with_archetype(tweet_text, tweet_author, mention_product, None)
45            .await
46    }
47
48    /// Generate a reply using a specific archetype for varied output.
49    pub async fn generate_reply_with_archetype(
50        &self,
51        tweet_text: &str,
52        tweet_author: &str,
53        mention_product: bool,
54        archetype: Option<ReplyArchetype>,
55    ) -> Result<String, LlmError> {
56        tracing::debug!(
57            author = %tweet_author,
58            archetype = ?archetype,
59            mention_product = mention_product,
60            "Generating reply",
61        );
62        let voice_section = match &self.business.brand_voice {
63            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
64            _ => String::new(),
65        };
66        let reply_section = match &self.business.reply_style {
67            Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
68            _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
69        };
70
71        let archetype_section = match archetype {
72            Some(a) => format!("\n{}", a.prompt_fragment()),
73            None => String::new(),
74        };
75
76        let persona_section = self.format_persona_context();
77
78        let product_rule = if mention_product {
79            let product_url = self.business.product_url.as_deref().unwrap_or("");
80            format!(
81                "You are a helpful community member who uses {} ({}).\n\
82                 Your target audience is: {}.\n\
83                 Product URL: {}\
84                 {voice_section}\
85                 {reply_section}\
86                 {archetype_section}\
87                 {persona_section}\n\n\
88                 Rules:\n\
89                 - Write a reply to the tweet below.\n\
90                 - Maximum 3 sentences.\n\
91                 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
92                 - Do not use hashtags.\n\
93                 - Do not use emojis excessively.",
94                self.business.product_name,
95                self.business.product_description,
96                self.business.target_audience,
97                product_url,
98                self.business.product_name,
99            )
100        } else {
101            format!(
102                "You are a helpful community member.\n\
103                 Your target audience is: {}.\
104                 {voice_section}\
105                 {reply_section}\
106                 {archetype_section}\
107                 {persona_section}\n\n\
108                 Rules:\n\
109                 - Write a reply to the tweet below.\n\
110                 - Maximum 3 sentences.\n\
111                 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
112                 - Do not use hashtags.\n\
113                 - Do not use emojis excessively.",
114                self.business.target_audience, self.business.product_name,
115            )
116        };
117
118        let system = product_rule;
119        let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
120
121        let params = GenerationParams {
122            max_tokens: 200,
123            temperature: 0.7,
124            ..Default::default()
125        };
126
127        let resp = self
128            .provider
129            .complete(&system, &user_message, &params)
130            .await?;
131        let text = resp.text.trim().to_string();
132
133        tracing::debug!(chars = text.len(), "Generated reply");
134
135        if validate_length(&text, MAX_TWEET_CHARS) {
136            return Ok(text);
137        }
138
139        // Retry with stricter instruction
140        let retry_msg = format!(
141            "{user_message}\n\nImportant: Your reply MUST be under 280 characters. Be more concise."
142        );
143        let resp = self.provider.complete(&system, &retry_msg, &params).await?;
144        let text = resp.text.trim().to_string();
145
146        if validate_length(&text, MAX_TWEET_CHARS) {
147            return Ok(text);
148        }
149
150        // Last resort: truncate at sentence boundary
151        Ok(truncate_at_sentence(&text, MAX_TWEET_CHARS))
152    }
153
154    /// Generate a standalone educational tweet.
155    ///
156    /// The tweet will be informative, engaging, and under 280 characters.
157    pub async fn generate_tweet(&self, topic: &str) -> Result<String, LlmError> {
158        self.generate_tweet_with_format(topic, None).await
159    }
160
161    /// Generate a tweet using a specific format for varied structure.
162    pub async fn generate_tweet_with_format(
163        &self,
164        topic: &str,
165        format: Option<TweetFormat>,
166    ) -> Result<String, LlmError> {
167        tracing::debug!(
168            topic = %topic,
169            format = ?format,
170            "Generating tweet",
171        );
172        let voice_section = match &self.business.brand_voice {
173            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
174            _ => String::new(),
175        };
176        let content_section = match &self.business.content_style {
177            Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
178            _ => "\nContent style: Be informative and engaging.".to_string(),
179        };
180
181        let format_section = match format {
182            Some(f) => format!("\n{}", f.prompt_fragment()),
183            None => String::new(),
184        };
185
186        let persona_section = self.format_persona_context();
187
188        let system = format!(
189            "You are {}'s social media voice. {}.\n\
190             Your audience: {}.\
191             {voice_section}\
192             {content_section}\
193             {format_section}\
194             {persona_section}\n\n\
195             Rules:\n\
196             - Write a single educational tweet about the topic below.\n\
197             - Maximum 280 characters.\n\
198             - Do not use hashtags.\n\
199             - Do not mention {} directly unless it is central to the topic.",
200            self.business.product_name,
201            self.business.product_description,
202            self.business.target_audience,
203            self.business.product_name,
204        );
205
206        let user_message = format!("Write a tweet about: {topic}");
207
208        let params = GenerationParams {
209            max_tokens: 150,
210            temperature: 0.8,
211            ..Default::default()
212        };
213
214        let resp = self
215            .provider
216            .complete(&system, &user_message, &params)
217            .await?;
218        let text = resp.text.trim().to_string();
219
220        if validate_length(&text, MAX_TWEET_CHARS) {
221            return Ok(text);
222        }
223
224        // Retry with stricter instruction
225        let retry_msg = format!(
226            "{user_message}\n\nImportant: Your tweet MUST be under 280 characters. Be more concise."
227        );
228        let resp = self.provider.complete(&system, &retry_msg, &params).await?;
229        let text = resp.text.trim().to_string();
230
231        if validate_length(&text, MAX_TWEET_CHARS) {
232            return Ok(text);
233        }
234
235        Ok(truncate_at_sentence(&text, MAX_TWEET_CHARS))
236    }
237
238    /// Generate an educational thread of 5-8 tweets.
239    ///
240    /// Each tweet in the thread will be under 280 characters.
241    /// Retries up to 2 times if the LLM produces malformed output.
242    pub async fn generate_thread(&self, topic: &str) -> Result<Vec<String>, LlmError> {
243        self.generate_thread_with_structure(topic, None).await
244    }
245
246    /// Generate a thread using a specific structure for varied content.
247    pub async fn generate_thread_with_structure(
248        &self,
249        topic: &str,
250        structure: Option<ThreadStructure>,
251    ) -> Result<Vec<String>, LlmError> {
252        tracing::debug!(
253            topic = %topic,
254            structure = ?structure,
255            "Generating thread",
256        );
257        let voice_section = match &self.business.brand_voice {
258            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
259            _ => String::new(),
260        };
261        let content_section = match &self.business.content_style {
262            Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
263            _ => "\nContent style: Be informative, not promotional.".to_string(),
264        };
265
266        let structure_section = match structure {
267            Some(s) => format!("\n{}", s.prompt_fragment()),
268            None => String::new(),
269        };
270
271        let persona_section = self.format_persona_context();
272
273        let system = format!(
274            "You are {}'s social media voice. {}.\n\
275             Your audience: {}.\
276             {voice_section}\
277             {content_section}\
278             {structure_section}\
279             {persona_section}\n\n\
280             Rules:\n\
281             - Write an educational thread of 5 to 8 tweets about the topic below.\n\
282             - Separate each tweet with a line containing only \"---\".\n\
283             - Each tweet must be under 280 characters.\n\
284             - The first tweet should hook the reader.\n\
285             - The last tweet should include a call to action or summary.\n\
286             - Do not use hashtags.",
287            self.business.product_name,
288            self.business.product_description,
289            self.business.target_audience,
290        );
291
292        let user_message = format!("Write a thread about: {topic}");
293
294        let params = GenerationParams {
295            max_tokens: 1500,
296            temperature: 0.7,
297            ..Default::default()
298        };
299
300        for attempt in 0..=MAX_THREAD_RETRIES {
301            let msg = if attempt == 0 {
302                user_message.clone()
303            } else {
304                format!(
305                    "{user_message}\n\nIMPORTANT: Write exactly 5-8 tweets, \
306                     each under 280 characters, separated by lines containing only \"---\"."
307                )
308            };
309
310            let resp = self.provider.complete(&system, &msg, &params).await?;
311            let tweets = parse_thread(&resp.text);
312
313            if (5..=8).contains(&tweets.len())
314                && tweets.iter().all(|t| validate_length(t, MAX_TWEET_CHARS))
315            {
316                return Ok(tweets);
317            }
318        }
319
320        Err(LlmError::GenerationFailed(
321            "Failed to generate valid thread after retries".to_string(),
322        ))
323    }
324
325    /// Build a persona context section from opinions and experiences.
326    fn format_persona_context(&self) -> String {
327        let mut parts = Vec::new();
328
329        if !self.business.persona_opinions.is_empty() {
330            let opinions = self.business.persona_opinions.join("; ");
331            parts.push(format!("Opinions you hold: {opinions}"));
332        }
333
334        if !self.business.persona_experiences.is_empty() {
335            let experiences = self.business.persona_experiences.join("; ");
336            parts.push(format!("Experiences you can reference: {experiences}"));
337        }
338
339        if !self.business.content_pillars.is_empty() {
340            let pillars = self.business.content_pillars.join(", ");
341            parts.push(format!("Content pillars: {pillars}"));
342        }
343
344        if parts.is_empty() {
345            String::new()
346        } else {
347            format!("\n{}", parts.join("\n"))
348        }
349    }
350}
351
352/// Parse a thread response by splitting on `---` delimiters.
353///
354/// Also tries numbered patterns (e.g., "1/8", "1.") as a fallback.
355fn parse_thread(text: &str) -> Vec<String> {
356    // Primary: split on "---" delimiter
357    let tweets: Vec<String> = text
358        .split("---")
359        .map(|s| s.trim().to_string())
360        .filter(|s| !s.is_empty())
361        .collect();
362
363    if !tweets.is_empty() && text.contains("---") {
364        return tweets;
365    }
366
367    // Fallback: try splitting on numbered patterns like "1/8", "2/8" or "1.", "2."
368    let lines: Vec<&str> = text.lines().collect();
369    let mut tweets = Vec::new();
370    let mut current = String::new();
371
372    for line in &lines {
373        let trimmed = line.trim();
374        let is_numbered = trimmed.chars().next().is_some_and(|c| c.is_ascii_digit())
375            && (trimmed.contains('/')
376                || trimmed.starts_with(|c: char| c.is_ascii_digit())
377                    && trimmed.chars().nth(1).is_some_and(|c| c == '.' || c == ')'));
378
379        if is_numbered && !current.is_empty() {
380            tweets.push(current.trim().to_string());
381            current = String::new();
382        }
383
384        if !trimmed.is_empty() {
385            if !current.is_empty() {
386                current.push(' ');
387            }
388            // Strip the number prefix if present
389            if is_numbered {
390                let content = trimmed
391                    .find(|c: char| !c.is_ascii_digit() && c != '/' && c != '.' && c != ')')
392                    .map(|i| trimmed[i..].trim_start())
393                    .unwrap_or(trimmed);
394                current.push_str(content);
395            } else {
396                current.push_str(trimmed);
397            }
398        }
399    }
400
401    if !current.trim().is_empty() {
402        tweets.push(current.trim().to_string());
403    }
404
405    tweets
406}
407
408/// Check if text is within the character limit.
409fn validate_length(text: &str, max_chars: usize) -> bool {
410    text.len() <= max_chars
411}
412
413/// Truncate text at the last sentence boundary that fits within the limit.
414///
415/// Looks for the last period, exclamation mark, or question mark within the limit.
416/// Falls back to truncating at the limit with "..." if no sentence boundary is found.
417fn truncate_at_sentence(text: &str, max_chars: usize) -> String {
418    if text.len() <= max_chars {
419        return text.to_string();
420    }
421
422    let search_area = &text[..max_chars];
423
424    // Find the last sentence-ending punctuation
425    let last_sentence_end = search_area
426        .rfind('.')
427        .max(search_area.rfind('!'))
428        .max(search_area.rfind('?'));
429
430    if let Some(pos) = last_sentence_end {
431        if pos > 0 {
432            return text[..=pos].trim().to_string();
433        }
434    }
435
436    // No sentence boundary found; hard truncate with ellipsis
437    let truncate_at = max_chars.saturating_sub(3);
438    // Find a word boundary
439    let word_end = search_area[..truncate_at].rfind(' ').unwrap_or(truncate_at);
440    format!("{}...", &text[..word_end])
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use crate::llm::{LlmResponse, TokenUsage};
447    use std::sync::atomic::{AtomicUsize, Ordering};
448    use std::sync::Arc;
449
450    /// Mock LLM provider that returns canned responses.
451    struct MockProvider {
452        responses: Vec<String>,
453        call_count: Arc<AtomicUsize>,
454    }
455
456    impl MockProvider {
457        fn new(responses: Vec<String>) -> Self {
458            Self {
459                responses,
460                call_count: Arc::new(AtomicUsize::new(0)),
461            }
462        }
463
464        fn single(response: &str) -> Self {
465            Self::new(vec![response.to_string()])
466        }
467    }
468
469    #[async_trait::async_trait]
470    impl LlmProvider for MockProvider {
471        fn name(&self) -> &str {
472            "mock"
473        }
474
475        async fn complete(
476            &self,
477            _system: &str,
478            _user_message: &str,
479            _params: &GenerationParams,
480        ) -> Result<LlmResponse, LlmError> {
481            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
482            let text = self
483                .responses
484                .get(idx)
485                .cloned()
486                .unwrap_or_else(|| self.responses.last().cloned().unwrap_or_default());
487
488            Ok(LlmResponse {
489                text,
490                usage: TokenUsage::default(),
491                model: "mock".to_string(),
492            })
493        }
494
495        async fn health_check(&self) -> Result<(), LlmError> {
496            Ok(())
497        }
498    }
499
500    fn test_business() -> BusinessProfile {
501        BusinessProfile {
502            product_name: "TestApp".to_string(),
503            product_description: "A test application".to_string(),
504            product_url: Some("https://testapp.com".to_string()),
505            target_audience: "developers".to_string(),
506            product_keywords: vec!["test".to_string()],
507            competitor_keywords: vec![],
508            industry_topics: vec!["testing".to_string()],
509            brand_voice: None,
510            reply_style: None,
511            content_style: None,
512            persona_opinions: vec![],
513            persona_experiences: vec![],
514            content_pillars: vec![],
515        }
516    }
517
518    // --- validate_length tests ---
519
520    #[test]
521    fn validate_length_under_limit() {
522        assert!(validate_length("short text", 280));
523    }
524
525    #[test]
526    fn validate_length_at_limit() {
527        let text = "a".repeat(280);
528        assert!(validate_length(&text, 280));
529    }
530
531    #[test]
532    fn validate_length_over_limit() {
533        let text = "a".repeat(281);
534        assert!(!validate_length(&text, 280));
535    }
536
537    // --- truncate_at_sentence tests ---
538
539    #[test]
540    fn truncate_at_sentence_under_limit() {
541        let text = "Short sentence.";
542        assert_eq!(truncate_at_sentence(text, 280), "Short sentence.");
543    }
544
545    #[test]
546    fn truncate_at_period() {
547        let text = "First sentence. Second sentence. Third sentence is very long and goes over the limit and more and more text.";
548        let result = truncate_at_sentence(text, 50);
549        assert!(result.len() <= 50);
550        assert!(result.ends_with('.'));
551    }
552
553    #[test]
554    fn truncate_at_question_mark() {
555        let text = "Is this working? I hope so because this text is getting very long and will exceed the character limit.";
556        let result = truncate_at_sentence(text, 20);
557        assert!(result.len() <= 20);
558        assert!(result.ends_with('?'));
559    }
560
561    #[test]
562    fn truncate_no_sentence_boundary() {
563        let text =
564            "This is a very long sentence without any punctuation that keeps going and going";
565        let result = truncate_at_sentence(text, 30);
566        assert!(result.len() <= 30);
567        assert!(result.ends_with("..."));
568    }
569
570    // --- parse_thread tests ---
571
572    #[test]
573    fn parse_thread_with_dashes() {
574        let text = "Tweet one\n---\nTweet two\n---\nTweet three";
575        let tweets = parse_thread(text);
576        assert_eq!(tweets.len(), 3);
577        assert_eq!(tweets[0], "Tweet one");
578        assert_eq!(tweets[1], "Tweet two");
579        assert_eq!(tweets[2], "Tweet three");
580    }
581
582    #[test]
583    fn parse_thread_with_extra_whitespace() {
584        let text = "  Tweet one  \n---\n  Tweet two  \n---\n";
585        let tweets = parse_thread(text);
586        assert_eq!(tweets.len(), 2);
587        assert_eq!(tweets[0], "Tweet one");
588        assert_eq!(tweets[1], "Tweet two");
589    }
590
591    #[test]
592    fn parse_thread_single_block_falls_back_to_numbered() {
593        let text =
594            "1/5 First tweet\n2/5 Second tweet\n3/5 Third tweet\n4/5 Fourth tweet\n5/5 Fifth tweet";
595        let tweets = parse_thread(text);
596        assert!(
597            tweets.len() >= 2,
598            "got {} tweets: {:?}",
599            tweets.len(),
600            tweets
601        );
602    }
603
604    #[test]
605    fn parse_thread_empty_sections_filtered() {
606        let text = "---\n---\nActual tweet\n---\n---";
607        let tweets = parse_thread(text);
608        assert_eq!(tweets.len(), 1);
609        assert_eq!(tweets[0], "Actual tweet");
610    }
611
612    // --- generate_reply tests ---
613
614    #[tokio::test]
615    async fn generate_reply_success() {
616        let provider =
617            MockProvider::single("Great point about testing! I've found similar results.");
618        let gen = ContentGenerator::new(Box::new(provider), test_business());
619
620        let reply = gen
621            .generate_reply("Testing is important", "devuser", true)
622            .await
623            .expect("reply");
624        assert!(reply.len() <= MAX_TWEET_CHARS);
625        assert!(!reply.is_empty());
626    }
627
628    #[tokio::test]
629    async fn generate_reply_truncates_long_output() {
630        let long_text = "a ".repeat(200); // 400 chars
631        let provider = MockProvider::new(vec![long_text.clone(), long_text]);
632        let gen = ContentGenerator::new(Box::new(provider), test_business());
633
634        let reply = gen
635            .generate_reply("test", "user", true)
636            .await
637            .expect("reply");
638        assert!(reply.len() <= MAX_TWEET_CHARS);
639    }
640
641    #[tokio::test]
642    async fn generate_reply_no_product_mention() {
643        let provider = MockProvider::single("That's a great approach for productivity!");
644        let gen = ContentGenerator::new(Box::new(provider), test_business());
645
646        let reply = gen
647            .generate_reply("How do you stay productive?", "devuser", false)
648            .await
649            .expect("reply");
650        assert!(reply.len() <= MAX_TWEET_CHARS);
651        assert!(!reply.is_empty());
652    }
653
654    // --- generate_tweet tests ---
655
656    #[tokio::test]
657    async fn generate_tweet_success() {
658        let provider =
659            MockProvider::single("Testing your code early saves hours of debugging later.");
660        let gen = ContentGenerator::new(Box::new(provider), test_business());
661
662        let tweet = gen
663            .generate_tweet("testing best practices")
664            .await
665            .expect("tweet");
666        assert!(tweet.len() <= MAX_TWEET_CHARS);
667        assert!(!tweet.is_empty());
668    }
669
670    // --- generate_thread tests ---
671
672    #[tokio::test]
673    async fn generate_thread_success() {
674        let thread_text = vec![
675            "Hook tweet here",
676            "---",
677            "Second point about testing",
678            "---",
679            "Third point about quality",
680            "---",
681            "Fourth point about CI/CD",
682            "---",
683            "Fifth point about automation",
684            "---",
685            "Summary and call to action",
686        ]
687        .join("\n");
688
689        let provider = MockProvider::single(&thread_text);
690        let gen = ContentGenerator::new(Box::new(provider), test_business());
691
692        let thread = gen.generate_thread("testing").await.expect("thread");
693        assert!(
694            (5..=8).contains(&thread.len()),
695            "got {} tweets",
696            thread.len()
697        );
698        for tweet in &thread {
699            assert!(tweet.len() <= MAX_TWEET_CHARS);
700        }
701    }
702
703    #[tokio::test]
704    async fn generate_thread_retries_on_bad_count() {
705        // First attempt: too few tweets. Second: still too few. Third: valid.
706        let bad = "Tweet one\n---\nTweet two";
707        let good = "One\n---\nTwo\n---\nThree\n---\nFour\n---\nFive";
708        let provider = MockProvider::new(vec![bad.into(), bad.into(), good.into()]);
709        let gen = ContentGenerator::new(Box::new(provider), test_business());
710
711        let thread = gen.generate_thread("topic").await.expect("thread");
712        assert_eq!(thread.len(), 5);
713    }
714
715    #[tokio::test]
716    async fn generate_thread_fails_after_max_retries() {
717        let bad = "Tweet one\n---\nTweet two";
718        let provider = MockProvider::new(vec![bad.into(), bad.into(), bad.into()]);
719        let gen = ContentGenerator::new(Box::new(provider), test_business());
720
721        let err = gen.generate_thread("topic").await.unwrap_err();
722        assert!(matches!(err, LlmError::GenerationFailed(_)));
723    }
724
725    // --- GenerationParams tests ---
726
727    #[test]
728    fn generation_params_default() {
729        let params = GenerationParams::default();
730        assert_eq!(params.max_tokens, 512);
731        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
732        assert!(params.system_prompt.is_none());
733    }
734}