1use crate::config::BusinessProfile;
7use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
8use crate::error::LlmError;
9use crate::llm::{GenerationParams, LlmProvider};
10
11const MAX_TWEET_CHARS: usize = 280;
13
14const MAX_THREAD_RETRIES: u32 = 2;
16
17pub struct ContentGenerator {
19 provider: Box<dyn LlmProvider>,
20 business: BusinessProfile,
21}
22
23impl ContentGenerator {
24 pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
26 Self { provider, business }
27 }
28
29 pub async fn generate_reply(
39 &self,
40 tweet_text: &str,
41 tweet_author: &str,
42 mention_product: bool,
43 ) -> Result<String, LlmError> {
44 self.generate_reply_with_archetype(tweet_text, tweet_author, mention_product, None)
45 .await
46 }
47
48 pub async fn generate_reply_with_archetype(
50 &self,
51 tweet_text: &str,
52 tweet_author: &str,
53 mention_product: bool,
54 archetype: Option<ReplyArchetype>,
55 ) -> Result<String, LlmError> {
56 tracing::debug!(
57 author = %tweet_author,
58 archetype = ?archetype,
59 mention_product = mention_product,
60 "Generating reply",
61 );
62 let voice_section = match &self.business.brand_voice {
63 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
64 _ => String::new(),
65 };
66 let reply_section = match &self.business.reply_style {
67 Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
68 _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
69 };
70
71 let archetype_section = match archetype {
72 Some(a) => format!("\n{}", a.prompt_fragment()),
73 None => String::new(),
74 };
75
76 let persona_section = self.format_persona_context();
77
78 let product_rule = if mention_product {
79 let product_url = self.business.product_url.as_deref().unwrap_or("");
80 format!(
81 "You are a helpful community member who uses {} ({}).\n\
82 Your target audience is: {}.\n\
83 Product URL: {}\
84 {voice_section}\
85 {reply_section}\
86 {archetype_section}\
87 {persona_section}\n\n\
88 Rules:\n\
89 - Write a reply to the tweet below.\n\
90 - Maximum 3 sentences.\n\
91 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
92 - Do not use hashtags.\n\
93 - Do not use emojis excessively.",
94 self.business.product_name,
95 self.business.product_description,
96 self.business.target_audience,
97 product_url,
98 self.business.product_name,
99 )
100 } else {
101 format!(
102 "You are a helpful community member.\n\
103 Your target audience is: {}.\
104 {voice_section}\
105 {reply_section}\
106 {archetype_section}\
107 {persona_section}\n\n\
108 Rules:\n\
109 - Write a reply to the tweet below.\n\
110 - Maximum 3 sentences.\n\
111 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
112 - Do not use hashtags.\n\
113 - Do not use emojis excessively.",
114 self.business.target_audience, self.business.product_name,
115 )
116 };
117
118 let system = product_rule;
119 let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
120
121 let params = GenerationParams {
122 max_tokens: 200,
123 temperature: 0.7,
124 ..Default::default()
125 };
126
127 let resp = self
128 .provider
129 .complete(&system, &user_message, ¶ms)
130 .await?;
131 let text = resp.text.trim().to_string();
132
133 tracing::debug!(chars = text.len(), "Generated reply");
134
135 if validate_length(&text, MAX_TWEET_CHARS) {
136 return Ok(text);
137 }
138
139 let retry_msg = format!(
141 "{user_message}\n\nImportant: Your reply MUST be under 280 characters. Be more concise."
142 );
143 let resp = self.provider.complete(&system, &retry_msg, ¶ms).await?;
144 let text = resp.text.trim().to_string();
145
146 if validate_length(&text, MAX_TWEET_CHARS) {
147 return Ok(text);
148 }
149
150 Ok(truncate_at_sentence(&text, MAX_TWEET_CHARS))
152 }
153
154 pub async fn generate_tweet(&self, topic: &str) -> Result<String, LlmError> {
158 self.generate_tweet_with_format(topic, None).await
159 }
160
161 pub async fn generate_tweet_with_format(
163 &self,
164 topic: &str,
165 format: Option<TweetFormat>,
166 ) -> Result<String, LlmError> {
167 tracing::debug!(
168 topic = %topic,
169 format = ?format,
170 "Generating tweet",
171 );
172 let voice_section = match &self.business.brand_voice {
173 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
174 _ => String::new(),
175 };
176 let content_section = match &self.business.content_style {
177 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
178 _ => "\nContent style: Be informative and engaging.".to_string(),
179 };
180
181 let format_section = match format {
182 Some(f) => format!("\n{}", f.prompt_fragment()),
183 None => String::new(),
184 };
185
186 let persona_section = self.format_persona_context();
187
188 let system = format!(
189 "You are {}'s social media voice. {}.\n\
190 Your audience: {}.\
191 {voice_section}\
192 {content_section}\
193 {format_section}\
194 {persona_section}\n\n\
195 Rules:\n\
196 - Write a single educational tweet about the topic below.\n\
197 - Maximum 280 characters.\n\
198 - Do not use hashtags.\n\
199 - Do not mention {} directly unless it is central to the topic.",
200 self.business.product_name,
201 self.business.product_description,
202 self.business.target_audience,
203 self.business.product_name,
204 );
205
206 let user_message = format!("Write a tweet about: {topic}");
207
208 let params = GenerationParams {
209 max_tokens: 150,
210 temperature: 0.8,
211 ..Default::default()
212 };
213
214 let resp = self
215 .provider
216 .complete(&system, &user_message, ¶ms)
217 .await?;
218 let text = resp.text.trim().to_string();
219
220 if validate_length(&text, MAX_TWEET_CHARS) {
221 return Ok(text);
222 }
223
224 let retry_msg = format!(
226 "{user_message}\n\nImportant: Your tweet MUST be under 280 characters. Be more concise."
227 );
228 let resp = self.provider.complete(&system, &retry_msg, ¶ms).await?;
229 let text = resp.text.trim().to_string();
230
231 if validate_length(&text, MAX_TWEET_CHARS) {
232 return Ok(text);
233 }
234
235 Ok(truncate_at_sentence(&text, MAX_TWEET_CHARS))
236 }
237
238 pub async fn generate_thread(&self, topic: &str) -> Result<Vec<String>, LlmError> {
243 self.generate_thread_with_structure(topic, None).await
244 }
245
246 pub async fn generate_thread_with_structure(
248 &self,
249 topic: &str,
250 structure: Option<ThreadStructure>,
251 ) -> Result<Vec<String>, LlmError> {
252 tracing::debug!(
253 topic = %topic,
254 structure = ?structure,
255 "Generating thread",
256 );
257 let voice_section = match &self.business.brand_voice {
258 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
259 _ => String::new(),
260 };
261 let content_section = match &self.business.content_style {
262 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
263 _ => "\nContent style: Be informative, not promotional.".to_string(),
264 };
265
266 let structure_section = match structure {
267 Some(s) => format!("\n{}", s.prompt_fragment()),
268 None => String::new(),
269 };
270
271 let persona_section = self.format_persona_context();
272
273 let system = format!(
274 "You are {}'s social media voice. {}.\n\
275 Your audience: {}.\
276 {voice_section}\
277 {content_section}\
278 {structure_section}\
279 {persona_section}\n\n\
280 Rules:\n\
281 - Write an educational thread of 5 to 8 tweets about the topic below.\n\
282 - Separate each tweet with a line containing only \"---\".\n\
283 - Each tweet must be under 280 characters.\n\
284 - The first tweet should hook the reader.\n\
285 - The last tweet should include a call to action or summary.\n\
286 - Do not use hashtags.",
287 self.business.product_name,
288 self.business.product_description,
289 self.business.target_audience,
290 );
291
292 let user_message = format!("Write a thread about: {topic}");
293
294 let params = GenerationParams {
295 max_tokens: 1500,
296 temperature: 0.7,
297 ..Default::default()
298 };
299
300 for attempt in 0..=MAX_THREAD_RETRIES {
301 let msg = if attempt == 0 {
302 user_message.clone()
303 } else {
304 format!(
305 "{user_message}\n\nIMPORTANT: Write exactly 5-8 tweets, \
306 each under 280 characters, separated by lines containing only \"---\"."
307 )
308 };
309
310 let resp = self.provider.complete(&system, &msg, ¶ms).await?;
311 let tweets = parse_thread(&resp.text);
312
313 if (5..=8).contains(&tweets.len())
314 && tweets.iter().all(|t| validate_length(t, MAX_TWEET_CHARS))
315 {
316 return Ok(tweets);
317 }
318 }
319
320 Err(LlmError::GenerationFailed(
321 "Failed to generate valid thread after retries".to_string(),
322 ))
323 }
324
325 fn format_persona_context(&self) -> String {
327 let mut parts = Vec::new();
328
329 if !self.business.persona_opinions.is_empty() {
330 let opinions = self.business.persona_opinions.join("; ");
331 parts.push(format!("Opinions you hold: {opinions}"));
332 }
333
334 if !self.business.persona_experiences.is_empty() {
335 let experiences = self.business.persona_experiences.join("; ");
336 parts.push(format!("Experiences you can reference: {experiences}"));
337 }
338
339 if !self.business.content_pillars.is_empty() {
340 let pillars = self.business.content_pillars.join(", ");
341 parts.push(format!("Content pillars: {pillars}"));
342 }
343
344 if parts.is_empty() {
345 String::new()
346 } else {
347 format!("\n{}", parts.join("\n"))
348 }
349 }
350}
351
352fn parse_thread(text: &str) -> Vec<String> {
356 let tweets: Vec<String> = text
358 .split("---")
359 .map(|s| s.trim().to_string())
360 .filter(|s| !s.is_empty())
361 .collect();
362
363 if !tweets.is_empty() && text.contains("---") {
364 return tweets;
365 }
366
367 let lines: Vec<&str> = text.lines().collect();
369 let mut tweets = Vec::new();
370 let mut current = String::new();
371
372 for line in &lines {
373 let trimmed = line.trim();
374 let is_numbered = trimmed.chars().next().is_some_and(|c| c.is_ascii_digit())
375 && (trimmed.contains('/')
376 || trimmed.starts_with(|c: char| c.is_ascii_digit())
377 && trimmed.chars().nth(1).is_some_and(|c| c == '.' || c == ')'));
378
379 if is_numbered && !current.is_empty() {
380 tweets.push(current.trim().to_string());
381 current = String::new();
382 }
383
384 if !trimmed.is_empty() {
385 if !current.is_empty() {
386 current.push(' ');
387 }
388 if is_numbered {
390 let content = trimmed
391 .find(|c: char| !c.is_ascii_digit() && c != '/' && c != '.' && c != ')')
392 .map(|i| trimmed[i..].trim_start())
393 .unwrap_or(trimmed);
394 current.push_str(content);
395 } else {
396 current.push_str(trimmed);
397 }
398 }
399 }
400
401 if !current.trim().is_empty() {
402 tweets.push(current.trim().to_string());
403 }
404
405 tweets
406}
407
408fn validate_length(text: &str, max_chars: usize) -> bool {
410 text.len() <= max_chars
411}
412
413fn truncate_at_sentence(text: &str, max_chars: usize) -> String {
418 if text.len() <= max_chars {
419 return text.to_string();
420 }
421
422 let search_area = &text[..max_chars];
423
424 let last_sentence_end = search_area
426 .rfind('.')
427 .max(search_area.rfind('!'))
428 .max(search_area.rfind('?'));
429
430 if let Some(pos) = last_sentence_end {
431 if pos > 0 {
432 return text[..=pos].trim().to_string();
433 }
434 }
435
436 let truncate_at = max_chars.saturating_sub(3);
438 let word_end = search_area[..truncate_at].rfind(' ').unwrap_or(truncate_at);
440 format!("{}...", &text[..word_end])
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::llm::{LlmResponse, TokenUsage};
447 use std::sync::atomic::{AtomicUsize, Ordering};
448 use std::sync::Arc;
449
450 struct MockProvider {
452 responses: Vec<String>,
453 call_count: Arc<AtomicUsize>,
454 }
455
456 impl MockProvider {
457 fn new(responses: Vec<String>) -> Self {
458 Self {
459 responses,
460 call_count: Arc::new(AtomicUsize::new(0)),
461 }
462 }
463
464 fn single(response: &str) -> Self {
465 Self::new(vec![response.to_string()])
466 }
467 }
468
469 #[async_trait::async_trait]
470 impl LlmProvider for MockProvider {
471 fn name(&self) -> &str {
472 "mock"
473 }
474
475 async fn complete(
476 &self,
477 _system: &str,
478 _user_message: &str,
479 _params: &GenerationParams,
480 ) -> Result<LlmResponse, LlmError> {
481 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
482 let text = self
483 .responses
484 .get(idx)
485 .cloned()
486 .unwrap_or_else(|| self.responses.last().cloned().unwrap_or_default());
487
488 Ok(LlmResponse {
489 text,
490 usage: TokenUsage::default(),
491 model: "mock".to_string(),
492 })
493 }
494
495 async fn health_check(&self) -> Result<(), LlmError> {
496 Ok(())
497 }
498 }
499
500 fn test_business() -> BusinessProfile {
501 BusinessProfile {
502 product_name: "TestApp".to_string(),
503 product_description: "A test application".to_string(),
504 product_url: Some("https://testapp.com".to_string()),
505 target_audience: "developers".to_string(),
506 product_keywords: vec!["test".to_string()],
507 competitor_keywords: vec![],
508 industry_topics: vec!["testing".to_string()],
509 brand_voice: None,
510 reply_style: None,
511 content_style: None,
512 persona_opinions: vec![],
513 persona_experiences: vec![],
514 content_pillars: vec![],
515 }
516 }
517
518 #[test]
521 fn validate_length_under_limit() {
522 assert!(validate_length("short text", 280));
523 }
524
525 #[test]
526 fn validate_length_at_limit() {
527 let text = "a".repeat(280);
528 assert!(validate_length(&text, 280));
529 }
530
531 #[test]
532 fn validate_length_over_limit() {
533 let text = "a".repeat(281);
534 assert!(!validate_length(&text, 280));
535 }
536
537 #[test]
540 fn truncate_at_sentence_under_limit() {
541 let text = "Short sentence.";
542 assert_eq!(truncate_at_sentence(text, 280), "Short sentence.");
543 }
544
545 #[test]
546 fn truncate_at_period() {
547 let text = "First sentence. Second sentence. Third sentence is very long and goes over the limit and more and more text.";
548 let result = truncate_at_sentence(text, 50);
549 assert!(result.len() <= 50);
550 assert!(result.ends_with('.'));
551 }
552
553 #[test]
554 fn truncate_at_question_mark() {
555 let text = "Is this working? I hope so because this text is getting very long and will exceed the character limit.";
556 let result = truncate_at_sentence(text, 20);
557 assert!(result.len() <= 20);
558 assert!(result.ends_with('?'));
559 }
560
561 #[test]
562 fn truncate_no_sentence_boundary() {
563 let text =
564 "This is a very long sentence without any punctuation that keeps going and going";
565 let result = truncate_at_sentence(text, 30);
566 assert!(result.len() <= 30);
567 assert!(result.ends_with("..."));
568 }
569
570 #[test]
573 fn parse_thread_with_dashes() {
574 let text = "Tweet one\n---\nTweet two\n---\nTweet three";
575 let tweets = parse_thread(text);
576 assert_eq!(tweets.len(), 3);
577 assert_eq!(tweets[0], "Tweet one");
578 assert_eq!(tweets[1], "Tweet two");
579 assert_eq!(tweets[2], "Tweet three");
580 }
581
582 #[test]
583 fn parse_thread_with_extra_whitespace() {
584 let text = " Tweet one \n---\n Tweet two \n---\n";
585 let tweets = parse_thread(text);
586 assert_eq!(tweets.len(), 2);
587 assert_eq!(tweets[0], "Tweet one");
588 assert_eq!(tweets[1], "Tweet two");
589 }
590
591 #[test]
592 fn parse_thread_single_block_falls_back_to_numbered() {
593 let text =
594 "1/5 First tweet\n2/5 Second tweet\n3/5 Third tweet\n4/5 Fourth tweet\n5/5 Fifth tweet";
595 let tweets = parse_thread(text);
596 assert!(
597 tweets.len() >= 2,
598 "got {} tweets: {:?}",
599 tweets.len(),
600 tweets
601 );
602 }
603
604 #[test]
605 fn parse_thread_empty_sections_filtered() {
606 let text = "---\n---\nActual tweet\n---\n---";
607 let tweets = parse_thread(text);
608 assert_eq!(tweets.len(), 1);
609 assert_eq!(tweets[0], "Actual tweet");
610 }
611
612 #[tokio::test]
615 async fn generate_reply_success() {
616 let provider =
617 MockProvider::single("Great point about testing! I've found similar results.");
618 let gen = ContentGenerator::new(Box::new(provider), test_business());
619
620 let reply = gen
621 .generate_reply("Testing is important", "devuser", true)
622 .await
623 .expect("reply");
624 assert!(reply.len() <= MAX_TWEET_CHARS);
625 assert!(!reply.is_empty());
626 }
627
628 #[tokio::test]
629 async fn generate_reply_truncates_long_output() {
630 let long_text = "a ".repeat(200); let provider = MockProvider::new(vec![long_text.clone(), long_text]);
632 let gen = ContentGenerator::new(Box::new(provider), test_business());
633
634 let reply = gen
635 .generate_reply("test", "user", true)
636 .await
637 .expect("reply");
638 assert!(reply.len() <= MAX_TWEET_CHARS);
639 }
640
641 #[tokio::test]
642 async fn generate_reply_no_product_mention() {
643 let provider = MockProvider::single("That's a great approach for productivity!");
644 let gen = ContentGenerator::new(Box::new(provider), test_business());
645
646 let reply = gen
647 .generate_reply("How do you stay productive?", "devuser", false)
648 .await
649 .expect("reply");
650 assert!(reply.len() <= MAX_TWEET_CHARS);
651 assert!(!reply.is_empty());
652 }
653
654 #[tokio::test]
657 async fn generate_tweet_success() {
658 let provider =
659 MockProvider::single("Testing your code early saves hours of debugging later.");
660 let gen = ContentGenerator::new(Box::new(provider), test_business());
661
662 let tweet = gen
663 .generate_tweet("testing best practices")
664 .await
665 .expect("tweet");
666 assert!(tweet.len() <= MAX_TWEET_CHARS);
667 assert!(!tweet.is_empty());
668 }
669
670 #[tokio::test]
673 async fn generate_thread_success() {
674 let thread_text = vec![
675 "Hook tweet here",
676 "---",
677 "Second point about testing",
678 "---",
679 "Third point about quality",
680 "---",
681 "Fourth point about CI/CD",
682 "---",
683 "Fifth point about automation",
684 "---",
685 "Summary and call to action",
686 ]
687 .join("\n");
688
689 let provider = MockProvider::single(&thread_text);
690 let gen = ContentGenerator::new(Box::new(provider), test_business());
691
692 let thread = gen.generate_thread("testing").await.expect("thread");
693 assert!(
694 (5..=8).contains(&thread.len()),
695 "got {} tweets",
696 thread.len()
697 );
698 for tweet in &thread {
699 assert!(tweet.len() <= MAX_TWEET_CHARS);
700 }
701 }
702
703 #[tokio::test]
704 async fn generate_thread_retries_on_bad_count() {
705 let bad = "Tweet one\n---\nTweet two";
707 let good = "One\n---\nTwo\n---\nThree\n---\nFour\n---\nFive";
708 let provider = MockProvider::new(vec![bad.into(), bad.into(), good.into()]);
709 let gen = ContentGenerator::new(Box::new(provider), test_business());
710
711 let thread = gen.generate_thread("topic").await.expect("thread");
712 assert_eq!(thread.len(), 5);
713 }
714
715 #[tokio::test]
716 async fn generate_thread_fails_after_max_retries() {
717 let bad = "Tweet one\n---\nTweet two";
718 let provider = MockProvider::new(vec![bad.into(), bad.into(), bad.into()]);
719 let gen = ContentGenerator::new(Box::new(provider), test_business());
720
721 let err = gen.generate_thread("topic").await.unwrap_err();
722 assert!(matches!(err, LlmError::GenerationFailed(_)));
723 }
724
725 #[test]
728 fn generation_params_default() {
729 let params = GenerationParams::default();
730 assert_eq!(params.max_tokens, 512);
731 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
732 assert!(params.system_prompt.is_none());
733 }
734}