1use super::accounts::DEFAULT_ACCOUNT_ID;
7use super::DbPool;
8use crate::error::StorageError;
9
10#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
12pub struct OriginalTweet {
13 pub id: i64,
15 pub tweet_id: Option<String>,
17 pub content: String,
19 pub topic: Option<String>,
21 pub llm_provider: Option<String>,
23 pub created_at: String,
25 pub status: String,
27 pub error_message: Option<String>,
29}
30
31#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
33pub struct Thread {
34 pub id: i64,
36 pub topic: String,
38 pub tweet_count: i64,
40 pub root_tweet_id: Option<String>,
42 pub created_at: String,
44 pub status: String,
46}
47
48#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
50pub struct ThreadTweet {
51 pub id: i64,
53 pub thread_id: i64,
55 pub position: i64,
57 pub tweet_id: Option<String>,
59 pub content: String,
61 pub created_at: String,
63}
64
65pub async fn insert_original_tweet_for(
67 pool: &DbPool,
68 account_id: &str,
69 tweet: &OriginalTweet,
70) -> Result<i64, StorageError> {
71 let result = sqlx::query(
72 "INSERT INTO original_tweets \
73 (account_id, tweet_id, content, topic, llm_provider, created_at, status, error_message) \
74 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
75 )
76 .bind(account_id)
77 .bind(&tweet.tweet_id)
78 .bind(&tweet.content)
79 .bind(&tweet.topic)
80 .bind(&tweet.llm_provider)
81 .bind(&tweet.created_at)
82 .bind(&tweet.status)
83 .bind(&tweet.error_message)
84 .execute(pool)
85 .await
86 .map_err(|e| StorageError::Query { source: e })?;
87
88 Ok(result.last_insert_rowid())
89}
90
91pub async fn insert_original_tweet(
93 pool: &DbPool,
94 tweet: &OriginalTweet,
95) -> Result<i64, StorageError> {
96 insert_original_tweet_for(pool, DEFAULT_ACCOUNT_ID, tweet).await
97}
98
99pub async fn get_last_original_tweet_time_for(
101 pool: &DbPool,
102 account_id: &str,
103) -> Result<Option<String>, StorageError> {
104 let row: Option<(String,)> = sqlx::query_as(
105 "SELECT created_at FROM original_tweets WHERE account_id = ? AND status = 'sent' \
106 ORDER BY created_at DESC LIMIT 1",
107 )
108 .bind(account_id)
109 .fetch_optional(pool)
110 .await
111 .map_err(|e| StorageError::Query { source: e })?;
112
113 Ok(row.map(|r| r.0))
114}
115
116pub async fn get_last_original_tweet_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
118 get_last_original_tweet_time_for(pool, DEFAULT_ACCOUNT_ID).await
119}
120
121pub async fn count_tweets_today_for(pool: &DbPool, account_id: &str) -> Result<i64, StorageError> {
123 let row: (i64,) = sqlx::query_as(
124 "SELECT COUNT(*) FROM original_tweets WHERE account_id = ? AND date(created_at) = date('now')",
125 )
126 .bind(account_id)
127 .fetch_one(pool)
128 .await
129 .map_err(|e| StorageError::Query { source: e })?;
130
131 Ok(row.0)
132}
133
134pub async fn count_tweets_today(pool: &DbPool) -> Result<i64, StorageError> {
136 count_tweets_today_for(pool, DEFAULT_ACCOUNT_ID).await
137}
138
139pub async fn insert_thread_for(
141 pool: &DbPool,
142 account_id: &str,
143 thread: &Thread,
144) -> Result<i64, StorageError> {
145 let result = sqlx::query(
146 "INSERT INTO threads (account_id, topic, tweet_count, root_tweet_id, created_at, status) \
147 VALUES (?, ?, ?, ?, ?, ?)",
148 )
149 .bind(account_id)
150 .bind(&thread.topic)
151 .bind(thread.tweet_count)
152 .bind(&thread.root_tweet_id)
153 .bind(&thread.created_at)
154 .bind(&thread.status)
155 .execute(pool)
156 .await
157 .map_err(|e| StorageError::Query { source: e })?;
158
159 Ok(result.last_insert_rowid())
160}
161
162pub async fn insert_thread(pool: &DbPool, thread: &Thread) -> Result<i64, StorageError> {
164 insert_thread_for(pool, DEFAULT_ACCOUNT_ID, thread).await
165}
166
167pub async fn insert_thread_tweets_for(
171 pool: &DbPool,
172 account_id: &str,
173 thread_id: i64,
174 tweets: &[ThreadTweet],
175) -> Result<(), StorageError> {
176 let mut tx = pool
177 .begin()
178 .await
179 .map_err(|e| StorageError::Connection { source: e })?;
180
181 for tweet in tweets {
182 sqlx::query(
183 "INSERT INTO thread_tweets \
184 (account_id, thread_id, position, tweet_id, content, created_at) \
185 VALUES (?, ?, ?, ?, ?, ?)",
186 )
187 .bind(account_id)
188 .bind(thread_id)
189 .bind(tweet.position)
190 .bind(&tweet.tweet_id)
191 .bind(&tweet.content)
192 .bind(&tweet.created_at)
193 .execute(&mut *tx)
194 .await
195 .map_err(|e| StorageError::Query { source: e })?;
196 }
197
198 tx.commit()
199 .await
200 .map_err(|e| StorageError::Connection { source: e })?;
201
202 Ok(())
203}
204
205pub async fn insert_thread_tweets(
209 pool: &DbPool,
210 thread_id: i64,
211 tweets: &[ThreadTweet],
212) -> Result<(), StorageError> {
213 insert_thread_tweets_for(pool, DEFAULT_ACCOUNT_ID, thread_id, tweets).await
214}
215
216pub async fn get_last_thread_time_for(
218 pool: &DbPool,
219 account_id: &str,
220) -> Result<Option<String>, StorageError> {
221 let row: Option<(String,)> = sqlx::query_as(
222 "SELECT created_at FROM threads WHERE account_id = ? AND status = 'sent' \
223 ORDER BY created_at DESC LIMIT 1",
224 )
225 .bind(account_id)
226 .fetch_optional(pool)
227 .await
228 .map_err(|e| StorageError::Query { source: e })?;
229
230 Ok(row.map(|r| r.0))
231}
232
233pub async fn get_last_thread_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
235 get_last_thread_time_for(pool, DEFAULT_ACCOUNT_ID).await
236}
237
238pub async fn get_todays_tweet_times_for(
240 pool: &DbPool,
241 account_id: &str,
242) -> Result<Vec<String>, StorageError> {
243 let rows: Vec<(String,)> = sqlx::query_as(
244 "SELECT created_at FROM original_tweets \
245 WHERE account_id = ? AND status = 'sent' AND date(created_at) = date('now') \
246 ORDER BY created_at ASC",
247 )
248 .bind(account_id)
249 .fetch_all(pool)
250 .await
251 .map_err(|e| StorageError::Query { source: e })?;
252
253 Ok(rows.into_iter().map(|r| r.0).collect())
254}
255
256pub async fn get_todays_tweet_times(pool: &DbPool) -> Result<Vec<String>, StorageError> {
258 get_todays_tweet_times_for(pool, DEFAULT_ACCOUNT_ID).await
259}
260
261pub async fn count_threads_this_week_for(
263 pool: &DbPool,
264 account_id: &str,
265) -> Result<i64, StorageError> {
266 let row: (i64,) = sqlx::query_as(
267 "SELECT COUNT(*) FROM threads \
268 WHERE account_id = ? AND strftime('%Y-%W', created_at) = strftime('%Y-%W', 'now')",
269 )
270 .bind(account_id)
271 .fetch_one(pool)
272 .await
273 .map_err(|e| StorageError::Query { source: e })?;
274
275 Ok(row.0)
276}
277
278pub async fn count_threads_this_week(pool: &DbPool) -> Result<i64, StorageError> {
280 count_threads_this_week_for(pool, DEFAULT_ACCOUNT_ID).await
281}
282
283pub async fn get_tweets_in_range_for(
285 pool: &DbPool,
286 account_id: &str,
287 from: &str,
288 to: &str,
289) -> Result<Vec<OriginalTweet>, StorageError> {
290 sqlx::query_as::<_, OriginalTweet>(
291 "SELECT * FROM original_tweets \
292 WHERE account_id = ? AND created_at BETWEEN ? AND ? \
293 ORDER BY created_at ASC",
294 )
295 .bind(account_id)
296 .bind(from)
297 .bind(to)
298 .fetch_all(pool)
299 .await
300 .map_err(|e| StorageError::Query { source: e })
301}
302
303pub async fn get_tweets_in_range(
305 pool: &DbPool,
306 from: &str,
307 to: &str,
308) -> Result<Vec<OriginalTweet>, StorageError> {
309 get_tweets_in_range_for(pool, DEFAULT_ACCOUNT_ID, from, to).await
310}
311
312pub async fn get_threads_in_range_for(
314 pool: &DbPool,
315 account_id: &str,
316 from: &str,
317 to: &str,
318) -> Result<Vec<Thread>, StorageError> {
319 sqlx::query_as::<_, Thread>(
320 "SELECT * FROM threads \
321 WHERE account_id = ? AND created_at BETWEEN ? AND ? \
322 ORDER BY created_at ASC",
323 )
324 .bind(account_id)
325 .bind(from)
326 .bind(to)
327 .fetch_all(pool)
328 .await
329 .map_err(|e| StorageError::Query { source: e })
330}
331
332pub async fn get_threads_in_range(
334 pool: &DbPool,
335 from: &str,
336 to: &str,
337) -> Result<Vec<Thread>, StorageError> {
338 get_threads_in_range_for(pool, DEFAULT_ACCOUNT_ID, from, to).await
339}
340
341pub async fn get_recent_original_tweets_for(
343 pool: &DbPool,
344 account_id: &str,
345 limit: u32,
346) -> Result<Vec<OriginalTweet>, StorageError> {
347 sqlx::query_as::<_, OriginalTweet>(
348 "SELECT * FROM original_tweets WHERE account_id = ? ORDER BY created_at DESC LIMIT ?",
349 )
350 .bind(account_id)
351 .bind(limit)
352 .fetch_all(pool)
353 .await
354 .map_err(|e| StorageError::Query { source: e })
355}
356
357pub async fn get_recent_original_tweets(
359 pool: &DbPool,
360 limit: u32,
361) -> Result<Vec<OriginalTweet>, StorageError> {
362 get_recent_original_tweets_for(pool, DEFAULT_ACCOUNT_ID, limit).await
363}
364
365pub async fn get_recent_threads_for(
367 pool: &DbPool,
368 account_id: &str,
369 limit: u32,
370) -> Result<Vec<Thread>, StorageError> {
371 sqlx::query_as::<_, Thread>(
372 "SELECT * FROM threads WHERE account_id = ? ORDER BY created_at DESC LIMIT ?",
373 )
374 .bind(account_id)
375 .bind(limit)
376 .fetch_all(pool)
377 .await
378 .map_err(|e| StorageError::Query { source: e })
379}
380
381pub async fn get_recent_threads(pool: &DbPool, limit: u32) -> Result<Vec<Thread>, StorageError> {
383 get_recent_threads_for(pool, DEFAULT_ACCOUNT_ID, limit).await
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::storage::init_test_db;
390
391 fn now_iso() -> String {
392 chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()
393 }
394
395 fn sample_original_tweet() -> OriginalTweet {
396 OriginalTweet {
397 id: 0,
398 tweet_id: Some("ot_123".to_string()),
399 content: "Educational tweet about Rust".to_string(),
400 topic: Some("rust".to_string()),
401 llm_provider: Some("openai".to_string()),
402 created_at: now_iso(),
403 status: "sent".to_string(),
404 error_message: None,
405 }
406 }
407
408 fn sample_thread() -> Thread {
409 Thread {
410 id: 0,
411 topic: "Rust async patterns".to_string(),
412 tweet_count: 3,
413 root_tweet_id: Some("root_456".to_string()),
414 created_at: now_iso(),
415 status: "sent".to_string(),
416 }
417 }
418
419 fn sample_thread_tweets(thread_id: i64) -> Vec<ThreadTweet> {
420 (0..3)
421 .map(|i| ThreadTweet {
422 id: 0,
423 thread_id,
424 position: i,
425 tweet_id: Some(format!("tt_{i}")),
426 content: format!("Thread tweet {i}"),
427 created_at: now_iso(),
428 })
429 .collect()
430 }
431
432 #[tokio::test]
433 async fn insert_and_query_original_tweet() {
434 let pool = init_test_db().await.expect("init db");
435 let tweet = sample_original_tweet();
436
437 let id = insert_original_tweet(&pool, &tweet).await.expect("insert");
438 assert!(id > 0);
439
440 let time = get_last_original_tweet_time(&pool).await.expect("get time");
441 assert!(time.is_some());
442 }
443
444 #[tokio::test]
445 async fn count_tweets_today_works() {
446 let pool = init_test_db().await.expect("init db");
447 let tweet = sample_original_tweet();
448
449 insert_original_tweet(&pool, &tweet).await.expect("insert");
450 let count = count_tweets_today(&pool).await.expect("count");
451 assert_eq!(count, 1);
452 }
453
454 #[tokio::test]
455 async fn insert_thread_with_tweets() {
456 let pool = init_test_db().await.expect("init db");
457 let thread = sample_thread();
458
459 let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
460 let tweets = sample_thread_tweets(thread_id);
461 insert_thread_tweets(&pool, thread_id, &tweets)
462 .await
463 .expect("insert tweets");
464
465 let rows: Vec<(i64,)> = sqlx::query_as(
467 "SELECT position FROM thread_tweets WHERE thread_id = ? ORDER BY position",
468 )
469 .bind(thread_id)
470 .fetch_all(&pool)
471 .await
472 .expect("query");
473
474 assert_eq!(rows.len(), 3);
475 assert_eq!(rows[0].0, 0);
476 assert_eq!(rows[1].0, 1);
477 assert_eq!(rows[2].0, 2);
478 }
479
480 #[tokio::test]
481 async fn thread_tweet_duplicate_position_fails() {
482 let pool = init_test_db().await.expect("init db");
483 let thread = sample_thread();
484
485 let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
486
487 let duplicate_tweets = vec![
489 ThreadTweet {
490 id: 0,
491 thread_id,
492 position: 0,
493 tweet_id: Some("a".to_string()),
494 content: "First".to_string(),
495 created_at: now_iso(),
496 },
497 ThreadTweet {
498 id: 0,
499 thread_id,
500 position: 0, tweet_id: Some("b".to_string()),
502 content: "Second".to_string(),
503 created_at: now_iso(),
504 },
505 ];
506
507 let result = insert_thread_tweets(&pool, thread_id, &duplicate_tweets).await;
508 assert!(result.is_err());
509
510 let rows: Vec<(i64,)> =
512 sqlx::query_as("SELECT COUNT(*) FROM thread_tweets WHERE thread_id = ?")
513 .bind(thread_id)
514 .fetch_all(&pool)
515 .await
516 .expect("query");
517
518 assert_eq!(rows[0].0, 0, "transaction should have rolled back");
519 }
520
521 #[tokio::test]
522 async fn count_threads_this_week_works() {
523 let pool = init_test_db().await.expect("init db");
524 let thread = sample_thread();
525
526 insert_thread(&pool, &thread).await.expect("insert");
527 let count = count_threads_this_week(&pool).await.expect("count");
528 assert_eq!(count, 1);
529 }
530
531 #[tokio::test]
532 async fn last_thread_time_empty() {
533 let pool = init_test_db().await.expect("init db");
534 let time = get_last_thread_time(&pool).await.expect("get time");
535 assert!(time.is_none());
536 }
537
538 #[tokio::test]
539 async fn get_tweets_in_range_filters() {
540 let pool = init_test_db().await.expect("init db");
541
542 let mut tweet = sample_original_tweet();
543 tweet.created_at = "2026-02-20T10:00:00Z".to_string();
544 insert_original_tweet(&pool, &tweet).await.expect("insert");
545
546 let mut tweet2 = sample_original_tweet();
547 tweet2.created_at = "2026-02-25T10:00:00Z".to_string();
548 tweet2.tweet_id = Some("ot_456".to_string());
549 insert_original_tweet(&pool, &tweet2).await.expect("insert");
550
551 let in_range = get_tweets_in_range(&pool, "2026-02-19T00:00:00Z", "2026-02-21T00:00:00Z")
552 .await
553 .expect("range");
554 assert_eq!(in_range.len(), 1);
555 assert_eq!(in_range[0].tweet_id, Some("ot_123".to_string()));
556
557 let all = get_tweets_in_range(&pool, "2026-02-01T00:00:00Z", "2026-02-28T00:00:00Z")
558 .await
559 .expect("range");
560 assert_eq!(all.len(), 2);
561 }
562
563 #[tokio::test]
564 async fn get_threads_in_range_filters() {
565 let pool = init_test_db().await.expect("init db");
566
567 let mut thread = sample_thread();
568 thread.created_at = "2026-02-20T10:00:00Z".to_string();
569 insert_thread(&pool, &thread).await.expect("insert");
570
571 let mut thread2 = sample_thread();
572 thread2.created_at = "2026-02-25T10:00:00Z".to_string();
573 insert_thread(&pool, &thread2).await.expect("insert");
574
575 let in_range = get_threads_in_range(&pool, "2026-02-19T00:00:00Z", "2026-02-21T00:00:00Z")
576 .await
577 .expect("range");
578 assert_eq!(in_range.len(), 1);
579
580 let all = get_threads_in_range(&pool, "2026-02-01T00:00:00Z", "2026-02-28T00:00:00Z")
581 .await
582 .expect("range");
583 assert_eq!(all.len(), 2);
584 }
585}