1use crate::config::BusinessProfile;
7use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
8use crate::content::length::{truncate_at_sentence, validate_tweet_length, MAX_TWEET_CHARS};
9use crate::error::LlmError;
10use crate::llm::{GenerationParams, LlmProvider, TokenUsage};
11
12#[derive(Debug, Clone)]
14pub struct GenerationOutput {
15 pub text: String,
17 pub usage: TokenUsage,
19 pub model: String,
21 pub provider: String,
23}
24
25#[derive(Debug, Clone)]
27pub struct ThreadGenerationOutput {
28 pub tweets: Vec<String>,
30 pub usage: TokenUsage,
32 pub model: String,
34 pub provider: String,
36}
37
38const MAX_THREAD_RETRIES: u32 = 2;
40
41pub struct ContentGenerator {
43 provider: Box<dyn LlmProvider>,
44 business: BusinessProfile,
45}
46
47impl ContentGenerator {
48 pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
50 Self { provider, business }
51 }
52
53 pub async fn generate_reply(
63 &self,
64 tweet_text: &str,
65 tweet_author: &str,
66 mention_product: bool,
67 ) -> Result<GenerationOutput, LlmError> {
68 self.generate_reply_with_archetype(tweet_text, tweet_author, mention_product, None)
69 .await
70 }
71
72 pub async fn generate_reply_with_context(
78 &self,
79 tweet_text: &str,
80 tweet_author: &str,
81 mention_product: bool,
82 archetype: Option<ReplyArchetype>,
83 rag_context: Option<&str>,
84 ) -> Result<GenerationOutput, LlmError> {
85 self.generate_reply_inner(
86 tweet_text,
87 tweet_author,
88 mention_product,
89 archetype,
90 rag_context,
91 )
92 .await
93 }
94
95 pub async fn generate_reply_with_archetype(
97 &self,
98 tweet_text: &str,
99 tweet_author: &str,
100 mention_product: bool,
101 archetype: Option<ReplyArchetype>,
102 ) -> Result<GenerationOutput, LlmError> {
103 self.generate_reply_inner(tweet_text, tweet_author, mention_product, archetype, None)
104 .await
105 }
106
107 async fn generate_reply_inner(
109 &self,
110 tweet_text: &str,
111 tweet_author: &str,
112 mention_product: bool,
113 archetype: Option<ReplyArchetype>,
114 rag_context: Option<&str>,
115 ) -> Result<GenerationOutput, LlmError> {
116 tracing::debug!(
117 author = %tweet_author,
118 archetype = ?archetype,
119 mention_product = mention_product,
120 has_rag_context = rag_context.is_some(),
121 "Generating reply",
122 );
123 let voice_section = match &self.business.brand_voice {
124 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
125 _ => String::new(),
126 };
127 let reply_section = match &self.business.reply_style {
128 Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
129 _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
130 };
131
132 let archetype_section = match archetype {
133 Some(a) => format!("\n{}", a.prompt_fragment()),
134 None => String::new(),
135 };
136
137 let persona_section = self.format_persona_context();
138
139 let rag_section = match rag_context {
140 Some(ctx) if !ctx.is_empty() => format!("\n{ctx}"),
141 _ => String::new(),
142 };
143
144 let audience_section = if self.business.target_audience.is_empty() {
145 String::new()
146 } else {
147 format!(
148 "\nYour target audience is: {}.",
149 self.business.target_audience
150 )
151 };
152
153 let product_rule = if mention_product {
154 let product_url = self.business.product_url.as_deref().unwrap_or("");
155 format!(
156 "You are a helpful community member who uses {} ({}).\
157 {audience_section}\n\
158 Product URL: {}\
159 {voice_section}\
160 {reply_section}\
161 {archetype_section}\
162 {persona_section}\
163 {rag_section}\n\n\
164 Rules:\n\
165 - Write a reply to the tweet below.\n\
166 - Maximum 3 sentences.\n\
167 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
168 - Do not use hashtags.\n\
169 - Do not use emojis excessively.",
170 self.business.product_name,
171 self.business.product_description,
172 product_url,
173 self.business.product_name,
174 )
175 } else {
176 format!(
177 "You are a helpful community member.\
178 {audience_section}\
179 {voice_section}\
180 {reply_section}\
181 {archetype_section}\
182 {persona_section}\
183 {rag_section}\n\n\
184 Rules:\n\
185 - Write a reply to the tweet below.\n\
186 - Maximum 3 sentences.\n\
187 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
188 - Do not use hashtags.\n\
189 - Do not use emojis excessively.",
190 self.business.product_name,
191 )
192 };
193
194 let system = product_rule;
195 let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
196
197 let params = GenerationParams {
198 max_tokens: 200,
199 temperature: 0.7,
200 ..Default::default()
201 };
202
203 let resp = self
204 .provider
205 .complete(&system, &user_message, ¶ms)
206 .await?;
207 let mut usage = resp.usage.clone();
208 let provider_name = self.provider.name().to_string();
209 let model = resp.model.clone();
210 let text = resp.text.trim().to_string();
211
212 tracing::debug!(chars = text.len(), "Generated reply");
213
214 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
215 return Ok(GenerationOutput {
216 text,
217 usage,
218 model,
219 provider: provider_name,
220 });
221 }
222
223 let retry_msg = format!(
225 "{user_message}\n\nImportant: Your reply MUST be under 280 characters. Be more concise."
226 );
227 let resp = self.provider.complete(&system, &retry_msg, ¶ms).await?;
228 usage.accumulate(&resp.usage);
229 let text = resp.text.trim().to_string();
230
231 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
232 return Ok(GenerationOutput {
233 text,
234 usage,
235 model,
236 provider: provider_name,
237 });
238 }
239
240 Ok(GenerationOutput {
242 text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
243 usage,
244 model,
245 provider: provider_name,
246 })
247 }
248
249 pub async fn generate_tweet(&self, topic: &str) -> Result<GenerationOutput, LlmError> {
253 self.generate_tweet_with_format(topic, None).await
254 }
255
256 pub async fn generate_tweet_with_format(
258 &self,
259 topic: &str,
260 format: Option<TweetFormat>,
261 ) -> Result<GenerationOutput, LlmError> {
262 tracing::debug!(
263 topic = %topic,
264 format = ?format,
265 "Generating tweet",
266 );
267 let voice_section = match &self.business.brand_voice {
268 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
269 _ => String::new(),
270 };
271 let content_section = match &self.business.content_style {
272 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
273 _ => "\nContent style: Be informative and engaging.".to_string(),
274 };
275
276 let format_section = match format {
277 Some(f) => format!("\n{}", f.prompt_fragment()),
278 None => String::new(),
279 };
280
281 let persona_section = self.format_persona_context();
282
283 let audience_section = if self.business.target_audience.is_empty() {
284 String::new()
285 } else {
286 format!("\nYour audience: {}.", self.business.target_audience)
287 };
288
289 let system = format!(
290 "You are {}'s social media voice. {}.\
291 {audience_section}\
292 {voice_section}\
293 {content_section}\
294 {format_section}\
295 {persona_section}\n\n\
296 Rules:\n\
297 - Write a single educational tweet about the topic below.\n\
298 - Maximum 280 characters.\n\
299 - Do not use hashtags.\n\
300 - Do not mention {} directly unless it is central to the topic.",
301 self.business.product_name,
302 self.business.product_description,
303 self.business.product_name,
304 );
305
306 let user_message = format!("Write a tweet about: {topic}");
307
308 let params = GenerationParams {
309 max_tokens: 150,
310 temperature: 0.8,
311 ..Default::default()
312 };
313
314 let resp = self
315 .provider
316 .complete(&system, &user_message, ¶ms)
317 .await?;
318 let mut usage = resp.usage.clone();
319 let provider_name = self.provider.name().to_string();
320 let model = resp.model.clone();
321 let text = resp.text.trim().to_string();
322
323 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
324 return Ok(GenerationOutput {
325 text,
326 usage,
327 model,
328 provider: provider_name,
329 });
330 }
331
332 let retry_msg = format!(
334 "{user_message}\n\nImportant: Your tweet MUST be under 280 characters. Be more concise."
335 );
336 let resp = self.provider.complete(&system, &retry_msg, ¶ms).await?;
337 usage.accumulate(&resp.usage);
338 let text = resp.text.trim().to_string();
339
340 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
341 return Ok(GenerationOutput {
342 text,
343 usage,
344 model,
345 provider: provider_name,
346 });
347 }
348
349 Ok(GenerationOutput {
350 text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
351 usage,
352 model,
353 provider: provider_name,
354 })
355 }
356
357 pub async fn generate_thread(&self, topic: &str) -> Result<ThreadGenerationOutput, LlmError> {
362 self.generate_thread_with_structure(topic, None).await
363 }
364
365 pub async fn generate_thread_with_structure(
367 &self,
368 topic: &str,
369 structure: Option<ThreadStructure>,
370 ) -> Result<ThreadGenerationOutput, LlmError> {
371 tracing::debug!(
372 topic = %topic,
373 structure = ?structure,
374 "Generating thread",
375 );
376 let voice_section = match &self.business.brand_voice {
377 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
378 _ => String::new(),
379 };
380 let content_section = match &self.business.content_style {
381 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
382 _ => "\nContent style: Be informative, not promotional.".to_string(),
383 };
384
385 let structure_section = match structure {
386 Some(s) => format!("\n{}", s.prompt_fragment()),
387 None => String::new(),
388 };
389
390 let persona_section = self.format_persona_context();
391
392 let audience_section = if self.business.target_audience.is_empty() {
393 String::new()
394 } else {
395 format!("\nYour audience: {}.", self.business.target_audience)
396 };
397
398 let system = format!(
399 "You are {}'s social media voice. {}.\
400 {audience_section}\
401 {voice_section}\
402 {content_section}\
403 {structure_section}\
404 {persona_section}\n\n\
405 Rules:\n\
406 - Write an educational thread of 5 to 8 tweets about the topic below.\n\
407 - Separate each tweet with a line containing only \"---\".\n\
408 - Each tweet must be under 280 characters.\n\
409 - The first tweet should hook the reader.\n\
410 - The last tweet should include a call to action or summary.\n\
411 - Do not use hashtags.",
412 self.business.product_name, self.business.product_description,
413 );
414
415 let user_message = format!("Write a thread about: {topic}");
416
417 let params = GenerationParams {
418 max_tokens: 1500,
419 temperature: 0.7,
420 ..Default::default()
421 };
422
423 let mut usage = TokenUsage::default();
424 let provider_name = self.provider.name().to_string();
425 let mut model = String::new();
426
427 for attempt in 0..=MAX_THREAD_RETRIES {
428 let msg = if attempt == 0 {
429 user_message.clone()
430 } else {
431 format!(
432 "{user_message}\n\nIMPORTANT: Write exactly 5-8 tweets, \
433 each under 280 characters, separated by lines containing only \"---\"."
434 )
435 };
436
437 let resp = self.provider.complete(&system, &msg, ¶ms).await?;
438 usage.accumulate(&resp.usage);
439 model.clone_from(&resp.model);
440 let tweets = parse_thread(&resp.text);
441
442 if (5..=8).contains(&tweets.len())
443 && tweets
444 .iter()
445 .all(|t| validate_tweet_length(t, MAX_TWEET_CHARS))
446 {
447 return Ok(ThreadGenerationOutput {
448 tweets,
449 usage,
450 model,
451 provider: provider_name,
452 });
453 }
454 }
455
456 Err(LlmError::GenerationFailed(
457 "Failed to generate valid thread after retries".to_string(),
458 ))
459 }
460
461 fn format_persona_context(&self) -> String {
463 let mut parts = Vec::new();
464
465 if !self.business.persona_opinions.is_empty() {
466 let opinions = self.business.persona_opinions.join("; ");
467 parts.push(format!("Opinions you hold: {opinions}"));
468 }
469
470 if !self.business.persona_experiences.is_empty() {
471 let experiences = self.business.persona_experiences.join("; ");
472 parts.push(format!("Experiences you can reference: {experiences}"));
473 }
474
475 if !self.business.content_pillars.is_empty() {
476 let pillars = self.business.content_pillars.join(", ");
477 parts.push(format!("Content pillars: {pillars}"));
478 }
479
480 if parts.is_empty() {
481 String::new()
482 } else {
483 format!("\n{}", parts.join("\n"))
484 }
485 }
486}
487
488fn parse_thread(text: &str) -> Vec<String> {
492 let tweets: Vec<String> = text
494 .split("---")
495 .map(|s| s.trim().to_string())
496 .filter(|s| !s.is_empty())
497 .collect();
498
499 if !tweets.is_empty() && text.contains("---") {
500 return tweets;
501 }
502
503 let lines: Vec<&str> = text.lines().collect();
505 let mut tweets = Vec::new();
506 let mut current = String::new();
507
508 for line in &lines {
509 let trimmed = line.trim();
510 let is_numbered = trimmed.chars().next().is_some_and(|c| c.is_ascii_digit())
511 && (trimmed.contains('/')
512 || trimmed.starts_with(|c: char| c.is_ascii_digit())
513 && trimmed.chars().nth(1).is_some_and(|c| c == '.' || c == ')'));
514
515 if is_numbered && !current.is_empty() {
516 tweets.push(current.trim().to_string());
517 current = String::new();
518 }
519
520 if !trimmed.is_empty() {
521 if !current.is_empty() {
522 current.push(' ');
523 }
524 if is_numbered {
526 let content = trimmed
527 .find(|c: char| !c.is_ascii_digit() && c != '/' && c != '.' && c != ')')
528 .map(|i| trimmed[i..].trim_start())
529 .unwrap_or(trimmed);
530 current.push_str(content);
531 } else {
532 current.push_str(trimmed);
533 }
534 }
535 }
536
537 if !current.trim().is_empty() {
538 tweets.push(current.trim().to_string());
539 }
540
541 tweets
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use crate::llm::{LlmResponse, TokenUsage};
548 use std::sync::atomic::{AtomicUsize, Ordering};
549 use std::sync::Arc;
550
551 struct MockProvider {
553 responses: Vec<String>,
554 call_count: Arc<AtomicUsize>,
555 }
556
557 impl MockProvider {
558 fn new(responses: Vec<String>) -> Self {
559 Self {
560 responses,
561 call_count: Arc::new(AtomicUsize::new(0)),
562 }
563 }
564
565 fn single(response: &str) -> Self {
566 Self::new(vec![response.to_string()])
567 }
568 }
569
570 #[async_trait::async_trait]
571 impl LlmProvider for MockProvider {
572 fn name(&self) -> &str {
573 "mock"
574 }
575
576 async fn complete(
577 &self,
578 _system: &str,
579 _user_message: &str,
580 _params: &GenerationParams,
581 ) -> Result<LlmResponse, LlmError> {
582 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
583 let text = self
584 .responses
585 .get(idx)
586 .cloned()
587 .unwrap_or_else(|| self.responses.last().cloned().unwrap_or_default());
588
589 Ok(LlmResponse {
590 text,
591 usage: TokenUsage::default(),
592 model: "mock".to_string(),
593 })
594 }
595
596 async fn health_check(&self) -> Result<(), LlmError> {
597 Ok(())
598 }
599 }
600
601 fn test_business() -> BusinessProfile {
602 BusinessProfile {
603 product_name: "TestApp".to_string(),
604 product_description: "A test application".to_string(),
605 product_url: Some("https://testapp.com".to_string()),
606 target_audience: "developers".to_string(),
607 product_keywords: vec!["test".to_string()],
608 competitor_keywords: vec![],
609 industry_topics: vec!["testing".to_string()],
610 brand_voice: None,
611 reply_style: None,
612 content_style: None,
613 persona_opinions: vec![],
614 persona_experiences: vec![],
615 content_pillars: vec![],
616 }
617 }
618
619 #[test]
622 fn parse_thread_with_dashes() {
623 let text = "Tweet one\n---\nTweet two\n---\nTweet three";
624 let tweets = parse_thread(text);
625 assert_eq!(tweets.len(), 3);
626 assert_eq!(tweets[0], "Tweet one");
627 assert_eq!(tweets[1], "Tweet two");
628 assert_eq!(tweets[2], "Tweet three");
629 }
630
631 #[test]
632 fn parse_thread_with_extra_whitespace() {
633 let text = " Tweet one \n---\n Tweet two \n---\n";
634 let tweets = parse_thread(text);
635 assert_eq!(tweets.len(), 2);
636 assert_eq!(tweets[0], "Tweet one");
637 assert_eq!(tweets[1], "Tweet two");
638 }
639
640 #[test]
641 fn parse_thread_single_block_falls_back_to_numbered() {
642 let text =
643 "1/5 First tweet\n2/5 Second tweet\n3/5 Third tweet\n4/5 Fourth tweet\n5/5 Fifth tweet";
644 let tweets = parse_thread(text);
645 assert!(
646 tweets.len() >= 2,
647 "got {} tweets: {:?}",
648 tweets.len(),
649 tweets
650 );
651 }
652
653 #[test]
654 fn parse_thread_empty_sections_filtered() {
655 let text = "---\n---\nActual tweet\n---\n---";
656 let tweets = parse_thread(text);
657 assert_eq!(tweets.len(), 1);
658 assert_eq!(tweets[0], "Actual tweet");
659 }
660
661 #[tokio::test]
664 async fn generate_reply_success() {
665 let provider =
666 MockProvider::single("Great point about testing! I've found similar results.");
667 let gen = ContentGenerator::new(Box::new(provider), test_business());
668
669 let output = gen
670 .generate_reply("Testing is important", "devuser", true)
671 .await
672 .expect("reply");
673 assert!(output.text.len() <= MAX_TWEET_CHARS);
674 assert!(!output.text.is_empty());
675 assert_eq!(output.provider, "mock");
676 }
677
678 #[tokio::test]
679 async fn generate_reply_truncates_long_output() {
680 let long_text = "a ".repeat(200); let provider = MockProvider::new(vec![long_text.clone(), long_text]);
682 let gen = ContentGenerator::new(Box::new(provider), test_business());
683
684 let output = gen
685 .generate_reply("test", "user", true)
686 .await
687 .expect("reply");
688 assert!(output.text.len() <= MAX_TWEET_CHARS);
689 }
690
691 #[tokio::test]
692 async fn generate_reply_no_product_mention() {
693 let provider = MockProvider::single("That's a great approach for productivity!");
694 let gen = ContentGenerator::new(Box::new(provider), test_business());
695
696 let output = gen
697 .generate_reply("How do you stay productive?", "devuser", false)
698 .await
699 .expect("reply");
700 assert!(output.text.len() <= MAX_TWEET_CHARS);
701 assert!(!output.text.is_empty());
702 }
703
704 #[tokio::test]
707 async fn generate_tweet_success() {
708 let provider =
709 MockProvider::single("Testing your code early saves hours of debugging later.");
710 let gen = ContentGenerator::new(Box::new(provider), test_business());
711
712 let output = gen
713 .generate_tweet("testing best practices")
714 .await
715 .expect("tweet");
716 assert!(output.text.len() <= MAX_TWEET_CHARS);
717 assert!(!output.text.is_empty());
718 }
719
720 #[tokio::test]
723 async fn generate_thread_success() {
724 let thread_text = vec![
725 "Hook tweet here",
726 "---",
727 "Second point about testing",
728 "---",
729 "Third point about quality",
730 "---",
731 "Fourth point about CI/CD",
732 "---",
733 "Fifth point about automation",
734 "---",
735 "Summary and call to action",
736 ]
737 .join("\n");
738
739 let provider = MockProvider::single(&thread_text);
740 let gen = ContentGenerator::new(Box::new(provider), test_business());
741
742 let output = gen.generate_thread("testing").await.expect("thread");
743 assert!(
744 (5..=8).contains(&output.tweets.len()),
745 "got {} tweets",
746 output.tweets.len()
747 );
748 for tweet in &output.tweets {
749 assert!(tweet.len() <= MAX_TWEET_CHARS);
750 }
751 }
752
753 #[tokio::test]
754 async fn generate_thread_retries_on_bad_count() {
755 let bad = "Tweet one\n---\nTweet two";
757 let good = "One\n---\nTwo\n---\nThree\n---\nFour\n---\nFive";
758 let provider = MockProvider::new(vec![bad.into(), bad.into(), good.into()]);
759 let gen = ContentGenerator::new(Box::new(provider), test_business());
760
761 let output = gen.generate_thread("topic").await.expect("thread");
762 assert_eq!(output.tweets.len(), 5);
763 }
764
765 #[tokio::test]
766 async fn generate_thread_fails_after_max_retries() {
767 let bad = "Tweet one\n---\nTweet two";
768 let provider = MockProvider::new(vec![bad.into(), bad.into(), bad.into()]);
769 let gen = ContentGenerator::new(Box::new(provider), test_business());
770
771 let err = gen.generate_thread("topic").await.unwrap_err();
772 assert!(matches!(err, LlmError::GenerationFailed(_)));
773 }
774
775 #[tokio::test]
778 async fn generate_reply_with_context_injects_rag() {
779 let provider = MockProvider::single("Great insight about testing patterns!");
780 let gen = ContentGenerator::new(Box::new(provider), test_business());
781
782 let rag_block = "Winning patterns:\n1. [tip] (tweet): \"Great advice\"";
783 let output = gen
784 .generate_reply_with_context("Test tweet", "user", false, None, Some(rag_block))
785 .await
786 .expect("reply");
787
788 assert!(!output.text.is_empty());
789 assert!(output.text.len() <= MAX_TWEET_CHARS);
790 }
791
792 #[tokio::test]
793 async fn generate_reply_with_context_none_matches_archetype() {
794 let provider = MockProvider::single("Agreed, great point!");
796 let gen = ContentGenerator::new(Box::new(provider), test_business());
797
798 let output = gen
799 .generate_reply_with_context("Test tweet", "user", false, None, None)
800 .await
801 .expect("reply");
802 assert!(!output.text.is_empty());
803 }
804
805 #[test]
808 fn generation_params_default() {
809 let params = GenerationParams::default();
810 assert_eq!(params.max_tokens, 512);
811 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
812 assert!(params.system_prompt.is_none());
813 }
814}