Skip to main content

tuitbot_core/content/
generator.rs

1//! High-level content generation combining LLM providers with business context.
2//!
3//! Produces replies, tweets, and threads that meet X's format requirements
4//! (280 characters per tweet, 5-8 tweets per thread) with retry logic.
5
6use crate::config::BusinessProfile;
7use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
8use crate::content::length::{truncate_at_sentence, validate_tweet_length, MAX_TWEET_CHARS};
9use crate::error::LlmError;
10use crate::llm::{GenerationParams, LlmProvider, TokenUsage};
11
12/// Output from a single-text generation (reply or tweet).
13#[derive(Debug, Clone)]
14pub struct GenerationOutput {
15    /// The generated text.
16    pub text: String,
17    /// Accumulated token usage across all attempts (including retries).
18    pub usage: TokenUsage,
19    /// The model that produced the final response.
20    pub model: String,
21    /// The provider name (e.g., "openai", "anthropic", "ollama").
22    pub provider: String,
23}
24
25/// Output from thread generation.
26#[derive(Debug, Clone)]
27pub struct ThreadGenerationOutput {
28    /// The generated tweets in thread order.
29    pub tweets: Vec<String>,
30    /// Accumulated token usage across all attempts (including retries).
31    pub usage: TokenUsage,
32    /// The model that produced the final response.
33    pub model: String,
34    /// The provider name.
35    pub provider: String,
36}
37
38/// Maximum retries for thread generation.
39const MAX_THREAD_RETRIES: u32 = 2;
40
41/// Content generator that combines an LLM provider with business context.
42pub struct ContentGenerator {
43    provider: Box<dyn LlmProvider>,
44    business: BusinessProfile,
45}
46
47impl ContentGenerator {
48    /// Create a new content generator.
49    pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
50        Self { provider, business }
51    }
52
53    /// Generate a reply to a tweet.
54    ///
55    /// The reply will be conversational, helpful, and under 280 characters.
56    /// When `mention_product` is false, the system prompt explicitly forbids
57    /// mentioning the product name.
58    /// When `archetype` is provided, the prompt includes archetype-specific
59    /// guidance for varied output (e.g., ask a question, share experience).
60    /// Retries once with a stricter prompt if the first attempt is too long,
61    /// then truncates at a sentence boundary as a last resort.
62    pub async fn generate_reply(
63        &self,
64        tweet_text: &str,
65        tweet_author: &str,
66        mention_product: bool,
67    ) -> Result<GenerationOutput, LlmError> {
68        self.generate_reply_with_archetype(tweet_text, tweet_author, mention_product, None)
69            .await
70    }
71
72    /// Generate a reply with optional RAG context injected into the prompt.
73    ///
74    /// The `rag_context` block is inserted between the persona/voice section
75    /// and the rules section of the system prompt. If `None`, behaves
76    /// identically to `generate_reply_with_archetype`.
77    pub async fn generate_reply_with_context(
78        &self,
79        tweet_text: &str,
80        tweet_author: &str,
81        mention_product: bool,
82        archetype: Option<ReplyArchetype>,
83        rag_context: Option<&str>,
84    ) -> Result<GenerationOutput, LlmError> {
85        self.generate_reply_inner(
86            tweet_text,
87            tweet_author,
88            mention_product,
89            archetype,
90            rag_context,
91        )
92        .await
93    }
94
95    /// Generate a reply using a specific archetype for varied output.
96    pub async fn generate_reply_with_archetype(
97        &self,
98        tweet_text: &str,
99        tweet_author: &str,
100        mention_product: bool,
101        archetype: Option<ReplyArchetype>,
102    ) -> Result<GenerationOutput, LlmError> {
103        self.generate_reply_inner(tweet_text, tweet_author, mention_product, archetype, None)
104            .await
105    }
106
107    /// Internal reply generation with optional RAG context.
108    async fn generate_reply_inner(
109        &self,
110        tweet_text: &str,
111        tweet_author: &str,
112        mention_product: bool,
113        archetype: Option<ReplyArchetype>,
114        rag_context: Option<&str>,
115    ) -> Result<GenerationOutput, LlmError> {
116        tracing::debug!(
117            author = %tweet_author,
118            archetype = ?archetype,
119            mention_product = mention_product,
120            has_rag_context = rag_context.is_some(),
121            "Generating reply",
122        );
123        let voice_section = match &self.business.brand_voice {
124            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
125            _ => String::new(),
126        };
127        let reply_section = match &self.business.reply_style {
128            Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
129            _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
130        };
131
132        let archetype_section = match archetype {
133            Some(a) => format!("\n{}", a.prompt_fragment()),
134            None => String::new(),
135        };
136
137        let persona_section = self.format_persona_context();
138
139        let rag_section = match rag_context {
140            Some(ctx) if !ctx.is_empty() => format!("\n{ctx}"),
141            _ => String::new(),
142        };
143
144        let audience_section = if self.business.target_audience.is_empty() {
145            String::new()
146        } else {
147            format!(
148                "\nYour target audience is: {}.",
149                self.business.target_audience
150            )
151        };
152
153        let product_rule = if mention_product {
154            let product_url = self.business.product_url.as_deref().unwrap_or("");
155            format!(
156                "You are a helpful community member who uses {} ({}).\
157                 {audience_section}\n\
158                 Product URL: {}\
159                 {voice_section}\
160                 {reply_section}\
161                 {archetype_section}\
162                 {persona_section}\
163                 {rag_section}\n\n\
164                 Rules:\n\
165                 - Write a reply to the tweet below.\n\
166                 - Maximum 3 sentences.\n\
167                 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
168                 - Do not use hashtags.\n\
169                 - Do not use emojis excessively.",
170                self.business.product_name,
171                self.business.product_description,
172                product_url,
173                self.business.product_name,
174            )
175        } else {
176            format!(
177                "You are a helpful community member.\
178                 {audience_section}\
179                 {voice_section}\
180                 {reply_section}\
181                 {archetype_section}\
182                 {persona_section}\
183                 {rag_section}\n\n\
184                 Rules:\n\
185                 - Write a reply to the tweet below.\n\
186                 - Maximum 3 sentences.\n\
187                 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
188                 - Do not use hashtags.\n\
189                 - Do not use emojis excessively.",
190                self.business.product_name,
191            )
192        };
193
194        let system = product_rule;
195        let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
196
197        let params = GenerationParams {
198            max_tokens: 200,
199            temperature: 0.7,
200            ..Default::default()
201        };
202
203        let resp = self
204            .provider
205            .complete(&system, &user_message, &params)
206            .await?;
207        let mut usage = resp.usage.clone();
208        let provider_name = self.provider.name().to_string();
209        let model = resp.model.clone();
210        let text = resp.text.trim().to_string();
211
212        tracing::debug!(chars = text.len(), "Generated reply");
213
214        if validate_tweet_length(&text, MAX_TWEET_CHARS) {
215            return Ok(GenerationOutput {
216                text,
217                usage,
218                model,
219                provider: provider_name,
220            });
221        }
222
223        // Retry with stricter instruction
224        let retry_msg = format!(
225            "{user_message}\n\nImportant: Your reply MUST be under 280 characters. Be more concise."
226        );
227        let resp = self.provider.complete(&system, &retry_msg, &params).await?;
228        usage.accumulate(&resp.usage);
229        let text = resp.text.trim().to_string();
230
231        if validate_tweet_length(&text, MAX_TWEET_CHARS) {
232            return Ok(GenerationOutput {
233                text,
234                usage,
235                model,
236                provider: provider_name,
237            });
238        }
239
240        // Last resort: truncate at sentence boundary
241        Ok(GenerationOutput {
242            text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
243            usage,
244            model,
245            provider: provider_name,
246        })
247    }
248
249    /// Generate a standalone educational tweet.
250    ///
251    /// The tweet will be informative, engaging, and under 280 characters.
252    pub async fn generate_tweet(&self, topic: &str) -> Result<GenerationOutput, LlmError> {
253        self.generate_tweet_with_format(topic, None).await
254    }
255
256    /// Generate a tweet using a specific format for varied structure.
257    pub async fn generate_tweet_with_format(
258        &self,
259        topic: &str,
260        format: Option<TweetFormat>,
261    ) -> Result<GenerationOutput, LlmError> {
262        tracing::debug!(
263            topic = %topic,
264            format = ?format,
265            "Generating tweet",
266        );
267        let voice_section = match &self.business.brand_voice {
268            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
269            _ => String::new(),
270        };
271        let content_section = match &self.business.content_style {
272            Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
273            _ => "\nContent style: Be informative and engaging.".to_string(),
274        };
275
276        let format_section = match format {
277            Some(f) => format!("\n{}", f.prompt_fragment()),
278            None => String::new(),
279        };
280
281        let persona_section = self.format_persona_context();
282
283        let audience_section = if self.business.target_audience.is_empty() {
284            String::new()
285        } else {
286            format!("\nYour audience: {}.", self.business.target_audience)
287        };
288
289        let system = format!(
290            "You are {}'s social media voice. {}.\
291             {audience_section}\
292             {voice_section}\
293             {content_section}\
294             {format_section}\
295             {persona_section}\n\n\
296             Rules:\n\
297             - Write a single educational tweet about the topic below.\n\
298             - Maximum 280 characters.\n\
299             - Do not use hashtags.\n\
300             - Do not mention {} directly unless it is central to the topic.",
301            self.business.product_name,
302            self.business.product_description,
303            self.business.product_name,
304        );
305
306        let user_message = format!("Write a tweet about: {topic}");
307
308        let params = GenerationParams {
309            max_tokens: 150,
310            temperature: 0.8,
311            ..Default::default()
312        };
313
314        let resp = self
315            .provider
316            .complete(&system, &user_message, &params)
317            .await?;
318        let mut usage = resp.usage.clone();
319        let provider_name = self.provider.name().to_string();
320        let model = resp.model.clone();
321        let text = resp.text.trim().to_string();
322
323        if validate_tweet_length(&text, MAX_TWEET_CHARS) {
324            return Ok(GenerationOutput {
325                text,
326                usage,
327                model,
328                provider: provider_name,
329            });
330        }
331
332        // Retry with stricter instruction
333        let retry_msg = format!(
334            "{user_message}\n\nImportant: Your tweet MUST be under 280 characters. Be more concise."
335        );
336        let resp = self.provider.complete(&system, &retry_msg, &params).await?;
337        usage.accumulate(&resp.usage);
338        let text = resp.text.trim().to_string();
339
340        if validate_tweet_length(&text, MAX_TWEET_CHARS) {
341            return Ok(GenerationOutput {
342                text,
343                usage,
344                model,
345                provider: provider_name,
346            });
347        }
348
349        Ok(GenerationOutput {
350            text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
351            usage,
352            model,
353            provider: provider_name,
354        })
355    }
356
357    /// Generate an educational thread of 5-8 tweets.
358    ///
359    /// Each tweet in the thread will be under 280 characters.
360    /// Retries up to 2 times if the LLM produces malformed output.
361    pub async fn generate_thread(&self, topic: &str) -> Result<ThreadGenerationOutput, LlmError> {
362        self.generate_thread_with_structure(topic, None).await
363    }
364
365    /// Generate a thread using a specific structure for varied content.
366    pub async fn generate_thread_with_structure(
367        &self,
368        topic: &str,
369        structure: Option<ThreadStructure>,
370    ) -> Result<ThreadGenerationOutput, LlmError> {
371        tracing::debug!(
372            topic = %topic,
373            structure = ?structure,
374            "Generating thread",
375        );
376        let voice_section = match &self.business.brand_voice {
377            Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
378            _ => String::new(),
379        };
380        let content_section = match &self.business.content_style {
381            Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
382            _ => "\nContent style: Be informative, not promotional.".to_string(),
383        };
384
385        let structure_section = match structure {
386            Some(s) => format!("\n{}", s.prompt_fragment()),
387            None => String::new(),
388        };
389
390        let persona_section = self.format_persona_context();
391
392        let audience_section = if self.business.target_audience.is_empty() {
393            String::new()
394        } else {
395            format!("\nYour audience: {}.", self.business.target_audience)
396        };
397
398        let system = format!(
399            "You are {}'s social media voice. {}.\
400             {audience_section}\
401             {voice_section}\
402             {content_section}\
403             {structure_section}\
404             {persona_section}\n\n\
405             Rules:\n\
406             - Write an educational thread of 5 to 8 tweets about the topic below.\n\
407             - Separate each tweet with a line containing only \"---\".\n\
408             - Each tweet must be under 280 characters.\n\
409             - The first tweet should hook the reader.\n\
410             - The last tweet should include a call to action or summary.\n\
411             - Do not use hashtags.",
412            self.business.product_name, self.business.product_description,
413        );
414
415        let user_message = format!("Write a thread about: {topic}");
416
417        let params = GenerationParams {
418            max_tokens: 1500,
419            temperature: 0.7,
420            ..Default::default()
421        };
422
423        let mut usage = TokenUsage::default();
424        let provider_name = self.provider.name().to_string();
425        let mut model = String::new();
426
427        for attempt in 0..=MAX_THREAD_RETRIES {
428            let msg = if attempt == 0 {
429                user_message.clone()
430            } else {
431                format!(
432                    "{user_message}\n\nIMPORTANT: Write exactly 5-8 tweets, \
433                     each under 280 characters, separated by lines containing only \"---\"."
434                )
435            };
436
437            let resp = self.provider.complete(&system, &msg, &params).await?;
438            usage.accumulate(&resp.usage);
439            model.clone_from(&resp.model);
440            let tweets = parse_thread(&resp.text);
441
442            if (5..=8).contains(&tweets.len())
443                && tweets
444                    .iter()
445                    .all(|t| validate_tweet_length(t, MAX_TWEET_CHARS))
446            {
447                return Ok(ThreadGenerationOutput {
448                    tweets,
449                    usage,
450                    model,
451                    provider: provider_name,
452                });
453            }
454        }
455
456        Err(LlmError::GenerationFailed(
457            "Failed to generate valid thread after retries".to_string(),
458        ))
459    }
460
461    /// Build a persona context section from opinions and experiences.
462    fn format_persona_context(&self) -> String {
463        let mut parts = Vec::new();
464
465        if !self.business.persona_opinions.is_empty() {
466            let opinions = self.business.persona_opinions.join("; ");
467            parts.push(format!("Opinions you hold: {opinions}"));
468        }
469
470        if !self.business.persona_experiences.is_empty() {
471            let experiences = self.business.persona_experiences.join("; ");
472            parts.push(format!("Experiences you can reference: {experiences}"));
473        }
474
475        if !self.business.content_pillars.is_empty() {
476            let pillars = self.business.content_pillars.join(", ");
477            parts.push(format!("Content pillars: {pillars}"));
478        }
479
480        if parts.is_empty() {
481            String::new()
482        } else {
483            format!("\n{}", parts.join("\n"))
484        }
485    }
486}
487
488/// Parse a thread response by splitting on `---` delimiters.
489///
490/// Also tries numbered patterns (e.g., "1/8", "1.") as a fallback.
491fn parse_thread(text: &str) -> Vec<String> {
492    // Primary: split on "---" delimiter
493    let tweets: Vec<String> = text
494        .split("---")
495        .map(|s| s.trim().to_string())
496        .filter(|s| !s.is_empty())
497        .collect();
498
499    if !tweets.is_empty() && text.contains("---") {
500        return tweets;
501    }
502
503    // Fallback: try splitting on numbered patterns like "1/8", "2/8" or "1.", "2."
504    let lines: Vec<&str> = text.lines().collect();
505    let mut tweets = Vec::new();
506    let mut current = String::new();
507
508    for line in &lines {
509        let trimmed = line.trim();
510        let is_numbered = trimmed.chars().next().is_some_and(|c| c.is_ascii_digit())
511            && (trimmed.contains('/')
512                || trimmed.starts_with(|c: char| c.is_ascii_digit())
513                    && trimmed.chars().nth(1).is_some_and(|c| c == '.' || c == ')'));
514
515        if is_numbered && !current.is_empty() {
516            tweets.push(current.trim().to_string());
517            current = String::new();
518        }
519
520        if !trimmed.is_empty() {
521            if !current.is_empty() {
522                current.push(' ');
523            }
524            // Strip the number prefix if present
525            if is_numbered {
526                let content = trimmed
527                    .find(|c: char| !c.is_ascii_digit() && c != '/' && c != '.' && c != ')')
528                    .map(|i| trimmed[i..].trim_start())
529                    .unwrap_or(trimmed);
530                current.push_str(content);
531            } else {
532                current.push_str(trimmed);
533            }
534        }
535    }
536
537    if !current.trim().is_empty() {
538        tweets.push(current.trim().to_string());
539    }
540
541    tweets
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use crate::llm::{LlmResponse, TokenUsage};
548    use std::sync::atomic::{AtomicUsize, Ordering};
549    use std::sync::Arc;
550
551    /// Mock LLM provider that returns canned responses.
552    struct MockProvider {
553        responses: Vec<String>,
554        call_count: Arc<AtomicUsize>,
555    }
556
557    impl MockProvider {
558        fn new(responses: Vec<String>) -> Self {
559            Self {
560                responses,
561                call_count: Arc::new(AtomicUsize::new(0)),
562            }
563        }
564
565        fn single(response: &str) -> Self {
566            Self::new(vec![response.to_string()])
567        }
568    }
569
570    #[async_trait::async_trait]
571    impl LlmProvider for MockProvider {
572        fn name(&self) -> &str {
573            "mock"
574        }
575
576        async fn complete(
577            &self,
578            _system: &str,
579            _user_message: &str,
580            _params: &GenerationParams,
581        ) -> Result<LlmResponse, LlmError> {
582            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
583            let text = self
584                .responses
585                .get(idx)
586                .cloned()
587                .unwrap_or_else(|| self.responses.last().cloned().unwrap_or_default());
588
589            Ok(LlmResponse {
590                text,
591                usage: TokenUsage::default(),
592                model: "mock".to_string(),
593            })
594        }
595
596        async fn health_check(&self) -> Result<(), LlmError> {
597            Ok(())
598        }
599    }
600
601    fn test_business() -> BusinessProfile {
602        BusinessProfile {
603            product_name: "TestApp".to_string(),
604            product_description: "A test application".to_string(),
605            product_url: Some("https://testapp.com".to_string()),
606            target_audience: "developers".to_string(),
607            product_keywords: vec!["test".to_string()],
608            competitor_keywords: vec![],
609            industry_topics: vec!["testing".to_string()],
610            brand_voice: None,
611            reply_style: None,
612            content_style: None,
613            persona_opinions: vec![],
614            persona_experiences: vec![],
615            content_pillars: vec![],
616        }
617    }
618
619    // --- parse_thread tests ---
620
621    #[test]
622    fn parse_thread_with_dashes() {
623        let text = "Tweet one\n---\nTweet two\n---\nTweet three";
624        let tweets = parse_thread(text);
625        assert_eq!(tweets.len(), 3);
626        assert_eq!(tweets[0], "Tweet one");
627        assert_eq!(tweets[1], "Tweet two");
628        assert_eq!(tweets[2], "Tweet three");
629    }
630
631    #[test]
632    fn parse_thread_with_extra_whitespace() {
633        let text = "  Tweet one  \n---\n  Tweet two  \n---\n";
634        let tweets = parse_thread(text);
635        assert_eq!(tweets.len(), 2);
636        assert_eq!(tweets[0], "Tweet one");
637        assert_eq!(tweets[1], "Tweet two");
638    }
639
640    #[test]
641    fn parse_thread_single_block_falls_back_to_numbered() {
642        let text =
643            "1/5 First tweet\n2/5 Second tweet\n3/5 Third tweet\n4/5 Fourth tweet\n5/5 Fifth tweet";
644        let tweets = parse_thread(text);
645        assert!(
646            tweets.len() >= 2,
647            "got {} tweets: {:?}",
648            tweets.len(),
649            tweets
650        );
651    }
652
653    #[test]
654    fn parse_thread_empty_sections_filtered() {
655        let text = "---\n---\nActual tweet\n---\n---";
656        let tweets = parse_thread(text);
657        assert_eq!(tweets.len(), 1);
658        assert_eq!(tweets[0], "Actual tweet");
659    }
660
661    // --- generate_reply tests ---
662
663    #[tokio::test]
664    async fn generate_reply_success() {
665        let provider =
666            MockProvider::single("Great point about testing! I've found similar results.");
667        let gen = ContentGenerator::new(Box::new(provider), test_business());
668
669        let output = gen
670            .generate_reply("Testing is important", "devuser", true)
671            .await
672            .expect("reply");
673        assert!(output.text.len() <= MAX_TWEET_CHARS);
674        assert!(!output.text.is_empty());
675        assert_eq!(output.provider, "mock");
676    }
677
678    #[tokio::test]
679    async fn generate_reply_truncates_long_output() {
680        let long_text = "a ".repeat(200); // 400 chars
681        let provider = MockProvider::new(vec![long_text.clone(), long_text]);
682        let gen = ContentGenerator::new(Box::new(provider), test_business());
683
684        let output = gen
685            .generate_reply("test", "user", true)
686            .await
687            .expect("reply");
688        assert!(output.text.len() <= MAX_TWEET_CHARS);
689    }
690
691    #[tokio::test]
692    async fn generate_reply_no_product_mention() {
693        let provider = MockProvider::single("That's a great approach for productivity!");
694        let gen = ContentGenerator::new(Box::new(provider), test_business());
695
696        let output = gen
697            .generate_reply("How do you stay productive?", "devuser", false)
698            .await
699            .expect("reply");
700        assert!(output.text.len() <= MAX_TWEET_CHARS);
701        assert!(!output.text.is_empty());
702    }
703
704    // --- generate_tweet tests ---
705
706    #[tokio::test]
707    async fn generate_tweet_success() {
708        let provider =
709            MockProvider::single("Testing your code early saves hours of debugging later.");
710        let gen = ContentGenerator::new(Box::new(provider), test_business());
711
712        let output = gen
713            .generate_tweet("testing best practices")
714            .await
715            .expect("tweet");
716        assert!(output.text.len() <= MAX_TWEET_CHARS);
717        assert!(!output.text.is_empty());
718    }
719
720    // --- generate_thread tests ---
721
722    #[tokio::test]
723    async fn generate_thread_success() {
724        let thread_text = vec![
725            "Hook tweet here",
726            "---",
727            "Second point about testing",
728            "---",
729            "Third point about quality",
730            "---",
731            "Fourth point about CI/CD",
732            "---",
733            "Fifth point about automation",
734            "---",
735            "Summary and call to action",
736        ]
737        .join("\n");
738
739        let provider = MockProvider::single(&thread_text);
740        let gen = ContentGenerator::new(Box::new(provider), test_business());
741
742        let output = gen.generate_thread("testing").await.expect("thread");
743        assert!(
744            (5..=8).contains(&output.tweets.len()),
745            "got {} tweets",
746            output.tweets.len()
747        );
748        for tweet in &output.tweets {
749            assert!(tweet.len() <= MAX_TWEET_CHARS);
750        }
751    }
752
753    #[tokio::test]
754    async fn generate_thread_retries_on_bad_count() {
755        // First attempt: too few tweets. Second: still too few. Third: valid.
756        let bad = "Tweet one\n---\nTweet two";
757        let good = "One\n---\nTwo\n---\nThree\n---\nFour\n---\nFive";
758        let provider = MockProvider::new(vec![bad.into(), bad.into(), good.into()]);
759        let gen = ContentGenerator::new(Box::new(provider), test_business());
760
761        let output = gen.generate_thread("topic").await.expect("thread");
762        assert_eq!(output.tweets.len(), 5);
763    }
764
765    #[tokio::test]
766    async fn generate_thread_fails_after_max_retries() {
767        let bad = "Tweet one\n---\nTweet two";
768        let provider = MockProvider::new(vec![bad.into(), bad.into(), bad.into()]);
769        let gen = ContentGenerator::new(Box::new(provider), test_business());
770
771        let err = gen.generate_thread("topic").await.unwrap_err();
772        assert!(matches!(err, LlmError::GenerationFailed(_)));
773    }
774
775    // --- generate_reply_with_context tests ---
776
777    #[tokio::test]
778    async fn generate_reply_with_context_injects_rag() {
779        let provider = MockProvider::single("Great insight about testing patterns!");
780        let gen = ContentGenerator::new(Box::new(provider), test_business());
781
782        let rag_block = "Winning patterns:\n1. [tip] (tweet): \"Great advice\"";
783        let output = gen
784            .generate_reply_with_context("Test tweet", "user", false, None, Some(rag_block))
785            .await
786            .expect("reply");
787
788        assert!(!output.text.is_empty());
789        assert!(output.text.len() <= MAX_TWEET_CHARS);
790    }
791
792    #[tokio::test]
793    async fn generate_reply_with_context_none_matches_archetype() {
794        // With rag_context=None, should behave like generate_reply_with_archetype
795        let provider = MockProvider::single("Agreed, great point!");
796        let gen = ContentGenerator::new(Box::new(provider), test_business());
797
798        let output = gen
799            .generate_reply_with_context("Test tweet", "user", false, None, None)
800            .await
801            .expect("reply");
802        assert!(!output.text.is_empty());
803    }
804
805    // --- GenerationParams tests ---
806
807    #[test]
808    fn generation_params_default() {
809        let params = GenerationParams::default();
810        assert_eq!(params.max_tokens, 512);
811        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
812        assert!(params.system_prompt.is_none());
813    }
814}