1use super::DbPool;
7use crate::error::StorageError;
8
9#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
11pub struct OriginalTweet {
12 pub id: i64,
14 pub tweet_id: Option<String>,
16 pub content: String,
18 pub topic: Option<String>,
20 pub llm_provider: Option<String>,
22 pub created_at: String,
24 pub status: String,
26 pub error_message: Option<String>,
28}
29
30#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
32pub struct Thread {
33 pub id: i64,
35 pub topic: String,
37 pub tweet_count: i64,
39 pub root_tweet_id: Option<String>,
41 pub created_at: String,
43 pub status: String,
45}
46
47#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize)]
49pub struct ThreadTweet {
50 pub id: i64,
52 pub thread_id: i64,
54 pub position: i64,
56 pub tweet_id: Option<String>,
58 pub content: String,
60 pub created_at: String,
62}
63
64pub async fn insert_original_tweet(
66 pool: &DbPool,
67 tweet: &OriginalTweet,
68) -> Result<i64, StorageError> {
69 let result = sqlx::query(
70 "INSERT INTO original_tweets \
71 (tweet_id, content, topic, llm_provider, created_at, status, error_message) \
72 VALUES (?, ?, ?, ?, ?, ?, ?)",
73 )
74 .bind(&tweet.tweet_id)
75 .bind(&tweet.content)
76 .bind(&tweet.topic)
77 .bind(&tweet.llm_provider)
78 .bind(&tweet.created_at)
79 .bind(&tweet.status)
80 .bind(&tweet.error_message)
81 .execute(pool)
82 .await
83 .map_err(|e| StorageError::Query { source: e })?;
84
85 Ok(result.last_insert_rowid())
86}
87
88pub async fn get_last_original_tweet_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
90 let row: Option<(String,)> = sqlx::query_as(
91 "SELECT created_at FROM original_tweets WHERE status = 'sent' \
92 ORDER BY created_at DESC LIMIT 1",
93 )
94 .fetch_optional(pool)
95 .await
96 .map_err(|e| StorageError::Query { source: e })?;
97
98 Ok(row.map(|r| r.0))
99}
100
101pub async fn count_tweets_today(pool: &DbPool) -> Result<i64, StorageError> {
103 let row: (i64,) =
104 sqlx::query_as("SELECT COUNT(*) FROM original_tweets WHERE date(created_at) = date('now')")
105 .fetch_one(pool)
106 .await
107 .map_err(|e| StorageError::Query { source: e })?;
108
109 Ok(row.0)
110}
111
112pub async fn insert_thread(pool: &DbPool, thread: &Thread) -> Result<i64, StorageError> {
114 let result = sqlx::query(
115 "INSERT INTO threads (topic, tweet_count, root_tweet_id, created_at, status) \
116 VALUES (?, ?, ?, ?, ?)",
117 )
118 .bind(&thread.topic)
119 .bind(thread.tweet_count)
120 .bind(&thread.root_tweet_id)
121 .bind(&thread.created_at)
122 .bind(&thread.status)
123 .execute(pool)
124 .await
125 .map_err(|e| StorageError::Query { source: e })?;
126
127 Ok(result.last_insert_rowid())
128}
129
130pub async fn insert_thread_tweets(
134 pool: &DbPool,
135 thread_id: i64,
136 tweets: &[ThreadTweet],
137) -> Result<(), StorageError> {
138 let mut tx = pool
139 .begin()
140 .await
141 .map_err(|e| StorageError::Connection { source: e })?;
142
143 for tweet in tweets {
144 sqlx::query(
145 "INSERT INTO thread_tweets (thread_id, position, tweet_id, content, created_at) \
146 VALUES (?, ?, ?, ?, ?)",
147 )
148 .bind(thread_id)
149 .bind(tweet.position)
150 .bind(&tweet.tweet_id)
151 .bind(&tweet.content)
152 .bind(&tweet.created_at)
153 .execute(&mut *tx)
154 .await
155 .map_err(|e| StorageError::Query { source: e })?;
156 }
157
158 tx.commit()
159 .await
160 .map_err(|e| StorageError::Connection { source: e })?;
161
162 Ok(())
163}
164
165pub async fn get_last_thread_time(pool: &DbPool) -> Result<Option<String>, StorageError> {
167 let row: Option<(String,)> = sqlx::query_as(
168 "SELECT created_at FROM threads WHERE status = 'sent' \
169 ORDER BY created_at DESC LIMIT 1",
170 )
171 .fetch_optional(pool)
172 .await
173 .map_err(|e| StorageError::Query { source: e })?;
174
175 Ok(row.map(|r| r.0))
176}
177
178pub async fn get_todays_tweet_times(pool: &DbPool) -> Result<Vec<String>, StorageError> {
180 let rows: Vec<(String,)> = sqlx::query_as(
181 "SELECT created_at FROM original_tweets \
182 WHERE status = 'sent' AND date(created_at) = date('now') \
183 ORDER BY created_at ASC",
184 )
185 .fetch_all(pool)
186 .await
187 .map_err(|e| StorageError::Query { source: e })?;
188
189 Ok(rows.into_iter().map(|r| r.0).collect())
190}
191
192pub async fn count_threads_this_week(pool: &DbPool) -> Result<i64, StorageError> {
194 let row: (i64,) = sqlx::query_as(
195 "SELECT COUNT(*) FROM threads \
196 WHERE strftime('%Y-%W', created_at) = strftime('%Y-%W', 'now')",
197 )
198 .fetch_one(pool)
199 .await
200 .map_err(|e| StorageError::Query { source: e })?;
201
202 Ok(row.0)
203}
204
205pub async fn get_recent_original_tweets(
207 pool: &DbPool,
208 limit: u32,
209) -> Result<Vec<OriginalTweet>, StorageError> {
210 sqlx::query_as::<_, OriginalTweet>(
211 "SELECT * FROM original_tweets ORDER BY created_at DESC LIMIT ?",
212 )
213 .bind(limit)
214 .fetch_all(pool)
215 .await
216 .map_err(|e| StorageError::Query { source: e })
217}
218
219pub async fn get_recent_threads(pool: &DbPool, limit: u32) -> Result<Vec<Thread>, StorageError> {
221 sqlx::query_as::<_, Thread>("SELECT * FROM threads ORDER BY created_at DESC LIMIT ?")
222 .bind(limit)
223 .fetch_all(pool)
224 .await
225 .map_err(|e| StorageError::Query { source: e })
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::storage::init_test_db;
232
233 fn now_iso() -> String {
234 chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()
235 }
236
237 fn sample_original_tweet() -> OriginalTweet {
238 OriginalTweet {
239 id: 0,
240 tweet_id: Some("ot_123".to_string()),
241 content: "Educational tweet about Rust".to_string(),
242 topic: Some("rust".to_string()),
243 llm_provider: Some("openai".to_string()),
244 created_at: now_iso(),
245 status: "sent".to_string(),
246 error_message: None,
247 }
248 }
249
250 fn sample_thread() -> Thread {
251 Thread {
252 id: 0,
253 topic: "Rust async patterns".to_string(),
254 tweet_count: 3,
255 root_tweet_id: Some("root_456".to_string()),
256 created_at: now_iso(),
257 status: "sent".to_string(),
258 }
259 }
260
261 fn sample_thread_tweets(thread_id: i64) -> Vec<ThreadTweet> {
262 (0..3)
263 .map(|i| ThreadTweet {
264 id: 0,
265 thread_id,
266 position: i,
267 tweet_id: Some(format!("tt_{i}")),
268 content: format!("Thread tweet {i}"),
269 created_at: now_iso(),
270 })
271 .collect()
272 }
273
274 #[tokio::test]
275 async fn insert_and_query_original_tweet() {
276 let pool = init_test_db().await.expect("init db");
277 let tweet = sample_original_tweet();
278
279 let id = insert_original_tweet(&pool, &tweet).await.expect("insert");
280 assert!(id > 0);
281
282 let time = get_last_original_tweet_time(&pool).await.expect("get time");
283 assert!(time.is_some());
284 }
285
286 #[tokio::test]
287 async fn count_tweets_today_works() {
288 let pool = init_test_db().await.expect("init db");
289 let tweet = sample_original_tweet();
290
291 insert_original_tweet(&pool, &tweet).await.expect("insert");
292 let count = count_tweets_today(&pool).await.expect("count");
293 assert_eq!(count, 1);
294 }
295
296 #[tokio::test]
297 async fn insert_thread_with_tweets() {
298 let pool = init_test_db().await.expect("init db");
299 let thread = sample_thread();
300
301 let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
302 let tweets = sample_thread_tweets(thread_id);
303 insert_thread_tweets(&pool, thread_id, &tweets)
304 .await
305 .expect("insert tweets");
306
307 let rows: Vec<(i64,)> = sqlx::query_as(
309 "SELECT position FROM thread_tweets WHERE thread_id = ? ORDER BY position",
310 )
311 .bind(thread_id)
312 .fetch_all(&pool)
313 .await
314 .expect("query");
315
316 assert_eq!(rows.len(), 3);
317 assert_eq!(rows[0].0, 0);
318 assert_eq!(rows[1].0, 1);
319 assert_eq!(rows[2].0, 2);
320 }
321
322 #[tokio::test]
323 async fn thread_tweet_duplicate_position_fails() {
324 let pool = init_test_db().await.expect("init db");
325 let thread = sample_thread();
326
327 let thread_id = insert_thread(&pool, &thread).await.expect("insert thread");
328
329 let duplicate_tweets = vec![
331 ThreadTweet {
332 id: 0,
333 thread_id,
334 position: 0,
335 tweet_id: Some("a".to_string()),
336 content: "First".to_string(),
337 created_at: now_iso(),
338 },
339 ThreadTweet {
340 id: 0,
341 thread_id,
342 position: 0, tweet_id: Some("b".to_string()),
344 content: "Second".to_string(),
345 created_at: now_iso(),
346 },
347 ];
348
349 let result = insert_thread_tweets(&pool, thread_id, &duplicate_tweets).await;
350 assert!(result.is_err());
351
352 let rows: Vec<(i64,)> =
354 sqlx::query_as("SELECT COUNT(*) FROM thread_tweets WHERE thread_id = ?")
355 .bind(thread_id)
356 .fetch_all(&pool)
357 .await
358 .expect("query");
359
360 assert_eq!(rows[0].0, 0, "transaction should have rolled back");
361 }
362
363 #[tokio::test]
364 async fn count_threads_this_week_works() {
365 let pool = init_test_db().await.expect("init db");
366 let thread = sample_thread();
367
368 insert_thread(&pool, &thread).await.expect("insert");
369 let count = count_threads_this_week(&pool).await.expect("count");
370 assert_eq!(count, 1);
371 }
372
373 #[tokio::test]
374 async fn last_thread_time_empty() {
375 let pool = init_test_db().await.expect("init db");
376 let time = get_last_thread_time(&pool).await.expect("get time");
377 assert!(time.is_none());
378 }
379}