Skip to main content

tuitbot_core/workflow/
draft.rs

1//! Draft step: fetch discovered tweets, generate LLM reply drafts, run safety checks.
2//!
3//! This is the second step in the reply pipeline: given scored candidates,
4//! produce draft reply text for human or automated review.
5
6use std::sync::Arc;
7
8use crate::content::frameworks::ReplyArchetype;
9use crate::context::winning_dna;
10use crate::llm::LlmProvider;
11use crate::safety::{contains_banned_phrase, DedupChecker};
12use crate::storage;
13use crate::storage::DbPool;
14
15use super::{make_content_gen, parse_archetype, DraftResult, WorkflowError};
16
17/// Input for the draft step.
18#[derive(Debug, Clone)]
19pub struct DraftInput {
20    /// Tweet IDs to generate drafts for (must exist in discovery DB).
21    pub candidate_ids: Vec<String>,
22    /// Override archetype for all drafts (e.g., "ask_question").
23    pub archetype: Option<String>,
24    /// Whether to mention the product in the reply.
25    pub mention_product: bool,
26}
27
28/// Execute the draft step: fetch tweets, generate replies, check safety.
29///
30/// Returns one `DraftResult` per candidate. Individual failures don't
31/// abort the batch — they produce `DraftResult::Error` entries.
32pub async fn execute(
33    db: &DbPool,
34    llm: &Arc<dyn LlmProvider>,
35    config: &crate::config::Config,
36    input: DraftInput,
37) -> Result<Vec<DraftResult>, WorkflowError> {
38    if input.candidate_ids.is_empty() {
39        return Err(WorkflowError::InvalidInput(
40            "candidate_ids must not be empty.".to_string(),
41        ));
42    }
43
44    let archetype_override: Option<ReplyArchetype> =
45        input.archetype.as_deref().and_then(parse_archetype);
46
47    let gen = make_content_gen(llm, &config.business);
48    let dedup = DedupChecker::new(db.clone());
49    let banned = &config.limits.banned_phrases;
50
51    // Build RAG context from winning ancestors + content seeds (one DB call, shared)
52    let mut topic_keywords: Vec<String> = config.business.product_keywords.clone();
53    topic_keywords.extend(config.business.competitor_keywords.clone());
54    topic_keywords.extend(config.business.effective_industry_topics().to_vec());
55
56    let rag_context = winning_dna::build_draft_context(
57        db,
58        &topic_keywords,
59        winning_dna::MAX_ANCESTORS,
60        winning_dna::RECENCY_HALF_LIFE_DAYS,
61    )
62    .await
63    .ok();
64
65    let rag_prompt = rag_context
66        .as_ref()
67        .map(|ctx| ctx.prompt_block.as_str())
68        .filter(|s| !s.is_empty());
69
70    let mut results = Vec::with_capacity(input.candidate_ids.len());
71
72    for candidate_id in &input.candidate_ids {
73        // Fetch tweet from DB
74        let tweet = match storage::tweets::get_tweet_by_id(db, candidate_id).await {
75            Ok(Some(t)) => t,
76            Ok(None) => {
77                results.push(DraftResult::Error {
78                    candidate_id: candidate_id.clone(),
79                    error_code: "not_found".to_string(),
80                    error_message: format!("Tweet {candidate_id} not found in discovery DB."),
81                });
82                continue;
83            }
84            Err(e) => {
85                results.push(DraftResult::Error {
86                    candidate_id: candidate_id.clone(),
87                    error_code: "db_error".to_string(),
88                    error_message: format!("DB error fetching tweet: {e}"),
89                });
90                continue;
91            }
92        };
93
94        // Generate reply via ContentGenerator with optional RAG context
95        let gen_result = gen
96            .generate_reply_with_context(
97                &tweet.content,
98                &tweet.author_username,
99                input.mention_product,
100                archetype_override,
101                rag_prompt,
102            )
103            .await;
104
105        let output = match gen_result {
106            Ok(o) => o,
107            Err(e) => {
108                results.push(DraftResult::Error {
109                    candidate_id: candidate_id.clone(),
110                    error_code: "llm_error".to_string(),
111                    error_message: format!("LLM generation failed: {e}"),
112                });
113                continue;
114            }
115        };
116
117        let draft_text = output.text;
118        let char_count = draft_text.len();
119
120        // Confidence heuristic
121        let confidence = if char_count < 200 {
122            "high"
123        } else if char_count < 260 {
124            "medium"
125        } else {
126            "low"
127        };
128
129        // Risk checks
130        let mut risks = Vec::new();
131        if let Some(phrase) = contains_banned_phrase(&draft_text, banned) {
132            risks.push(format!("contains_banned_phrase: {phrase}"));
133        }
134        if let Ok(true) = dedup.is_phrasing_similar(&draft_text, 20).await {
135            risks.push("similar_to_recent_reply".to_string());
136        }
137
138        let archetype_name = archetype_override
139            .map(|a| format!("{a:?}"))
140            .unwrap_or_else(|| "auto_selected".to_string());
141
142        results.push(DraftResult::Success {
143            candidate_id: candidate_id.clone(),
144            draft_text,
145            archetype: archetype_name,
146            char_count,
147            confidence: confidence.to_string(),
148            risks,
149        });
150    }
151
152    Ok(results)
153}