Skip to main content

tuitbot_core/storage/
threads.rs

1//! CRUD operations for original tweets and educational threads.
2//!
3//! Provides functions to insert and query original tweets and threads,
4//! supporting the content and thread automation loops.
5
6use super::DbPool;
7use crate::error::StorageError;
8
9/// An educational tweet generated and posted by the agent.
10#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
11pub struct OriginalTweet {
12    /// Internal auto-generated ID.
13    pub id: i64,
14    /// X tweet ID after posting (None if failed).
15    pub tweet_id: Option<String>,
16    /// Tweet text.
17    pub content: String,
18    /// Industry topic this covers.
19    pub topic: Option<String>,
20    /// Which LLM generated this.
21    pub llm_provider: Option<String>,
22    /// ISO-8601 UTC timestamp when tweet was posted.
23    pub created_at: String,
24    /// Status: sent or failed.
25    pub status: String,
26    /// Error details if failed.
27    pub error_message: Option<String>,
28}
29
30/// A series of connected tweets posted as a thread.
31#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
32pub struct Thread {
33    /// Internal auto-generated ID.
34    pub id: i64,
35    /// Thread topic.
36    pub topic: String,
37    /// Number of tweets in thread.
38    pub tweet_count: i64,
39    /// X tweet ID of first tweet.
40    pub root_tweet_id: Option<String>,
41    /// ISO-8601 UTC timestamp when thread was posted.
42    pub created_at: String,
43    /// Status: sent, partial, or failed.
44    pub status: String,
45}
46
47/// An individual tweet within a thread.
48#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
49pub struct ThreadTweet {
50    /// Internal auto-generated ID.
51    pub id: i64,
52    /// Parent thread ID.
53    pub thread_id: i64,
54    /// 0-indexed position in thread.
55    pub position: i64,
56    /// X tweet ID after posting.
57    pub tweet_id: Option<String>,
58    /// Tweet text.
59    pub content: String,
60    /// ISO-8601 UTC timestamp.
61    pub created_at: String,
62}
63
64/// Insert a new original tweet. Returns the auto-generated ID.
65pub async fn insert_original_tweet(
66    pool: &DbPool,
67    tweet: &OriginalTweet,
68) -> Result<i64, StorageError> {
69    let result = sqlx::query(
70        "INSERT INTO original_tweets \
71         (tweet_id, content, topic, llm_provider, created_at, status, error_message) \
72         VALUES (?, ?, ?, ?, ?, ?, ?)",
73    )
74    .bind(&tweet.tweet_id)
75    .bind(&tweet.content)
76    .bind(&tweet.topic)
77    .bind(&tweet.llm_provider)
78    .bind(&tweet.created_at)
79    .bind(&tweet.status)
80    .bind(&tweet.error_message)
81    .execute(pool)
82    .await
83    .map_err(|e| StorageError::Query { source: e })?;
84
85    Ok(result.last_insert_rowid())
86}
87
88/// Get the timestamp of the most recent successfully posted original tweet.
89pub async fn get_last_original_tweet_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
90    let row: Option<(String,)> = sqlx::query_as(
91        "SELECT created_at FROM original_tweets WHERE status = 'sent' \
92         ORDER BY created_at DESC LIMIT 1",
93    )
94    .fetch_optional(pool)
95    .await
96    .map_err(|e| StorageError::Query { source: e })?;
97
98    Ok(row.map(|r| r.0))
99}
100
101/// Count original tweets posted today (UTC).
102pub async fn count_tweets_today(pool: &DbPool) -> Result<i64, StorageError> {
103    let row: (i64,) =
104        sqlx::query_as("SELECT COUNT(*) FROM original_tweets WHERE date(created_at) = date('now')")
105            .fetch_one(pool)
106            .await
107            .map_err(|e| StorageError::Query { source: e })?;
108
109    Ok(row.0)
110}
111
112/// Insert a new thread record. Returns the auto-generated ID.
113pub async fn insert_thread(pool: &DbPool, thread: &Thread) -> Result<i64, StorageError> {
114    let result = sqlx::query(
115        "INSERT INTO threads (topic, tweet_count, root_tweet_id, created_at, status) \
116         VALUES (?, ?, ?, ?, ?)",
117    )
118    .bind(&thread.topic)
119    .bind(thread.tweet_count)
120    .bind(&thread.root_tweet_id)
121    .bind(&thread.created_at)
122    .bind(&thread.status)
123    .execute(pool)
124    .await
125    .map_err(|e| StorageError::Query { source: e })?;
126
127    Ok(result.last_insert_rowid())
128}
129
130/// Insert all tweets for a thread atomically using a transaction.
131///
132/// Either all tweets are inserted or none are (rollback on failure).
133pub async fn insert_thread_tweets(
134    pool: &DbPool,
135    thread_id: i64,
136    tweets: &[ThreadTweet],
137) -> Result<(), StorageError> {
138    let mut tx = pool
139        .begin()
140        .await
141        .map_err(|e| StorageError::Connection { source: e })?;
142
143    for tweet in tweets {
144        sqlx::query(
145            "INSERT INTO thread_tweets (thread_id, position, tweet_id, content, created_at) \
146             VALUES (?, ?, ?, ?, ?)",
147        )
148        .bind(thread_id)
149        .bind(tweet.position)
150        .bind(&tweet.tweet_id)
151        .bind(&tweet.content)
152        .bind(&tweet.created_at)
153        .execute(&mut *tx)
154        .await
155        .map_err(|e| StorageError::Query { source: e })?;
156    }
157
158    tx.commit()
159        .await
160        .map_err(|e| StorageError::Connection { source: e })?;
161
162    Ok(())
163}
164
165/// Get the timestamp of the most recent successfully posted thread.
166pub async fn get_last_thread_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
167    let row: Option<(String,)> = sqlx::query_as(
168        "SELECT created_at FROM threads WHERE status = 'sent' \
169         ORDER BY created_at DESC LIMIT 1",
170    )
171    .fetch_optional(pool)
172    .await
173    .map_err(|e| StorageError::Query { source: e })?;
174
175    Ok(row.map(|r| r.0))
176}
177
178/// Get the timestamps of all successfully posted original tweets today (UTC).
179pub async fn get_todays_tweet_times(pool: &DbPool) -> Result<Vec<String>, StorageError> {
180    let rows: Vec<(String,)> = sqlx::query_as(
181        "SELECT created_at FROM original_tweets \
182         WHERE status = 'sent' AND date(created_at) = date('now') \
183         ORDER BY created_at ASC",
184    )
185    .fetch_all(pool)
186    .await
187    .map_err(|e| StorageError::Query { source: e })?;
188
189    Ok(rows.into_iter().map(|r| r.0).collect())
190}
191
192/// Count threads posted in the current ISO week (Monday-Sunday, UTC).
193pub async fn count_threads_this_week(pool: &DbPool) -> Result<i64, StorageError> {
194    let row: (i64,) = sqlx::query_as(
195        "SELECT COUNT(*) FROM threads \
196         WHERE strftime('%Y-%W', created_at) = strftime('%Y-%W', 'now')",
197    )
198    .fetch_one(pool)
199    .await
200    .map_err(|e| StorageError::Query { source: e })?;
201
202    Ok(row.0)
203}
204
205/// Get the most recent original tweets, newest first.
206pub async fn get_recent_original_tweets(
207    pool: &DbPool,
208    limit: u32,
209) -> Result<Vec<OriginalTweet>, StorageError> {
210    sqlx::query_as::<_, OriginalTweet>(
211        "SELECT * FROM original_tweets ORDER BY created_at DESC LIMIT ?",
212    )
213    .bind(limit)
214    .fetch_all(pool)
215    .await
216    .map_err(|e| StorageError::Query { source: e })
217}
218
219/// Get the most recent threads, newest first.
220pub async fn get_recent_threads(pool: &DbPool, limit: u32) -> Result<Vec<Thread>, StorageError> {
221    sqlx::query_as::<_, Thread>("SELECT * FROM threads ORDER BY created_at DESC LIMIT ?")
222        .bind(limit)
223        .fetch_all(pool)
224        .await
225        .map_err(|e| StorageError::Query { source: e })
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::storage::init_test_db;
232
233    fn now_iso() -> String {
234        chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()
235    }
236
237    fn sample_original_tweet() -> OriginalTweet {
238        OriginalTweet {
239            id: 0,
240            tweet_id: Some("ot_123".to_string()),
241            content: "Educational tweet about Rust".to_string(),
242            topic: Some("rust".to_string()),
243            llm_provider: Some("openai".to_string()),
244            created_at: now_iso(),
245            status: "sent".to_string(),
246            error_message: None,
247        }
248    }
249
250    fn sample_thread() -> Thread {
251        Thread {
252            id: 0,
253            topic: "Rust async patterns".to_string(),
254            tweet_count: 3,
255            root_tweet_id: Some("root_456".to_string()),
256            created_at: now_iso(),
257            status: "sent".to_string(),
258        }
259    }
260
261    fn sample_thread_tweets(thread_id: i64) -> Vec<ThreadTweet> {
262        (0..3)
263            .map(|i| ThreadTweet {
264                id: 0,
265                thread_id,
266                position: i,
267                tweet_id: Some(format!("tt_{i}")),
268                content: format!("Thread tweet {i}"),
269                created_at: now_iso(),
270            })
271            .collect()
272    }
273
274    #[tokio::test]
275    async fn insert_and_query_original_tweet() {
276        let pool = init_test_db().await.expect("init db");
277        let tweet = sample_original_tweet();
278
279        let id = insert_original_tweet(&pool, &tweet).await.expect("insert");
280        assert!(id > 0);
281
282        let time = get_last_original_tweet_time(&pool).await.expect("get time");
283        assert!(time.is_some());
284    }
285
286    #[tokio::test]
287    async fn count_tweets_today_works() {
288        let pool = init_test_db().await.expect("init db");
289        let tweet = sample_original_tweet();
290
291        insert_original_tweet(&pool, &tweet).await.expect("insert");
292        let count = count_tweets_today(&pool).await.expect("count");
293        assert_eq!(count, 1);
294    }
295
296    #[tokio::test]
297    async fn insert_thread_with_tweets() {
298        let pool = init_test_db().await.expect("init db");
299        let thread = sample_thread();
300
301        let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
302        let tweets = sample_thread_tweets(thread_id);
303        insert_thread_tweets(&pool, thread_id, &tweets)
304            .await
305            .expect("insert tweets");
306
307        // Verify all tweets were inserted
308        let rows: Vec<(i64,)> = sqlx::query_as(
309            "SELECT position FROM thread_tweets WHERE thread_id = ? ORDER BY position",
310        )
311        .bind(thread_id)
312        .fetch_all(&pool)
313        .await
314        .expect("query");
315
316        assert_eq!(rows.len(), 3);
317        assert_eq!(rows[0].0, 0);
318        assert_eq!(rows[1].0, 1);
319        assert_eq!(rows[2].0, 2);
320    }
321
322    #[tokio::test]
323    async fn thread_tweet_duplicate_position_fails() {
324        let pool = init_test_db().await.expect("init db");
325        let thread = sample_thread();
326
327        let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
328
329        // Two tweets with same position should fail the UNIQUE constraint
330        let duplicate_tweets = vec![
331            ThreadTweet {
332                id: 0,
333                thread_id,
334                position: 0,
335                tweet_id: Some("a".to_string()),
336                content: "First".to_string(),
337                created_at: now_iso(),
338            },
339            ThreadTweet {
340                id: 0,
341                thread_id,
342                position: 0, // duplicate position
343                tweet_id: Some("b".to_string()),
344                content: "Second".to_string(),
345                created_at: now_iso(),
346            },
347        ];
348
349        let result = insert_thread_tweets(&pool, thread_id, &duplicate_tweets).await;
350        assert!(result.is_err());
351
352        // Verify transaction rolled back (no partial data)
353        let rows: Vec<(i64,)> =
354            sqlx::query_as("SELECT COUNT(*) FROM thread_tweets WHERE thread_id = ?")
355                .bind(thread_id)
356                .fetch_all(&pool)
357                .await
358                .expect("query");
359
360        assert_eq!(rows[0].0, 0, "transaction should have rolled back");
361    }
362
363    #[tokio::test]
364    async fn count_threads_this_week_works() {
365        let pool = init_test_db().await.expect("init db");
366        let thread = sample_thread();
367
368        insert_thread(&pool, &thread).await.expect("insert");
369        let count = count_threads_this_week(&pool).await.expect("count");
370        assert_eq!(count, 1);
371    }
372
373    #[tokio::test]
374    async fn last_thread_time_empty() {
375        let pool = init_test_db().await.expect("init db");
376        let time = get_last_thread_time(&pool).await.expect("get time");
377        assert!(time.is_none());
378    }
379}