Skip to main content

tuitbot_core/workflow/
draft.rs

1//! Draft step: fetch discovered tweets, generate LLM reply drafts, run safety checks.
2//!
3//! This is the second step in the reply pipeline: given scored candidates,
4//! produce draft reply text for human or automated review.
5
6use std::sync::Arc;
7
8use crate::content::frameworks::ReplyArchetype;
9use crate::context::retrieval::VaultCitation;
10use crate::context::winning_dna;
11use crate::llm::LlmProvider;
12use crate::safety::{contains_banned_phrase, DedupChecker};
13use crate::storage;
14use crate::storage::DbPool;
15
16use super::{make_content_gen, parse_archetype, DraftResult, WorkflowError};
17
18/// Input for the draft step.
19#[derive(Debug, Clone)]
20pub struct DraftInput {
21    /// Tweet IDs to generate drafts for (must exist in discovery DB).
22    pub candidate_ids: Vec<String>,
23    /// Override archetype for all drafts (e.g., "ask_question").
24    pub archetype: Option<String>,
25    /// Whether to mention the product in the reply.
26    pub mention_product: bool,
27    /// Account ID for scoping RAG context retrieval.
28    pub account_id: Option<String>,
29}
30
31/// Execute the draft step: fetch tweets, generate replies, check safety.
32///
33/// Returns one `DraftResult` per candidate. Individual failures don't
34/// abort the batch — they produce `DraftResult::Error` entries.
35pub async fn execute(
36    db: &DbPool,
37    llm: &Arc<dyn LlmProvider>,
38    config: &crate::config::Config,
39    input: DraftInput,
40) -> Result<Vec<DraftResult>, WorkflowError> {
41    if input.candidate_ids.is_empty() {
42        return Err(WorkflowError::InvalidInput(
43            "candidate_ids must not be empty.".to_string(),
44        ));
45    }
46
47    let archetype_override: Option<ReplyArchetype> =
48        input.archetype.as_deref().and_then(parse_archetype);
49
50    let gen = make_content_gen(llm, &config.business);
51    let dedup = DedupChecker::new(db.clone());
52    let banned = &config.limits.banned_phrases;
53
54    // Build RAG context from winning ancestors + content seeds (one DB call, shared)
55    let topic_keywords = config.business.draft_context_keywords();
56
57    let account_id = input
58        .account_id
59        .as_deref()
60        .unwrap_or(crate::storage::accounts::DEFAULT_ACCOUNT_ID);
61
62    let rag_context = winning_dna::build_draft_context(
63        db,
64        account_id,
65        &topic_keywords,
66        winning_dna::MAX_ANCESTORS,
67        winning_dna::RECENCY_HALF_LIFE_DAYS,
68    )
69    .await
70    .ok();
71
72    let vault_citations: Vec<VaultCitation> = rag_context
73        .as_ref()
74        .map(|ctx| ctx.vault_citations.clone())
75        .unwrap_or_default();
76
77    let rag_prompt = rag_context
78        .as_ref()
79        .map(|ctx| ctx.prompt_block.as_str())
80        .filter(|s| !s.is_empty());
81
82    let mut results = Vec::with_capacity(input.candidate_ids.len());
83
84    for candidate_id in &input.candidate_ids {
85        // Fetch tweet from DB
86        let tweet = match storage::tweets::get_tweet_by_id(db, candidate_id).await {
87            Ok(Some(t)) => t,
88            Ok(None) => {
89                results.push(DraftResult::Error {
90                    candidate_id: candidate_id.clone(),
91                    error_code: "not_found".to_string(),
92                    error_message: format!("Tweet {candidate_id} not found in discovery DB."),
93                });
94                continue;
95            }
96            Err(e) => {
97                results.push(DraftResult::Error {
98                    candidate_id: candidate_id.clone(),
99                    error_code: "db_error".to_string(),
100                    error_message: format!("DB error fetching tweet: {e}"),
101                });
102                continue;
103            }
104        };
105
106        // Generate reply via ContentGenerator with optional RAG context
107        let gen_result = gen
108            .generate_reply_with_context(
109                &tweet.content,
110                &tweet.author_username,
111                input.mention_product,
112                archetype_override,
113                rag_prompt,
114            )
115            .await;
116
117        let output = match gen_result {
118            Ok(o) => o,
119            Err(e) => {
120                results.push(DraftResult::Error {
121                    candidate_id: candidate_id.clone(),
122                    error_code: "llm_error".to_string(),
123                    error_message: format!("LLM generation failed: {e}"),
124                });
125                continue;
126            }
127        };
128
129        let draft_text = output.text;
130        let char_count = draft_text.len();
131
132        // Confidence heuristic
133        let confidence = if char_count < 200 {
134            "high"
135        } else if char_count < 260 {
136            "medium"
137        } else {
138            "low"
139        };
140
141        // Risk checks
142        let mut risks = Vec::new();
143        if let Some(phrase) = contains_banned_phrase(&draft_text, banned) {
144            risks.push(format!("contains_banned_phrase: {phrase}"));
145        }
146        if let Ok(true) = dedup.is_phrasing_similar(&draft_text, 20).await {
147            risks.push("similar_to_recent_reply".to_string());
148        }
149
150        let archetype_name = archetype_override
151            .map(|a| format!("{a:?}"))
152            .unwrap_or_else(|| "auto_selected".to_string());
153
154        results.push(DraftResult::Success {
155            candidate_id: candidate_id.clone(),
156            draft_text,
157            archetype: archetype_name,
158            char_count,
159            confidence: confidence.to_string(),
160            risks,
161            vault_citations: vault_citations.clone(),
162        });
163    }
164
165    Ok(results)
166}