Skip to main content

tuitbot_core/workflow/
draft.rs

1//! Draft step: fetch discovered tweets, generate LLM reply drafts, run safety checks.
2//!
3//! This is the second step in the reply pipeline: given scored candidates,
4//! produce draft reply text for human or automated review.
5
6use std::sync::Arc;
7
8use crate::content::frameworks::ReplyArchetype;
9use crate::context::winning_dna;
10use crate::llm::LlmProvider;
11use crate::safety::{contains_banned_phrase, DedupChecker};
12use crate::storage;
13use crate::storage::DbPool;
14
15use super::{make_content_gen, parse_archetype, DraftResult, WorkflowError};
16
17/// Input for the draft step.
18#[derive(Debug, Clone)]
19pub struct DraftInput {
20    /// Tweet IDs to generate drafts for (must exist in discovery DB).
21    pub candidate_ids: Vec<String>,
22    /// Override archetype for all drafts (e.g., "ask_question").
23    pub archetype: Option<String>,
24    /// Whether to mention the product in the reply.
25    pub mention_product: bool,
26}
27
28/// Execute the draft step: fetch tweets, generate replies, check safety.
29///
30/// Returns one `DraftResult` per candidate. Individual failures don't
31/// abort the batch — they produce `DraftResult::Error` entries.
32pub async fn execute(
33    db: &DbPool,
34    llm: &Arc<dyn LlmProvider>,
35    config: &crate::config::Config,
36    input: DraftInput,
37) -> Result<Vec<DraftResult>, WorkflowError> {
38    if input.candidate_ids.is_empty() {
39        return Err(WorkflowError::InvalidInput(
40            "candidate_ids must not be empty.".to_string(),
41        ));
42    }
43
44    let archetype_override: Option<ReplyArchetype> =
45        input.archetype.as_deref().and_then(parse_archetype);
46
47    let gen = make_content_gen(llm, &config.business);
48    let dedup = DedupChecker::new(db.clone());
49    let banned = &config.limits.banned_phrases;
50
51    // Build RAG context from winning ancestors + content seeds (one DB call, shared)
52    let topic_keywords = config.business.draft_context_keywords();
53
54    let rag_context = winning_dna::build_draft_context(
55        db,
56        &topic_keywords,
57        winning_dna::MAX_ANCESTORS,
58        winning_dna::RECENCY_HALF_LIFE_DAYS,
59    )
60    .await
61    .ok();
62
63    let rag_prompt = rag_context
64        .as_ref()
65        .map(|ctx| ctx.prompt_block.as_str())
66        .filter(|s| !s.is_empty());
67
68    let mut results = Vec::with_capacity(input.candidate_ids.len());
69
70    for candidate_id in &input.candidate_ids {
71        // Fetch tweet from DB
72        let tweet = match storage::tweets::get_tweet_by_id(db, candidate_id).await {
73            Ok(Some(t)) => t,
74            Ok(None) => {
75                results.push(DraftResult::Error {
76                    candidate_id: candidate_id.clone(),
77                    error_code: "not_found".to_string(),
78                    error_message: format!("Tweet {candidate_id} not found in discovery DB."),
79                });
80                continue;
81            }
82            Err(e) => {
83                results.push(DraftResult::Error {
84                    candidate_id: candidate_id.clone(),
85                    error_code: "db_error".to_string(),
86                    error_message: format!("DB error fetching tweet: {e}"),
87                });
88                continue;
89            }
90        };
91
92        // Generate reply via ContentGenerator with optional RAG context
93        let gen_result = gen
94            .generate_reply_with_context(
95                &tweet.content,
96                &tweet.author_username,
97                input.mention_product,
98                archetype_override,
99                rag_prompt,
100            )
101            .await;
102
103        let output = match gen_result {
104            Ok(o) => o,
105            Err(e) => {
106                results.push(DraftResult::Error {
107                    candidate_id: candidate_id.clone(),
108                    error_code: "llm_error".to_string(),
109                    error_message: format!("LLM generation failed: {e}"),
110                });
111                continue;
112            }
113        };
114
115        let draft_text = output.text;
116        let char_count = draft_text.len();
117
118        // Confidence heuristic
119        let confidence = if char_count < 200 {
120            "high"
121        } else if char_count < 260 {
122            "medium"
123        } else {
124            "low"
125        };
126
127        // Risk checks
128        let mut risks = Vec::new();
129        if let Some(phrase) = contains_banned_phrase(&draft_text, banned) {
130            risks.push(format!("contains_banned_phrase: {phrase}"));
131        }
132        if let Ok(true) = dedup.is_phrasing_similar(&draft_text, 20).await {
133            risks.push("similar_to_recent_reply".to_string());
134        }
135
136        let archetype_name = archetype_override
137            .map(|a| format!("{a:?}"))
138            .unwrap_or_else(|| "auto_selected".to_string());
139
140        results.push(DraftResult::Success {
141            candidate_id: candidate_id.clone(),
142            draft_text,
143            archetype: archetype_name,
144            char_count,
145            confidence: confidence.to_string(),
146            risks,
147        });
148    }
149
150    Ok(results)
151}