Skip to main content

tuitbot_core/automation/adapters/
llm.rs

1//! LLM adapter implementations.
2
3use std::sync::Arc;
4
5use super::super::loop_helpers::{
6    ContentLoopError, LoopError, ReplyGenerator, ReplyOutput, TweetGenerator,
7};
8use super::super::thread_loop::ThreadGenerator;
9use super::helpers::{llm_to_content_error, llm_to_loop_error};
10use crate::content::ContentGenerator;
11use crate::storage::DbPool;
12
13/// Record LLM usage to the database (fire-and-forget).
14pub(super) async fn record_llm_usage(
15    pool: &DbPool,
16    generation_type: &str,
17    provider: &str,
18    model: &str,
19    input_tokens: u32,
20    output_tokens: u32,
21) {
22    let pricing = crate::llm::pricing::lookup(provider, model);
23    let cost = pricing.compute_cost(input_tokens, output_tokens);
24    if let Err(e) = crate::storage::llm_usage::insert_llm_usage(
25        pool,
26        generation_type,
27        provider,
28        model,
29        input_tokens,
30        output_tokens,
31        cost,
32    )
33    .await
34    {
35        tracing::warn!(error = %e, "Failed to record LLM usage");
36    }
37}
38
39/// Adapts `ContentGenerator` to the `ReplyGenerator` port trait.
40pub struct LlmReplyAdapter {
41    generator: Arc<ContentGenerator>,
42    pool: DbPool,
43}
44
45impl LlmReplyAdapter {
46    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
47        Self { generator, pool }
48    }
49}
50
51#[async_trait::async_trait]
52impl ReplyGenerator for LlmReplyAdapter {
53    async fn generate_reply(
54        &self,
55        tweet_text: &str,
56        author: &str,
57        mention_product: bool,
58    ) -> Result<String, LoopError> {
59        let output = self
60            .generator
61            .generate_reply(tweet_text, author, mention_product)
62            .await
63            .map_err(llm_to_loop_error)?;
64        record_llm_usage(
65            &self.pool,
66            "reply",
67            &output.provider,
68            &output.model,
69            output.usage.input_tokens,
70            output.usage.output_tokens,
71        )
72        .await;
73        Ok(output.text)
74    }
75}
76
77/// Vault-aware reply adapter that injects pre-built RAG context into replies.
78///
79/// The RAG prompt is built once at construction time (by the server/CLI wiring
80/// layer) and reused for every reply, avoiding per-tweet DB queries.
81pub struct VaultAwareLlmReplyAdapter {
82    generator: Arc<ContentGenerator>,
83    pool: DbPool,
84    /// Pre-built RAG prompt block to inject into every reply.
85    rag_prompt: Option<String>,
86    /// Pre-built vault citations corresponding to the RAG prompt.
87    vault_citations: Vec<crate::context::retrieval::VaultCitation>,
88}
89
90impl VaultAwareLlmReplyAdapter {
91    pub fn new(
92        generator: Arc<ContentGenerator>,
93        pool: DbPool,
94        rag_prompt: Option<String>,
95        vault_citations: Vec<crate::context::retrieval::VaultCitation>,
96    ) -> Self {
97        Self {
98            generator,
99            pool,
100            rag_prompt,
101            vault_citations,
102        }
103    }
104}
105
106#[async_trait::async_trait]
107impl ReplyGenerator for VaultAwareLlmReplyAdapter {
108    async fn generate_reply(
109        &self,
110        tweet_text: &str,
111        author: &str,
112        mention_product: bool,
113    ) -> Result<String, LoopError> {
114        let output = self
115            .generator
116            .generate_reply_with_context(
117                tweet_text,
118                author,
119                mention_product,
120                None,
121                self.rag_prompt.as_deref(),
122            )
123            .await
124            .map_err(llm_to_loop_error)?;
125        record_llm_usage(
126            &self.pool,
127            "reply",
128            &output.provider,
129            &output.model,
130            output.usage.input_tokens,
131            output.usage.output_tokens,
132        )
133        .await;
134        Ok(output.text)
135    }
136
137    async fn generate_reply_with_rag(
138        &self,
139        tweet_text: &str,
140        author: &str,
141        mention_product: bool,
142    ) -> Result<ReplyOutput, LoopError> {
143        let text = self
144            .generate_reply(tweet_text, author, mention_product)
145            .await?;
146        Ok(ReplyOutput {
147            text,
148            vault_citations: self.vault_citations.clone(),
149        })
150    }
151}
152
153/// Adapts `ContentGenerator` to the `TweetGenerator` port trait.
154pub struct LlmTweetAdapter {
155    generator: Arc<ContentGenerator>,
156    pool: DbPool,
157}
158
159impl LlmTweetAdapter {
160    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
161        Self { generator, pool }
162    }
163}
164
165#[async_trait::async_trait]
166impl TweetGenerator for LlmTweetAdapter {
167    async fn generate_tweet(&self, topic: &str) -> Result<String, ContentLoopError> {
168        let output = self
169            .generator
170            .generate_tweet(topic)
171            .await
172            .map_err(llm_to_content_error)?;
173        record_llm_usage(
174            &self.pool,
175            "tweet",
176            &output.provider,
177            &output.model,
178            output.usage.input_tokens,
179            output.usage.output_tokens,
180        )
181        .await;
182        Ok(output.text)
183    }
184}
185
186/// Adapts `ContentGenerator` to the `ThreadGenerator` port trait.
187pub struct LlmThreadAdapter {
188    generator: Arc<ContentGenerator>,
189    pool: DbPool,
190}
191
192impl LlmThreadAdapter {
193    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
194        Self { generator, pool }
195    }
196}
197
198#[async_trait::async_trait]
199impl ThreadGenerator for LlmThreadAdapter {
200    async fn generate_thread(
201        &self,
202        topic: &str,
203        _count: Option<usize>,
204    ) -> Result<Vec<String>, ContentLoopError> {
205        let output = self
206            .generator
207            .generate_thread(topic)
208            .await
209            .map_err(llm_to_content_error)?;
210        record_llm_usage(
211            &self.pool,
212            "thread",
213            &output.provider,
214            &output.model,
215            output.usage.input_tokens,
216            output.usage.output_tokens,
217        )
218        .await;
219        Ok(output.tweets)
220    }
221}