1pub(crate) mod parser;
7
8#[cfg(test)]
9mod tests;
10
11use crate::config::BusinessProfile;
12use crate::content::frameworks::{ReplyArchetype, ThreadStructure, TweetFormat};
13use crate::content::length::{truncate_at_sentence, validate_tweet_length, MAX_TWEET_CHARS};
14use crate::error::LlmError;
15use crate::llm::{GenerationParams, LlmProvider, TokenUsage};
16
17use parser::parse_thread;
18
19#[derive(Debug, Clone)]
21pub struct GenerationOutput {
22 pub text: String,
24 pub usage: TokenUsage,
26 pub model: String,
28 pub provider: String,
30}
31
32#[derive(Debug, Clone)]
34pub struct ThreadGenerationOutput {
35 pub tweets: Vec<String>,
37 pub usage: TokenUsage,
39 pub model: String,
41 pub provider: String,
43}
44
45const MAX_THREAD_RETRIES: u32 = 2;
47
48pub struct ContentGenerator {
50 provider: Box<dyn LlmProvider>,
51 business: BusinessProfile,
52}
53
54impl ContentGenerator {
55 pub fn new(provider: Box<dyn LlmProvider>, business: BusinessProfile) -> Self {
57 Self { provider, business }
58 }
59
60 pub fn business(&self) -> &BusinessProfile {
62 &self.business
63 }
64
65 pub async fn generate_reply(
71 &self,
72 tweet_text: &str,
73 tweet_author: &str,
74 mention_product: bool,
75 ) -> Result<GenerationOutput, LlmError> {
76 self.generate_reply_inner(tweet_text, tweet_author, mention_product, None, None)
77 .await
78 }
79
80 pub async fn generate_reply_with_archetype(
82 &self,
83 tweet_text: &str,
84 tweet_author: &str,
85 mention_product: bool,
86 archetype: Option<ReplyArchetype>,
87 ) -> Result<GenerationOutput, LlmError> {
88 self.generate_reply_inner(tweet_text, tweet_author, mention_product, archetype, None)
89 .await
90 }
91
92 pub async fn generate_reply_with_context(
94 &self,
95 tweet_text: &str,
96 tweet_author: &str,
97 mention_product: bool,
98 archetype: Option<ReplyArchetype>,
99 rag_context: Option<&str>,
100 ) -> Result<GenerationOutput, LlmError> {
101 self.generate_reply_inner(
102 tweet_text,
103 tweet_author,
104 mention_product,
105 archetype,
106 rag_context,
107 )
108 .await
109 }
110
111 async fn generate_reply_inner(
113 &self,
114 tweet_text: &str,
115 tweet_author: &str,
116 mention_product: bool,
117 archetype: Option<ReplyArchetype>,
118 rag_context: Option<&str>,
119 ) -> Result<GenerationOutput, LlmError> {
120 tracing::debug!(
121 author = %tweet_author,
122 archetype = ?archetype,
123 mention_product = mention_product,
124 has_rag_context = rag_context.is_some(),
125 "Generating reply",
126 );
127
128 let voice_section = self.format_voice_section();
129 let reply_section = match &self.business.reply_style {
130 Some(s) if !s.is_empty() => format!("\nReply style: {s}"),
131 _ => "\nReply style: Be conversational and helpful, not salesy. Sound like a real person, not a bot.".to_string(),
132 };
133 let archetype_section = match archetype {
134 Some(a) => format!("\n{}", a.prompt_fragment()),
135 None => String::new(),
136 };
137 let persona_section = self.format_persona_context();
138 let rag_section = Self::format_rag_section(rag_context);
139 let audience_section = self.format_audience_section();
140
141 let system = if mention_product {
142 let product_url = self.business.product_url.as_deref().unwrap_or("");
143 format!(
144 "You are a helpful community member who uses {} ({}).\
145 {audience_section}\n\
146 Product URL: {}\
147 {voice_section}\
148 {reply_section}\
149 {archetype_section}\
150 {persona_section}\
151 {rag_section}\n\n\
152 Rules:\n\
153 - Write a reply to the tweet below.\n\
154 - Maximum 3 sentences.\n\
155 - Only mention {} if it is genuinely relevant to the tweet's topic.\n\
156 - Do not use hashtags.\n\
157 - Do not use emojis excessively.",
158 self.business.product_name,
159 self.business.product_description,
160 product_url,
161 self.business.product_name,
162 )
163 } else {
164 format!(
165 "You are a helpful community member.\
166 {audience_section}\
167 {voice_section}\
168 {reply_section}\
169 {archetype_section}\
170 {persona_section}\
171 {rag_section}\n\n\
172 Rules:\n\
173 - Write a reply to the tweet below.\n\
174 - Maximum 3 sentences.\n\
175 - Do NOT mention {} or any product. Just be genuinely helpful.\n\
176 - Do not use hashtags.\n\
177 - Do not use emojis excessively.",
178 self.business.product_name,
179 )
180 };
181
182 let user_message = format!("Tweet by @{tweet_author}: {tweet_text}");
183 let params = GenerationParams {
184 max_tokens: 200,
185 temperature: 0.7,
186 ..Default::default()
187 };
188
189 self.generate_single(&system, &user_message, ¶ms).await
190 }
191
192 pub async fn generate_tweet(&self, topic: &str) -> Result<GenerationOutput, LlmError> {
198 self.generate_tweet_inner(topic, None, None).await
199 }
200
201 pub async fn generate_tweet_with_format(
203 &self,
204 topic: &str,
205 format: Option<TweetFormat>,
206 ) -> Result<GenerationOutput, LlmError> {
207 self.generate_tweet_inner(topic, format, None).await
208 }
209
210 pub async fn generate_tweet_with_context(
212 &self,
213 topic: &str,
214 format: Option<TweetFormat>,
215 rag_context: Option<&str>,
216 ) -> Result<GenerationOutput, LlmError> {
217 self.generate_tweet_inner(topic, format, rag_context).await
218 }
219
220 async fn generate_tweet_inner(
222 &self,
223 topic: &str,
224 format: Option<TweetFormat>,
225 rag_context: Option<&str>,
226 ) -> Result<GenerationOutput, LlmError> {
227 tracing::debug!(
228 topic = %topic,
229 format = ?format,
230 has_rag_context = rag_context.is_some(),
231 "Generating tweet",
232 );
233
234 let voice_section = self.format_voice_section();
235 let content_section = match &self.business.content_style {
236 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
237 _ => "\nContent style: Be informative and engaging.".to_string(),
238 };
239 let format_section = match format {
240 Some(f) => format!("\n{}", f.prompt_fragment()),
241 None => String::new(),
242 };
243 let persona_section = self.format_persona_context();
244 let rag_section = Self::format_rag_section(rag_context);
245 let audience_section = self.format_audience_section();
246
247 let system = format!(
248 "You are {}'s social media voice. {}.\
249 {audience_section}\
250 {voice_section}\
251 {content_section}\
252 {format_section}\
253 {persona_section}\
254 {rag_section}\n\n\
255 Rules:\n\
256 - Write a single educational tweet about the topic below.\n\
257 - Maximum 280 characters.\n\
258 - Do not use hashtags.\n\
259 - Do not mention {} directly unless it is central to the topic.",
260 self.business.product_name,
261 self.business.product_description,
262 self.business.product_name,
263 );
264
265 let user_message = format!("Write a tweet about: {topic}");
266 let params = GenerationParams {
267 max_tokens: 150,
268 temperature: 0.8,
269 ..Default::default()
270 };
271
272 self.generate_single(&system, &user_message, ¶ms).await
273 }
274
275 pub async fn improve_draft(
281 &self,
282 draft: &str,
283 tone_cue: Option<&str>,
284 ) -> Result<GenerationOutput, LlmError> {
285 self.improve_draft_inner(draft, tone_cue, None).await
286 }
287
288 pub async fn improve_draft_with_context(
291 &self,
292 draft: &str,
293 tone_cue: Option<&str>,
294 rag_context: Option<&str>,
295 ) -> Result<GenerationOutput, LlmError> {
296 self.improve_draft_inner(draft, tone_cue, rag_context).await
297 }
298
299 async fn improve_draft_inner(
301 &self,
302 draft: &str,
303 tone_cue: Option<&str>,
304 rag_context: Option<&str>,
305 ) -> Result<GenerationOutput, LlmError> {
306 tracing::debug!(
307 draft_len = draft.len(),
308 tone_cue = ?tone_cue,
309 has_rag_context = rag_context.is_some(),
310 "Improving draft",
311 );
312
313 let voice_section = self.format_voice_section();
314 let persona_section = self.format_persona_context();
315 let rag_section = Self::format_rag_section(rag_context);
316
317 let tone_instruction = match tone_cue {
318 Some(cue) if !cue.is_empty() => {
319 format!("\n\nTone/style directive (MUST follow): {cue}")
320 }
321 _ => String::new(),
322 };
323
324 let system = format!(
325 "You are {}'s social media voice. {}.\
326 {voice_section}\
327 {persona_section}\
328 {rag_section}\n\n\
329 Task: Rewrite and improve the draft tweet below. \
330 Keep the core message but make it sharper, more engaging, \
331 and better-written.{tone_instruction}\n\n\
332 Rules:\n\
333 - Maximum 280 characters.\n\
334 - Do not use hashtags.\n\
335 - Output only the improved tweet text, nothing else.",
336 self.business.product_name, self.business.product_description,
337 );
338
339 let user_message = format!("Draft to improve:\n{draft}");
340 let params = GenerationParams {
341 max_tokens: 150,
342 temperature: 0.7,
343 ..Default::default()
344 };
345
346 self.generate_single(&system, &user_message, ¶ms).await
347 }
348
349 pub async fn generate_thread(&self, topic: &str) -> Result<ThreadGenerationOutput, LlmError> {
355 self.generate_thread_inner(topic, None, None).await
356 }
357
358 pub async fn generate_thread_with_structure(
360 &self,
361 topic: &str,
362 structure: Option<ThreadStructure>,
363 ) -> Result<ThreadGenerationOutput, LlmError> {
364 self.generate_thread_inner(topic, structure, None).await
365 }
366
367 pub async fn generate_thread_with_context(
369 &self,
370 topic: &str,
371 structure: Option<ThreadStructure>,
372 rag_context: Option<&str>,
373 ) -> Result<ThreadGenerationOutput, LlmError> {
374 self.generate_thread_inner(topic, structure, rag_context)
375 .await
376 }
377
378 async fn generate_thread_inner(
380 &self,
381 topic: &str,
382 structure: Option<ThreadStructure>,
383 rag_context: Option<&str>,
384 ) -> Result<ThreadGenerationOutput, LlmError> {
385 tracing::debug!(
386 topic = %topic,
387 structure = ?structure,
388 has_rag_context = rag_context.is_some(),
389 "Generating thread",
390 );
391
392 let voice_section = self.format_voice_section();
393 let content_section = match &self.business.content_style {
394 Some(s) if !s.is_empty() => format!("\nContent style: {s}"),
395 _ => "\nContent style: Be informative, not promotional.".to_string(),
396 };
397 let structure_section = match structure {
398 Some(s) => format!("\n{}", s.prompt_fragment()),
399 None => String::new(),
400 };
401 let persona_section = self.format_persona_context();
402 let rag_section = Self::format_rag_section(rag_context);
403 let audience_section = self.format_audience_section();
404
405 let system = format!(
406 "You are {}'s social media voice. {}.\
407 {audience_section}\
408 {voice_section}\
409 {content_section}\
410 {structure_section}\
411 {persona_section}\
412 {rag_section}\n\n\
413 Rules:\n\
414 - Write an educational thread of 5 to 8 tweets about the topic below.\n\
415 - Separate each tweet with a line containing only \"---\".\n\
416 - Each tweet must be under 280 characters.\n\
417 - The first tweet should hook the reader.\n\
418 - The last tweet should include a call to action or summary.\n\
419 - Do not use hashtags.",
420 self.business.product_name, self.business.product_description,
421 );
422
423 let user_message = format!("Write a thread about: {topic}");
424 let params = GenerationParams {
425 max_tokens: 1500,
426 temperature: 0.7,
427 ..Default::default()
428 };
429
430 let mut usage = TokenUsage::default();
431 let provider_name = self.provider.name().to_string();
432 let mut model = String::new();
433
434 for attempt in 0..=MAX_THREAD_RETRIES {
435 let msg = if attempt == 0 {
436 user_message.clone()
437 } else {
438 format!(
439 "{user_message}\n\nIMPORTANT: Write exactly 5-8 tweets, \
440 each under 280 characters, separated by lines containing only \"---\"."
441 )
442 };
443
444 let resp = self.provider.complete(&system, &msg, ¶ms).await?;
445 usage.accumulate(&resp.usage);
446 model.clone_from(&resp.model);
447 let tweets = parse_thread(&resp.text);
448
449 if (5..=8).contains(&tweets.len())
450 && tweets
451 .iter()
452 .all(|t| validate_tweet_length(t, MAX_TWEET_CHARS))
453 {
454 return Ok(ThreadGenerationOutput {
455 tweets,
456 usage,
457 model,
458 provider: provider_name,
459 });
460 }
461 }
462
463 Err(LlmError::GenerationFailed(
464 "Failed to generate valid thread after retries".to_string(),
465 ))
466 }
467
468 async fn generate_single(
474 &self,
475 system: &str,
476 user_message: &str,
477 params: &GenerationParams,
478 ) -> Result<GenerationOutput, LlmError> {
479 let resp = self.provider.complete(system, user_message, params).await?;
480 let mut usage = resp.usage.clone();
481 let provider_name = self.provider.name().to_string();
482 let model = resp.model.clone();
483 let text = resp.text.trim().to_string();
484
485 tracing::debug!(chars = text.len(), "Generated content");
486
487 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
488 return Ok(GenerationOutput {
489 text,
490 usage,
491 model,
492 provider: provider_name,
493 });
494 }
495
496 let retry_msg = format!(
498 "{user_message}\n\nImportant: Your response MUST be under 280 characters. Be more concise."
499 );
500 let resp = self.provider.complete(system, &retry_msg, params).await?;
501 usage.accumulate(&resp.usage);
502 let text = resp.text.trim().to_string();
503
504 if validate_tweet_length(&text, MAX_TWEET_CHARS) {
505 return Ok(GenerationOutput {
506 text,
507 usage,
508 model,
509 provider: provider_name,
510 });
511 }
512
513 Ok(GenerationOutput {
515 text: truncate_at_sentence(&text, MAX_TWEET_CHARS),
516 usage,
517 model,
518 provider: provider_name,
519 })
520 }
521
522 fn format_voice_section(&self) -> String {
523 match &self.business.brand_voice {
524 Some(v) if !v.is_empty() => format!("\nVoice & personality: {v}"),
525 _ => String::new(),
526 }
527 }
528
529 fn format_audience_section(&self) -> String {
530 if self.business.target_audience.is_empty() {
531 String::new()
532 } else {
533 format!("\nYour audience: {}.", self.business.target_audience)
534 }
535 }
536
537 fn format_rag_section(rag_context: Option<&str>) -> String {
538 match rag_context {
539 Some(ctx) if !ctx.is_empty() => format!("\n{ctx}"),
540 _ => String::new(),
541 }
542 }
543
544 fn format_persona_context(&self) -> String {
546 let mut parts = Vec::new();
547
548 if !self.business.persona_opinions.is_empty() {
549 let opinions = self.business.persona_opinions.join("; ");
550 parts.push(format!("Opinions you hold: {opinions}"));
551 }
552
553 if !self.business.persona_experiences.is_empty() {
554 let experiences = self.business.persona_experiences.join("; ");
555 parts.push(format!("Experiences you can reference: {experiences}"));
556 }
557
558 if !self.business.content_pillars.is_empty() {
559 let pillars = self.business.content_pillars.join(", ");
560 parts.push(format!("Content pillars: {pillars}"));
561 }
562
563 if parts.is_empty() {
564 String::new()
565 } else {
566 format!("\n{}", parts.join("\n"))
567 }
568 }
569}