Skip to main content

tuitbot_core/automation/adapters/
llm.rs

1//! LLM adapter implementations.
2
3use std::sync::Arc;
4
5use super::super::loop_helpers::{ContentLoopError, LoopError, ReplyGenerator, TweetGenerator};
6use super::super::thread_loop::ThreadGenerator;
7use super::helpers::{llm_to_content_error, llm_to_loop_error};
8use crate::content::ContentGenerator;
9use crate::storage::DbPool;
10
11/// Record LLM usage to the database (fire-and-forget).
12pub(super) async fn record_llm_usage(
13    pool: &DbPool,
14    generation_type: &str,
15    provider: &str,
16    model: &str,
17    input_tokens: u32,
18    output_tokens: u32,
19) {
20    let pricing = crate::llm::pricing::lookup(provider, model);
21    let cost = pricing.compute_cost(input_tokens, output_tokens);
22    if let Err(e) = crate::storage::llm_usage::insert_llm_usage(
23        pool,
24        generation_type,
25        provider,
26        model,
27        input_tokens,
28        output_tokens,
29        cost,
30    )
31    .await
32    {
33        tracing::warn!(error = %e, "Failed to record LLM usage");
34    }
35}
36
37/// Adapts `ContentGenerator` to the `ReplyGenerator` port trait.
38pub struct LlmReplyAdapter {
39    generator: Arc<ContentGenerator>,
40    pool: DbPool,
41}
42
43impl LlmReplyAdapter {
44    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
45        Self { generator, pool }
46    }
47}
48
49#[async_trait::async_trait]
50impl ReplyGenerator for LlmReplyAdapter {
51    async fn generate_reply(
52        &self,
53        tweet_text: &str,
54        author: &str,
55        mention_product: bool,
56    ) -> Result<String, LoopError> {
57        let output = self
58            .generator
59            .generate_reply(tweet_text, author, mention_product)
60            .await
61            .map_err(llm_to_loop_error)?;
62        record_llm_usage(
63            &self.pool,
64            "reply",
65            &output.provider,
66            &output.model,
67            output.usage.input_tokens,
68            output.usage.output_tokens,
69        )
70        .await;
71        Ok(output.text)
72    }
73}
74
75/// Adapts `ContentGenerator` to the `TweetGenerator` port trait.
76pub struct LlmTweetAdapter {
77    generator: Arc<ContentGenerator>,
78    pool: DbPool,
79}
80
81impl LlmTweetAdapter {
82    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
83        Self { generator, pool }
84    }
85}
86
87#[async_trait::async_trait]
88impl TweetGenerator for LlmTweetAdapter {
89    async fn generate_tweet(&self, topic: &str) -> Result<String, ContentLoopError> {
90        let output = self
91            .generator
92            .generate_tweet(topic)
93            .await
94            .map_err(llm_to_content_error)?;
95        record_llm_usage(
96            &self.pool,
97            "tweet",
98            &output.provider,
99            &output.model,
100            output.usage.input_tokens,
101            output.usage.output_tokens,
102        )
103        .await;
104        Ok(output.text)
105    }
106}
107
108/// Adapts `ContentGenerator` to the `ThreadGenerator` port trait.
109pub struct LlmThreadAdapter {
110    generator: Arc<ContentGenerator>,
111    pool: DbPool,
112}
113
114impl LlmThreadAdapter {
115    pub fn new(generator: Arc<ContentGenerator>, pool: DbPool) -> Self {
116        Self { generator, pool }
117    }
118}
119
120#[async_trait::async_trait]
121impl ThreadGenerator for LlmThreadAdapter {
122    async fn generate_thread(
123        &self,
124        topic: &str,
125        _count: Option<usize>,
126    ) -> Result<Vec<String>, ContentLoopError> {
127        let output = self
128            .generator
129            .generate_thread(topic)
130            .await
131            .map_err(llm_to_content_error)?;
132        record_llm_usage(
133            &self.pool,
134            "thread",
135            &output.provider,
136            &output.model,
137            output.usage.input_tokens,
138            output.usage.output_tokens,
139        )
140        .await;
141        Ok(output.tweets)
142    }
143}