Skip to main content

tuitbot_core/storage/
threads.rs

1//! CRUD operations for original tweets and educational threads.
2//!
3//! Provides functions to insert and query original tweets and threads,
4//! supporting the content and thread automation loops.
5
6use super::accounts::DEFAULT_ACCOUNT_ID;
7use super::DbPool;
8use crate::error::StorageError;
9
10/// An educational tweet generated and posted by the agent.
11#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
12pub struct OriginalTweet {
13    /// Internal auto-generated ID.
14    pub id: i64,
15    /// X tweet ID after posting (None if failed).
16    pub tweet_id: Option<String>,
17    /// Tweet text.
18    pub content: String,
19    /// Industry topic this covers.
20    pub topic: Option<String>,
21    /// Which LLM generated this.
22    pub llm_provider: Option<String>,
23    /// ISO-8601 UTC timestamp when tweet was posted.
24    pub created_at: String,
25    /// Status: sent or failed.
26    pub status: String,
27    /// Error details if failed.
28    pub error_message: Option<String>,
29}
30
31/// A series of connected tweets posted as a thread.
32#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
33pub struct Thread {
34    /// Internal auto-generated ID.
35    pub id: i64,
36    /// Thread topic.
37    pub topic: String,
38    /// Number of tweets in thread.
39    pub tweet_count: i64,
40    /// X tweet ID of first tweet.
41    pub root_tweet_id: Option<String>,
42    /// ISO-8601 UTC timestamp when thread was posted.
43    pub created_at: String,
44    /// Status: sent, partial, or failed.
45    pub status: String,
46}
47
48/// An individual tweet within a thread.
49#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
50pub struct ThreadTweet {
51    /// Internal auto-generated ID.
52    pub id: i64,
53    /// Parent thread ID.
54    pub thread_id: i64,
55    /// 0-indexed position in thread.
56    pub position: i64,
57    /// X tweet ID after posting.
58    pub tweet_id: Option<String>,
59    /// Tweet text.
60    pub content: String,
61    /// ISO-8601 UTC timestamp.
62    pub created_at: String,
63}
64
65/// Insert a new original tweet for a specific account. Returns the auto-generated ID.
66pub async fn insert_original_tweet_for(
67    pool: &DbPool,
68    account_id: &str,
69    tweet: &OriginalTweet,
70) -> Result<i64, StorageError> {
71    let result = sqlx::query(
72        "INSERT INTO original_tweets \
73         (account_id, tweet_id, content, topic, llm_provider, created_at, status, error_message) \
74         VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
75    )
76    .bind(account_id)
77    .bind(&tweet.tweet_id)
78    .bind(&tweet.content)
79    .bind(&tweet.topic)
80    .bind(&tweet.llm_provider)
81    .bind(&tweet.created_at)
82    .bind(&tweet.status)
83    .bind(&tweet.error_message)
84    .execute(pool)
85    .await
86    .map_err(|e| StorageError::Query { source: e })?;
87
88    Ok(result.last_insert_rowid())
89}
90
91/// Insert a new original tweet. Returns the auto-generated ID.
92pub async fn insert_original_tweet(
93    pool: &DbPool,
94    tweet: &OriginalTweet,
95) -> Result<i64, StorageError> {
96    insert_original_tweet_for(pool, DEFAULT_ACCOUNT_ID, tweet).await
97}
98
99/// Get the timestamp of the most recent successfully posted original tweet for a specific account.
100pub async fn get_last_original_tweet_time_for(
101    pool: &DbPool,
102    account_id: &str,
103) -> Result<Option<String>, StorageError> {
104    let row: Option<(String,)> = sqlx::query_as(
105        "SELECT created_at FROM original_tweets WHERE account_id = ? AND status = 'sent' \
106         ORDER BY created_at DESC LIMIT 1",
107    )
108    .bind(account_id)
109    .fetch_optional(pool)
110    .await
111    .map_err(|e| StorageError::Query { source: e })?;
112
113    Ok(row.map(|r| r.0))
114}
115
116/// Get the timestamp of the most recent successfully posted original tweet.
117pub async fn get_last_original_tweet_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
118    get_last_original_tweet_time_for(pool, DEFAULT_ACCOUNT_ID).await
119}
120
121/// Count original tweets posted today (UTC) for a specific account.
122pub async fn count_tweets_today_for(pool: &DbPool, account_id: &str) -> Result<i64, StorageError> {
123    let row: (i64,) = sqlx::query_as(
124        "SELECT COUNT(*) FROM original_tweets WHERE account_id = ? AND date(created_at) = date('now')",
125    )
126    .bind(account_id)
127    .fetch_one(pool)
128    .await
129    .map_err(|e| StorageError::Query { source: e })?;
130
131    Ok(row.0)
132}
133
134/// Count original tweets posted today (UTC).
135pub async fn count_tweets_today(pool: &DbPool) -> Result<i64, StorageError> {
136    count_tweets_today_for(pool, DEFAULT_ACCOUNT_ID).await
137}
138
139/// Insert a new thread record for a specific account. Returns the auto-generated ID.
140pub async fn insert_thread_for(
141    pool: &DbPool,
142    account_id: &str,
143    thread: &Thread,
144) -> Result<i64, StorageError> {
145    let result = sqlx::query(
146        "INSERT INTO threads (account_id, topic, tweet_count, root_tweet_id, created_at, status) \
147         VALUES (?, ?, ?, ?, ?, ?)",
148    )
149    .bind(account_id)
150    .bind(&thread.topic)
151    .bind(thread.tweet_count)
152    .bind(&thread.root_tweet_id)
153    .bind(&thread.created_at)
154    .bind(&thread.status)
155    .execute(pool)
156    .await
157    .map_err(|e| StorageError::Query { source: e })?;
158
159    Ok(result.last_insert_rowid())
160}
161
162/// Insert a new thread record. Returns the auto-generated ID.
163pub async fn insert_thread(pool: &DbPool, thread: &Thread) -> Result<i64, StorageError> {
164    insert_thread_for(pool, DEFAULT_ACCOUNT_ID, thread).await
165}
166
167/// Insert all tweets for a thread atomically using a transaction for a specific account.
168///
169/// Either all tweets are inserted or none are (rollback on failure).
170pub async fn insert_thread_tweets_for(
171    pool: &DbPool,
172    account_id: &str,
173    thread_id: i64,
174    tweets: &[ThreadTweet],
175) -> Result<(), StorageError> {
176    let mut tx = pool
177        .begin()
178        .await
179        .map_err(|e| StorageError::Connection { source: e })?;
180
181    for tweet in tweets {
182        sqlx::query(
183            "INSERT INTO thread_tweets \
184             (account_id, thread_id, position, tweet_id, content, created_at) \
185             VALUES (?, ?, ?, ?, ?, ?)",
186        )
187        .bind(account_id)
188        .bind(thread_id)
189        .bind(tweet.position)
190        .bind(&tweet.tweet_id)
191        .bind(&tweet.content)
192        .bind(&tweet.created_at)
193        .execute(&mut *tx)
194        .await
195        .map_err(|e| StorageError::Query { source: e })?;
196    }
197
198    tx.commit()
199        .await
200        .map_err(|e| StorageError::Connection { source: e })?;
201
202    Ok(())
203}
204
205/// Insert all tweets for a thread atomically using a transaction.
206///
207/// Either all tweets are inserted or none are (rollback on failure).
208pub async fn insert_thread_tweets(
209    pool: &DbPool,
210    thread_id: i64,
211    tweets: &[ThreadTweet],
212) -> Result<(), StorageError> {
213    insert_thread_tweets_for(pool, DEFAULT_ACCOUNT_ID, thread_id, tweets).await
214}
215
216/// Get the timestamp of the most recent successfully posted thread for a specific account.
217pub async fn get_last_thread_time_for(
218    pool: &DbPool,
219    account_id: &str,
220) -> Result<Option<String>, StorageError> {
221    let row: Option<(String,)> = sqlx::query_as(
222        "SELECT created_at FROM threads WHERE account_id = ? AND status = 'sent' \
223         ORDER BY created_at DESC LIMIT 1",
224    )
225    .bind(account_id)
226    .fetch_optional(pool)
227    .await
228    .map_err(|e| StorageError::Query { source: e })?;
229
230    Ok(row.map(|r| r.0))
231}
232
233/// Get the timestamp of the most recent successfully posted thread.
234pub async fn get_last_thread_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
235    get_last_thread_time_for(pool, DEFAULT_ACCOUNT_ID).await
236}
237
238/// Get the timestamps of all successfully posted original tweets today (UTC) for a specific account.
239pub async fn get_todays_tweet_times_for(
240    pool: &DbPool,
241    account_id: &str,
242) -> Result<Vec<String>, StorageError> {
243    let rows: Vec<(String,)> = sqlx::query_as(
244        "SELECT created_at FROM original_tweets \
245         WHERE account_id = ? AND status = 'sent' AND date(created_at) = date('now') \
246         ORDER BY created_at ASC",
247    )
248    .bind(account_id)
249    .fetch_all(pool)
250    .await
251    .map_err(|e| StorageError::Query { source: e })?;
252
253    Ok(rows.into_iter().map(|r| r.0).collect())
254}
255
256/// Get the timestamps of all successfully posted original tweets today (UTC).
257pub async fn get_todays_tweet_times(pool: &DbPool) -> Result<Vec<String>, StorageError> {
258    get_todays_tweet_times_for(pool, DEFAULT_ACCOUNT_ID).await
259}
260
261/// Count threads posted in the current ISO week (Monday-Sunday, UTC) for a specific account.
262pub async fn count_threads_this_week_for(
263    pool: &DbPool,
264    account_id: &str,
265) -> Result<i64, StorageError> {
266    let row: (i64,) = sqlx::query_as(
267        "SELECT COUNT(*) FROM threads \
268         WHERE account_id = ? AND strftime('%Y-%W', created_at) = strftime('%Y-%W', 'now')",
269    )
270    .bind(account_id)
271    .fetch_one(pool)
272    .await
273    .map_err(|e| StorageError::Query { source: e })?;
274
275    Ok(row.0)
276}
277
278/// Count threads posted in the current ISO week (Monday-Sunday, UTC).
279pub async fn count_threads_this_week(pool: &DbPool) -> Result<i64, StorageError> {
280    count_threads_this_week_for(pool, DEFAULT_ACCOUNT_ID).await
281}
282
283/// Get original tweets within a date range for a specific account, ordered by creation time.
284pub async fn get_tweets_in_range_for(
285    pool: &DbPool,
286    account_id: &str,
287    from: &str,
288    to: &str,
289) -> Result<Vec<OriginalTweet>, StorageError> {
290    sqlx::query_as::<_, OriginalTweet>(
291        "SELECT * FROM original_tweets \
292         WHERE account_id = ? AND created_at BETWEEN ? AND ? \
293         ORDER BY created_at ASC",
294    )
295    .bind(account_id)
296    .bind(from)
297    .bind(to)
298    .fetch_all(pool)
299    .await
300    .map_err(|e| StorageError::Query { source: e })
301}
302
303/// Get original tweets within a date range, ordered by creation time.
304pub async fn get_tweets_in_range(
305    pool: &DbPool,
306    from: &str,
307    to: &str,
308) -> Result<Vec<OriginalTweet>, StorageError> {
309    get_tweets_in_range_for(pool, DEFAULT_ACCOUNT_ID, from, to).await
310}
311
312/// Get threads within a date range for a specific account, ordered by creation time.
313pub async fn get_threads_in_range_for(
314    pool: &DbPool,
315    account_id: &str,
316    from: &str,
317    to: &str,
318) -> Result<Vec<Thread>, StorageError> {
319    sqlx::query_as::<_, Thread>(
320        "SELECT * FROM threads \
321         WHERE account_id = ? AND created_at BETWEEN ? AND ? \
322         ORDER BY created_at ASC",
323    )
324    .bind(account_id)
325    .bind(from)
326    .bind(to)
327    .fetch_all(pool)
328    .await
329    .map_err(|e| StorageError::Query { source: e })
330}
331
332/// Get threads within a date range, ordered by creation time.
333pub async fn get_threads_in_range(
334    pool: &DbPool,
335    from: &str,
336    to: &str,
337) -> Result<Vec<Thread>, StorageError> {
338    get_threads_in_range_for(pool, DEFAULT_ACCOUNT_ID, from, to).await
339}
340
341/// Get the most recent original tweets for a specific account, newest first.
342pub async fn get_recent_original_tweets_for(
343    pool: &DbPool,
344    account_id: &str,
345    limit: u32,
346) -> Result<Vec<OriginalTweet>, StorageError> {
347    sqlx::query_as::<_, OriginalTweet>(
348        "SELECT * FROM original_tweets WHERE account_id = ? ORDER BY created_at DESC LIMIT ?",
349    )
350    .bind(account_id)
351    .bind(limit)
352    .fetch_all(pool)
353    .await
354    .map_err(|e| StorageError::Query { source: e })
355}
356
357/// Get the most recent original tweets, newest first.
358pub async fn get_recent_original_tweets(
359    pool: &DbPool,
360    limit: u32,
361) -> Result<Vec<OriginalTweet>, StorageError> {
362    get_recent_original_tweets_for(pool, DEFAULT_ACCOUNT_ID, limit).await
363}
364
365/// Get the most recent threads for a specific account, newest first.
366pub async fn get_recent_threads_for(
367    pool: &DbPool,
368    account_id: &str,
369    limit: u32,
370) -> Result<Vec<Thread>, StorageError> {
371    sqlx::query_as::<_, Thread>(
372        "SELECT * FROM threads WHERE account_id = ? ORDER BY created_at DESC LIMIT ?",
373    )
374    .bind(account_id)
375    .bind(limit)
376    .fetch_all(pool)
377    .await
378    .map_err(|e| StorageError::Query { source: e })
379}
380
381/// Get the most recent threads, newest first.
382pub async fn get_recent_threads(pool: &DbPool, limit: u32) -> Result<Vec<Thread>, StorageError> {
383    get_recent_threads_for(pool, DEFAULT_ACCOUNT_ID, limit).await
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::storage::init_test_db;
390
391    fn now_iso() -> String {
392        chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()
393    }
394
395    fn sample_original_tweet() -> OriginalTweet {
396        OriginalTweet {
397            id: 0,
398            tweet_id: Some("ot_123".to_string()),
399            content: "Educational tweet about Rust".to_string(),
400            topic: Some("rust".to_string()),
401            llm_provider: Some("openai".to_string()),
402            created_at: now_iso(),
403            status: "sent".to_string(),
404            error_message: None,
405        }
406    }
407
408    fn sample_thread() -> Thread {
409        Thread {
410            id: 0,
411            topic: "Rust async patterns".to_string(),
412            tweet_count: 3,
413            root_tweet_id: Some("root_456".to_string()),
414            created_at: now_iso(),
415            status: "sent".to_string(),
416        }
417    }
418
419    fn sample_thread_tweets(thread_id: i64) -> Vec<ThreadTweet> {
420        (0..3)
421            .map(|i| ThreadTweet {
422                id: 0,
423                thread_id,
424                position: i,
425                tweet_id: Some(format!("tt_{i}")),
426                content: format!("Thread tweet {i}"),
427                created_at: now_iso(),
428            })
429            .collect()
430    }
431
432    #[tokio::test]
433    async fn insert_and_query_original_tweet() {
434        let pool = init_test_db().await.expect("init db");
435        let tweet = sample_original_tweet();
436
437        let id = insert_original_tweet(&pool, &tweet).await.expect("insert");
438        assert!(id > 0);
439
440        let time = get_last_original_tweet_time(&pool).await.expect("get time");
441        assert!(time.is_some());
442    }
443
444    #[tokio::test]
445    async fn count_tweets_today_works() {
446        let pool = init_test_db().await.expect("init db");
447        let tweet = sample_original_tweet();
448
449        insert_original_tweet(&pool, &tweet).await.expect("insert");
450        let count = count_tweets_today(&pool).await.expect("count");
451        assert_eq!(count, 1);
452    }
453
454    #[tokio::test]
455    async fn insert_thread_with_tweets() {
456        let pool = init_test_db().await.expect("init db");
457        let thread = sample_thread();
458
459        let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
460        let tweets = sample_thread_tweets(thread_id);
461        insert_thread_tweets(&pool, thread_id, &tweets)
462            .await
463            .expect("insert tweets");
464
465        // Verify all tweets were inserted
466        let rows: Vec<(i64,)> = sqlx::query_as(
467            "SELECT position FROM thread_tweets WHERE thread_id = ? ORDER BY position",
468        )
469        .bind(thread_id)
470        .fetch_all(&pool)
471        .await
472        .expect("query");
473
474        assert_eq!(rows.len(), 3);
475        assert_eq!(rows[0].0, 0);
476        assert_eq!(rows[1].0, 1);
477        assert_eq!(rows[2].0, 2);
478    }
479
480    #[tokio::test]
481    async fn thread_tweet_duplicate_position_fails() {
482        let pool = init_test_db().await.expect("init db");
483        let thread = sample_thread();
484
485        let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
486
487        // Two tweets with same position should fail the UNIQUE constraint
488        let duplicate_tweets = vec![
489            ThreadTweet {
490                id: 0,
491                thread_id,
492                position: 0,
493                tweet_id: Some("a".to_string()),
494                content: "First".to_string(),
495                created_at: now_iso(),
496            },
497            ThreadTweet {
498                id: 0,
499                thread_id,
500                position: 0, // duplicate position
501                tweet_id: Some("b".to_string()),
502                content: "Second".to_string(),
503                created_at: now_iso(),
504            },
505        ];
506
507        let result = insert_thread_tweets(&pool, thread_id, &duplicate_tweets).await;
508        assert!(result.is_err());
509
510        // Verify transaction rolled back (no partial data)
511        let rows: Vec<(i64,)> =
512            sqlx::query_as("SELECT COUNT(*) FROM thread_tweets WHERE thread_id = ?")
513                .bind(thread_id)
514                .fetch_all(&pool)
515                .await
516                .expect("query");
517
518        assert_eq!(rows[0].0, 0, "transaction should have rolled back");
519    }
520
521    #[tokio::test]
522    async fn count_threads_this_week_works() {
523        let pool = init_test_db().await.expect("init db");
524        let thread = sample_thread();
525
526        insert_thread(&pool, &thread).await.expect("insert");
527        let count = count_threads_this_week(&pool).await.expect("count");
528        assert_eq!(count, 1);
529    }
530
531    #[tokio::test]
532    async fn last_thread_time_empty() {
533        let pool = init_test_db().await.expect("init db");
534        let time = get_last_thread_time(&pool).await.expect("get time");
535        assert!(time.is_none());
536    }
537
538    #[tokio::test]
539    async fn get_tweets_in_range_filters() {
540        let pool = init_test_db().await.expect("init db");
541
542        let mut tweet = sample_original_tweet();
543        tweet.created_at = "2026-02-20T10:00:00Z".to_string();
544        insert_original_tweet(&pool, &tweet).await.expect("insert");
545
546        let mut tweet2 = sample_original_tweet();
547        tweet2.created_at = "2026-02-25T10:00:00Z".to_string();
548        tweet2.tweet_id = Some("ot_456".to_string());
549        insert_original_tweet(&pool, &tweet2).await.expect("insert");
550
551        let in_range = get_tweets_in_range(&pool, "2026-02-19T00:00:00Z", "2026-02-21T00:00:00Z")
552            .await
553            .expect("range");
554        assert_eq!(in_range.len(), 1);
555        assert_eq!(in_range[0].tweet_id, Some("ot_123".to_string()));
556
557        let all = get_tweets_in_range(&pool, "2026-02-01T00:00:00Z", "2026-02-28T00:00:00Z")
558            .await
559            .expect("range");
560        assert_eq!(all.len(), 2);
561    }
562
563    #[tokio::test]
564    async fn get_threads_in_range_filters() {
565        let pool = init_test_db().await.expect("init db");
566
567        let mut thread = sample_thread();
568        thread.created_at = "2026-02-20T10:00:00Z".to_string();
569        insert_thread(&pool, &thread).await.expect("insert");
570
571        let mut thread2 = sample_thread();
572        thread2.created_at = "2026-02-25T10:00:00Z".to_string();
573        insert_thread(&pool, &thread2).await.expect("insert");
574
575        let in_range = get_threads_in_range(&pool, "2026-02-19T00:00:00Z", "2026-02-21T00:00:00Z")
576            .await
577            .expect("range");
578        assert_eq!(in_range.len(), 1);
579
580        let all = get_threads_in_range(&pool, "2026-02-01T00:00:00Z", "2026-02-28T00:00:00Z")
581            .await
582            .expect("range");
583        assert_eq!(all.len(), 2);
584    }
585}