1pub(crate) mod angles;
7pub(crate) mod parser;
8
9#[cfg(test)]
10mod tests;
11
12use crate::config::BusinessProfile;
13use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
14use crate::content::length::{truncate_at_sentence, validate_tweet_length, MAX_TWEET_CHARS};
15use crate::error::LlmError;
16use crate::llm::{GenerationParams, LlmProvider, TokenUsage};
17
18use parser::{parse_hooks_response, parse_thread};
19
20#[derive(Debug, Clone)]
22pub struct GenerationOutput {
23 pub text: String,
25 pub usage: TokenUsage,
27 pub model: String,
29 pub provider: String,
31}
32
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
35pub struct HookOption {
36 pub style: String,
38 pub text: String,
40 pub char_count: usize,
42 pub confidence: String,
44}
45
46#[derive(Debug, Clone)]
48pub struct HookGenerationOutput {
49 pub hooks: Vec<HookOption>,
51 pub usage: TokenUsage,
53 pub model: String,
55 pub provider: String,
57}
58
59#[derive(Debug, Clone)]
61pub struct ThreadGenerationOutput {
62 pub tweets: Vec<String>,
64 pub usage: TokenUsage,
66 pub model: String,
68 pub provider: String,
70}
71
72const MAX_THREAD_RETRIES: u32 = 2;
74
75pub struct ContentGenerator {
77 provider: Box<dyn LlmProvider>,
78 business: BusinessProfile,
79}
80
81impl ContentGenerator {
82 pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
84 Self { provider, business }
85 }
86
87 pub fn business(&self) -> &BusinessProfile {
89 &self.business
90 }
91
92 pub async fn generate_reply(
98 &self,
99 tweet_text: &str,
100 tweet_author: &str,
101 mention_product: bool,
102 ) -> Result<GenerationOutput, LlmError> {
103 self.generate_reply_inner(tweet_text, tweet_author, mention_product, None, None)
104 .await
105 }
106
107 pub async fn generate_reply_with_archetype(
109 &self,
110 tweet_text: &str,
111 tweet_author: &str,
112 mention_product: bool,
113 archetype: Option<ReplyArchetype>,
114 ) -> Result<GenerationOutput, LlmError> {
115 self.generate_reply_inner(tweet_text, tweet_author, mention_product, archetype, None)
116 .await
117 }
118
119 pub async fn generate_reply_with_context(
121 &self,
122 tweet_text: &str,
123 tweet_author: &str,
124 mention_product: bool,
125 archetype: Option<ReplyArchetype>,
126 rag_context: Option<&str>,
127 ) -> Result<GenerationOutput, LlmError> {
128 self.generate_reply_inner(
129 tweet_text,
130 tweet_author,
131 mention_product,
132 archetype,
133 rag_context,
134 )
135 .await
136 }
137
138 async fn generate_reply_inner(
140 &self,
141 tweet_text: &str,
142 tweet_author: &str,
143 mention_product: bool,
144 archetype: Option<ReplyArchetype>,
145 rag_context: Option<&str>,
146 ) -> Result<GenerationOutput, LlmError> {
147 tracing::debug!(
148 author = %tweet_author,
149 archetype = ?archetype,
150 mention_product = mention_product,
151 has_rag_context = rag_context.is_some(),
152 "Generating reply",
153 );
154
155 let voice_section = self.format_voice_section();
156 let reply_section = match &self.business.reply_style {
157 Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
158 _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
159 };
160 let archetype_section = match archetype {
161 Some(a) => format!("\n{}", a.prompt_fragment()),
162 None => String::new(),
163 };
164 let persona_section = self.format_persona_context();
165 let rag_section = Self::format_rag_section(rag_context);
166 let audience_section = self.format_audience_section();
167
168 let system = if mention_product {
169 let product_url = self.business.product_url.as_deref().unwrap_or("");
170 format!(
171 "You are a helpful community member who uses {} ({}).\
172 {audience_section}\n\
173 Product URL: {}\
174 {voice_section}\
175 {reply_section}\
176 {archetype_section}\
177 {persona_section}\
178 {rag_section}\n\n\
179 Rules:\n\
180 - Write a reply to the tweet below.\n\
181 - Maximum 3 sentences.\n\
182 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
183 - Do not use hashtags.\n\
184 - Do not use emojis excessively.",
185 self.business.product_name,
186 self.business.product_description,
187 product_url,
188 self.business.product_name,
189 )
190 } else {
191 format!(
192 "You are a helpful community member.\
193 {audience_section}\
194 {voice_section}\
195 {reply_section}\
196 {archetype_section}\
197 {persona_section}\
198 {rag_section}\n\n\
199 Rules:\n\
200 - Write a reply to the tweet below.\n\
201 - Maximum 3 sentences.\n\
202 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
203 - Do not use hashtags.\n\
204 - Do not use emojis excessively.",
205 self.business.product_name,
206 )
207 };
208
209 let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
210 let params = GenerationParams {
211 max_tokens: 200,
212 temperature: 0.7,
213 ..Default::default()
214 };
215
216 self.generate_single(&system, &user_message, ¶ms).await
217 }
218
219 pub async fn generate_tweet(&self, topic: &str) -> Result<GenerationOutput, LlmError> {
225 self.generate_tweet_inner(topic, None, None).await
226 }
227
228 pub async fn generate_tweet_with_format(
230 &self,
231 topic: &str,
232 format: Option<TweetFormat>,
233 ) -> Result<GenerationOutput, LlmError> {
234 self.generate_tweet_inner(topic, format, None).await
235 }
236
237 pub async fn generate_tweet_with_context(
239 &self,
240 topic: &str,
241 format: Option<TweetFormat>,
242 rag_context: Option<&str>,
243 ) -> Result<GenerationOutput, LlmError> {
244 self.generate_tweet_inner(topic, format, rag_context).await
245 }
246
247 async fn generate_tweet_inner(
249 &self,
250 topic: &str,
251 format: Option<TweetFormat>,
252 rag_context: Option<&str>,
253 ) -> Result<GenerationOutput, LlmError> {
254 tracing::debug!(
255 topic = %topic,
256 format = ?format,
257 has_rag_context = rag_context.is_some(),
258 "Generating tweet",
259 );
260
261 let voice_section = self.format_voice_section();
262 let content_section = match &self.business.content_style {
263 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
264 _ => "\nContent style: Be informative and engaging.".to_string(),
265 };
266 let format_section = match format {
267 Some(f) => format!("\n{}", f.prompt_fragment()),
268 None => String::new(),
269 };
270 let persona_section = self.format_persona_context();
271 let rag_section = Self::format_rag_section(rag_context);
272 let audience_section = self.format_audience_section();
273
274 let system = format!(
275 "You are {}'s social media voice. {}.\
276 {audience_section}\
277 {voice_section}\
278 {content_section}\
279 {format_section}\
280 {persona_section}\
281 {rag_section}\n\n\
282 Rules:\n\
283 - Write a single educational tweet about the topic below.\n\
284 - Maximum 280 characters.\n\
285 - Do not use hashtags.\n\
286 - Do not mention {} directly unless it is central to the topic.",
287 self.business.product_name,
288 self.business.product_description,
289 self.business.product_name,
290 );
291
292 let user_message = format!("Write a tweet about: {topic}");
293 let params = GenerationParams {
294 max_tokens: 150,
295 temperature: 0.8,
296 ..Default::default()
297 };
298
299 self.generate_single(&system, &user_message, ¶ms).await
300 }
301
302 pub async fn improve_draft(
308 &self,
309 draft: &str,
310 tone_cue: Option<&str>,
311 ) -> Result<GenerationOutput, LlmError> {
312 self.improve_draft_inner(draft, tone_cue, None).await
313 }
314
315 pub async fn improve_draft_with_context(
318 &self,
319 draft: &str,
320 tone_cue: Option<&str>,
321 rag_context: Option<&str>,
322 ) -> Result<GenerationOutput, LlmError> {
323 self.improve_draft_inner(draft, tone_cue, rag_context).await
324 }
325
326 async fn improve_draft_inner(
328 &self,
329 draft: &str,
330 tone_cue: Option<&str>,
331 rag_context: Option<&str>,
332 ) -> Result<GenerationOutput, LlmError> {
333 tracing::debug!(
334 draft_len = draft.len(),
335 tone_cue = ?tone_cue,
336 has_rag_context = rag_context.is_some(),
337 "Improving draft",
338 );
339
340 let voice_section = self.format_voice_section();
341 let persona_section = self.format_persona_context();
342 let rag_section = Self::format_rag_section(rag_context);
343
344 let tone_instruction = match tone_cue {
345 Some(cue) if !cue.is_empty() => {
346 format!("\n\nTone/style directive (MUST follow): {cue}")
347 }
348 _ => String::new(),
349 };
350
351 let system = format!(
352 "You are {}'s social media voice. {}.\
353 {voice_section}\
354 {persona_section}\
355 {rag_section}\n\n\
356 Task: Rewrite and improve the draft tweet below. \
357 Keep the core message but make it sharper, more engaging, \
358 and better-written.{tone_instruction}\n\n\
359 Rules:\n\
360 - Maximum 280 characters.\n\
361 - Do not use hashtags.\n\
362 - Output only the improved tweet text, nothing else.",
363 self.business.product_name, self.business.product_description,
364 );
365
366 let user_message = format!("Draft to improve:\n{draft}");
367 let params = GenerationParams {
368 max_tokens: 150,
369 temperature: 0.7,
370 ..Default::default()
371 };
372
373 self.generate_single(&system, &user_message, ¶ms).await
374 }
375
376 pub async fn generate_hooks(
382 &self,
383 topic: &str,
384 rag_context: Option<&str>,
385 ) -> Result<HookGenerationOutput, LlmError> {
386 tracing::debug!(
387 topic = %topic,
388 has_rag_context = rag_context.is_some(),
389 "Generating hooks",
390 );
391
392 let styles = Self::select_hook_styles();
393 let style_list = styles
394 .iter()
395 .enumerate()
396 .map(|(i, f)| format!("{}. {}", i + 1, f))
397 .collect::<Vec<_>>()
398 .join("\n");
399
400 let voice_section = self.format_voice_section();
401 let persona_section = self.format_persona_context();
402 let rag_section = Self::format_rag_section(rag_context);
403 let audience_section = self.format_audience_section();
404
405 let system = format!(
406 "You are {}'s social media voice. {}.\
407 {audience_section}\
408 {voice_section}\
409 {persona_section}\
410 {rag_section}\n\n\
411 Task: Generate exactly 5 hook tweets for the topic below, \
412 one per style listed. Each hook must be a standalone tweet \
413 (max 280 characters) that grabs attention.\n\n\
414 Required styles (one hook per style):\n{style_list}\n\n\
415 Output format (strictly follow this, no extra text):\n\
416 STYLE: <style_name>\n\
417 HOOK: <hook text>\n\
418 ---\n\
419 (repeat for all 5)",
420 self.business.product_name, self.business.product_description,
421 );
422
423 let user_message = format!("Generate hooks about: {topic}");
424 let params = GenerationParams {
425 max_tokens: 800,
426 temperature: 0.9,
427 ..Default::default()
428 };
429
430 let mut usage = TokenUsage::default();
431 let provider_name = self.provider.name().to_string();
432
433 let resp = self
434 .provider
435 .complete(&system, &user_message, ¶ms)
436 .await?;
437 usage.accumulate(&resp.usage);
438 let model = resp.model.clone();
439
440 tracing::debug!(
441 raw_response = %resp.text,
442 "Raw LLM response for hook generation"
443 );
444
445 let mut hooks = Self::build_hook_options(&parse_hooks_response(&resp.text));
446
447 if hooks.len() < 3 {
449 tracing::debug!(count = hooks.len(), "Too few hooks, retrying");
450 let retry_msg = format!(
451 "{user_message}\n\nIMPORTANT: Output exactly 5 hooks, \
452 each with STYLE: and HOOK: lines, separated by ---."
453 );
454 let resp = self.provider.complete(&system, &retry_msg, ¶ms).await?;
455 usage.accumulate(&resp.usage);
456
457 tracing::debug!(
458 raw_response = %resp.text,
459 "Raw LLM retry response for hook generation"
460 );
461
462 hooks = Self::build_hook_options(&parse_hooks_response(&resp.text));
463 }
464
465 if hooks.is_empty() {
466 return Err(LlmError::GenerationFailed(
467 "No valid hooks could be generated".to_string(),
468 ));
469 }
470
471 hooks.truncate(5);
473
474 Ok(HookGenerationOutput {
475 hooks,
476 usage,
477 model,
478 provider: provider_name,
479 })
480 }
481
482 fn select_hook_styles() -> Vec<TweetFormat> {
485 use rand::seq::SliceRandom;
486
487 let mut styles = vec![TweetFormat::Question, TweetFormat::ContrarianTake];
488 let remaining = [
489 TweetFormat::List,
490 TweetFormat::MostPeopleThinkX,
491 TweetFormat::Storytelling,
492 TweetFormat::BeforeAfter,
493 TweetFormat::Tip,
494 ];
495 let mut rng = rand::rng();
496 let mut pool = remaining.to_vec();
497 pool.shuffle(&mut rng);
498 styles.extend(pool.into_iter().take(3));
499 styles
500 }
501
502 fn build_hook_options(parsed: &[(String, String)]) -> Vec<HookOption> {
505 parsed
506 .iter()
507 .filter(|(_, text)| !text.is_empty() && text.len() <= MAX_TWEET_CHARS)
508 .map(|(style, text)| {
509 let char_count = text.len();
510 let confidence = if char_count <= 240 {
511 "high".to_string()
512 } else {
513 "medium".to_string()
514 };
515 HookOption {
516 style: style.clone(),
517 text: text.clone(),
518 char_count,
519 confidence,
520 }
521 })
522 .collect()
523 }
524
525 pub async fn generate_mined_angles(
531 &self,
532 topic: &str,
533 neighbors: &[crate::content::evidence::NeighborContent],
534 selection_context: Option<&str>,
535 ) -> Result<crate::content::angles::AngleMiningOutput, LlmError> {
536 angles::generate_mined_angles(
537 &*self.provider,
538 &self.business,
539 topic,
540 neighbors,
541 selection_context,
542 )
543 .await
544 }
545
546 pub async fn extract_highlights(&self, rag_context: &str) -> Result<Vec<String>, LlmError> {
556 tracing::debug!(context_len = rag_context.len(), "Extracting key highlights",);
557
558 let system = format!(
559 "You are {}'s content strategist. {}.\n\n\
560 Task: Read the context below and extract 3 to 5 concise, \
561 tweetable key insights as bullet points.\n\n\
562 Rules:\n\
563 - Each bullet should be a single clear insight or idea.\n\
564 - Keep each bullet under 200 characters.\n\
565 - Output only the bullet list, one per line.\n\
566 - Use a dash (-) prefix for each bullet.\n\
567 - No numbering, no sub-bullets, no headers.",
568 self.business.product_name, self.business.product_description,
569 );
570
571 let user_message = format!("Context:\n{rag_context}");
572 let params = GenerationParams {
573 max_tokens: 500,
574 temperature: 0.5,
575 ..Default::default()
576 };
577
578 let resp = self
579 .provider
580 .complete(&system, &user_message, ¶ms)
581 .await?;
582
583 tracing::debug!(
584 raw_response = %resp.text,
585 "Raw LLM response for highlight extraction"
586 );
587
588 let highlights: Vec<String> = resp
589 .text
590 .lines()
591 .map(|line| strip_bullet_prefix(line.trim()))
592 .filter(|s| !s.is_empty())
593 .collect();
594
595 if highlights.is_empty() {
596 tracing::warn!(
597 raw_response = %resp.text,
598 "Highlight extraction produced no results after parsing"
599 );
600 return Err(LlmError::GenerationFailed(
601 "No highlights could be extracted from the provided context".to_string(),
602 ));
603 }
604
605 Ok(highlights)
606 }
607
608 pub async fn generate_thread(&self, topic: &str) -> Result<ThreadGenerationOutput, LlmError> {
614 self.generate_thread_inner(topic, None, None, None).await
615 }
616
617 pub async fn generate_thread_with_structure(
619 &self,
620 topic: &str,
621 structure: Option<ThreadStructure>,
622 ) -> Result<ThreadGenerationOutput, LlmError> {
623 self.generate_thread_inner(topic, structure, None, None)
624 .await
625 }
626
627 pub async fn generate_thread_with_context(
629 &self,
630 topic: &str,
631 structure: Option<ThreadStructure>,
632 rag_context: Option<&str>,
633 ) -> Result<ThreadGenerationOutput, LlmError> {
634 self.generate_thread_inner(topic, structure, rag_context, None)
635 .await
636 }
637
638 pub async fn generate_thread_with_hook(
643 &self,
644 topic: &str,
645 opening_hook: &str,
646 structure: Option<ThreadStructure>,
647 rag_context: Option<&str>,
648 ) -> Result<ThreadGenerationOutput, LlmError> {
649 self.generate_thread_inner(topic, structure, rag_context, Some(opening_hook))
650 .await
651 }
652
653 async fn generate_thread_inner(
655 &self,
656 topic: &str,
657 structure: Option<ThreadStructure>,
658 rag_context: Option<&str>,
659 opening_hook: Option<&str>,
660 ) -> Result<ThreadGenerationOutput, LlmError> {
661 tracing::debug!(
662 topic = %topic,
663 structure = ?structure,
664 has_rag_context = rag_context.is_some(),
665 has_opening_hook = opening_hook.is_some(),
666 "Generating thread",
667 );
668
669 let voice_section = self.format_voice_section();
670 let content_section = match &self.business.content_style {
671 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
672 _ => "\nContent style: Be informative, not promotional.".to_string(),
673 };
674 let structure_section = match structure {
675 Some(s) => format!("\n{}", s.prompt_fragment()),
676 None => String::new(),
677 };
678 let persona_section = self.format_persona_context();
679 let rag_section = Self::format_rag_section(rag_context);
680 let audience_section = self.format_audience_section();
681
682 let (hook_rule, tweet_count_rule) = match opening_hook {
683 Some(hook) => (
684 format!(
685 "\n- The first tweet of the thread is ALREADY WRITTEN. \
686 Do NOT include it in your output.\n\
687 - Here is the first tweet (for context only): \"{hook}\"\n\
688 - Write 4 to 7 ADDITIONAL tweets that continue from that opening."
689 ),
690 "4 to 7",
691 ),
692 None => (
693 "\n- The first tweet should hook the reader.".to_string(),
694 "5 to 8",
695 ),
696 };
697
698 let system = format!(
699 "You are {}'s social media voice. {}.\
700 {audience_section}\
701 {voice_section}\
702 {content_section}\
703 {structure_section}\
704 {persona_section}\
705 {rag_section}\n\n\
706 Rules:\n\
707 - Write an educational thread of {tweet_count_rule} tweets about the topic below.\n\
708 - Separate each tweet with a line containing only \"---\".\n\
709 - Each tweet must be under 280 characters.{hook_rule}\n\
710 - The last tweet should include a call to action or summary.\n\
711 - Do not use hashtags.",
712 self.business.product_name, self.business.product_description,
713 );
714
715 let user_message = format!("Write a thread about: {topic}");
716 let params = GenerationParams {
717 max_tokens: 1500,
718 temperature: 0.7,
719 ..Default::default()
720 };
721
722 let mut usage = TokenUsage::default();
723 let provider_name = self.provider.name().to_string();
724 let mut model = String::new();
725
726 let (min_gen, max_gen) = if opening_hook.is_some() {
728 (4, 7)
729 } else {
730 (5, 8)
731 };
732
733 for attempt in 0..=MAX_THREAD_RETRIES {
734 let msg = if attempt == 0 {
735 user_message.clone()
736 } else {
737 format!(
738 "{user_message}\n\nIMPORTANT: Write exactly {tweet_count_rule} tweets, \
739 each under 280 characters, separated by lines containing only \"---\"."
740 )
741 };
742
743 let resp = self.provider.complete(&system, &msg, ¶ms).await?;
744 usage.accumulate(&resp.usage);
745 model.clone_from(&resp.model);
746 let mut tweets = parse_thread(&resp.text);
747
748 if let Some(hook) = opening_hook {
750 tweets.insert(0, hook.to_string());
751 }
752
753 let gen_count = tweets.len() - if opening_hook.is_some() { 1 } else { 0 };
754 if (min_gen..=max_gen).contains(&gen_count)
755 && tweets
756 .iter()
757 .all(|t| validate_tweet_length(t, MAX_TWEET_CHARS))
758 {
759 return Ok(ThreadGenerationOutput {
760 tweets,
761 usage,
762 model,
763 provider: provider_name,
764 });
765 }
766 }
767
768 Err(LlmError::GenerationFailed(
769 "Failed to generate valid thread after retries".to_string(),
770 ))
771 }
772
773 async fn generate_single(
779 &self,
780 system: &str,
781 user_message: &str,
782 params: &GenerationParams,
783 ) -> Result<GenerationOutput, LlmError> {
784 let resp = self.provider.complete(system, user_message, params).await?;
785 let mut usage = resp.usage.clone();
786 let provider_name = self.provider.name().to_string();
787 let model = resp.model.clone();
788 let text = resp.text.trim().to_string();
789
790 tracing::debug!(chars = text.len(), "Generated content");
791
792 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
793 return Ok(GenerationOutput {
794 text,
795 usage,
796 model,
797 provider: provider_name,
798 });
799 }
800
801 let retry_msg = format!(
803 "{user_message}\n\nImportant: Your response MUST be under 280 characters. Be more concise."
804 );
805 let resp = self.provider.complete(system, &retry_msg, params).await?;
806 usage.accumulate(&resp.usage);
807 let text = resp.text.trim().to_string();
808
809 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
810 return Ok(GenerationOutput {
811 text,
812 usage,
813 model,
814 provider: provider_name,
815 });
816 }
817
818 Ok(GenerationOutput {
820 text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
821 usage,
822 model,
823 provider: provider_name,
824 })
825 }
826
827 fn format_voice_section(&self) -> String {
828 match &self.business.brand_voice {
829 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
830 _ => String::new(),
831 }
832 }
833
834 fn format_audience_section(&self) -> String {
835 if self.business.target_audience.is_empty() {
836 String::new()
837 } else {
838 format!("\nYour audience: {}.", self.business.target_audience)
839 }
840 }
841
842 fn format_rag_section(rag_context: Option<&str>) -> String {
843 match rag_context {
844 Some(ctx) if !ctx.is_empty() => format!("\n{ctx}"),
845 _ => String::new(),
846 }
847 }
848
849 fn format_persona_context(&self) -> String {
851 let mut parts = Vec::new();
852
853 if !self.business.persona_opinions.is_empty() {
854 let opinions = self.business.persona_opinions.join("; ");
855 parts.push(format!("Opinions you hold: {opinions}"));
856 }
857
858 if !self.business.persona_experiences.is_empty() {
859 let experiences = self.business.persona_experiences.join("; ");
860 parts.push(format!("Experiences you can reference: {experiences}"));
861 }
862
863 if !self.business.content_pillars.is_empty() {
864 let pillars = self.business.content_pillars.join(", ");
865 parts.push(format!("Content pillars: {pillars}"));
866 }
867
868 if parts.is_empty() {
869 String::new()
870 } else {
871 format!("\n{}", parts.join("\n"))
872 }
873 }
874}
875
876fn strip_bullet_prefix(line: &str) -> String {
881 let s = line
882 .trim_start_matches(|c: char| c == '(' || c.is_ascii_whitespace())
883 .trim_start_matches(|c: char| c.is_ascii_digit())
884 .trim_start_matches(['.', ')', ':', '-', '*', '•', '—'])
885 .trim();
886 s.to_string()
887}