1use std::sync::Arc;
7
8use axum::extract::State;
9use axum::http::StatusCode;
10use axum::Json;
11use serde::{Deserialize, Serialize};
12
13use tuitbot_core::content::ContentGenerator;
14use tuitbot_core::context::winning_dna;
15use tuitbot_core::storage;
16
17use crate::account::AccountContext;
18use crate::error::ApiError;
19use crate::state::AppState;
20
21async fn get_generator(
26 state: &AppState,
27 account_id: &str,
28) -> Result<Arc<ContentGenerator>, ApiError> {
29 state
30 .get_or_create_content_generator(account_id)
31 .await
32 .map_err(ApiError::BadRequest)
33}
34
35async fn resolve_composer_rag_context(state: &AppState, account_id: &str) -> Option<String> {
42 let config = match state.load_effective_config(account_id).await {
43 Ok(c) => c,
44 Err(e) => {
45 tracing::warn!("composer RAG: failed to load config: {e}");
46 return None;
47 }
48 };
49
50 let keywords = config.business.draft_context_keywords();
51 if keywords.is_empty() {
52 return None;
53 }
54
55 let draft_context = match winning_dna::build_draft_context(
56 &state.db,
57 &keywords,
58 winning_dna::MAX_ANCESTORS,
59 winning_dna::RECENCY_HALF_LIFE_DAYS,
60 )
61 .await
62 {
63 Ok(ctx) => ctx,
64 Err(e) => {
65 tracing::warn!("composer RAG: failed to build draft context: {e}");
66 return None;
67 }
68 };
69
70 if draft_context.prompt_block.is_empty() {
71 None
72 } else {
73 Some(draft_context.prompt_block)
74 }
75}
76
77#[derive(Deserialize)]
82pub struct AssistTweetRequest {
83 pub topic: String,
84}
85
86#[derive(Serialize)]
87pub struct AssistTweetResponse {
88 pub content: String,
89 pub topic: String,
90}
91
92pub async fn assist_tweet(
93 State(state): State<Arc<AppState>>,
94 ctx: AccountContext,
95 Json(body): Json<AssistTweetRequest>,
96) -> Result<Json<AssistTweetResponse>, ApiError> {
97 let gen = get_generator(&state, &ctx.account_id).await?;
98 let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
99
100 let output = gen
101 .generate_tweet_with_context(&body.topic, None, rag_context.as_deref())
102 .await
103 .map_err(|e| ApiError::Internal(e.to_string()))?;
104
105 Ok(Json(AssistTweetResponse {
106 content: output.text,
107 topic: body.topic,
108 }))
109}
110
111#[derive(Deserialize)]
116pub struct AssistReplyRequest {
117 pub tweet_text: String,
118 pub tweet_author: String,
119 #[serde(default)]
120 pub mention_product: bool,
121}
122
123#[derive(Serialize)]
124pub struct AssistReplyResponse {
125 pub content: String,
126}
127
128pub async fn assist_reply(
129 State(state): State<Arc<AppState>>,
130 ctx: AccountContext,
131 Json(body): Json<AssistReplyRequest>,
132) -> Result<Json<AssistReplyResponse>, ApiError> {
133 let gen = get_generator(&state, &ctx.account_id).await?;
134
135 let output = gen
136 .generate_reply(&body.tweet_text, &body.tweet_author, body.mention_product)
137 .await
138 .map_err(|e| ApiError::Internal(e.to_string()))?;
139
140 Ok(Json(AssistReplyResponse {
141 content: output.text,
142 }))
143}
144
145#[derive(Deserialize)]
150pub struct AssistThreadRequest {
151 pub topic: String,
152}
153
154#[derive(Serialize)]
155pub struct AssistThreadResponse {
156 pub tweets: Vec<String>,
157 pub topic: String,
158}
159
160pub async fn assist_thread(
161 State(state): State<Arc<AppState>>,
162 ctx: AccountContext,
163 Json(body): Json<AssistThreadRequest>,
164) -> Result<Json<AssistThreadResponse>, ApiError> {
165 let gen = get_generator(&state, &ctx.account_id).await?;
166 let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
167
168 let output = gen
169 .generate_thread_with_context(&body.topic, None, rag_context.as_deref())
170 .await
171 .map_err(|e| ApiError::Internal(e.to_string()))?;
172
173 Ok(Json(AssistThreadResponse {
174 tweets: output.tweets,
175 topic: body.topic,
176 }))
177}
178
179#[derive(Deserialize)]
184pub struct AssistImproveRequest {
185 pub draft: String,
186 #[serde(default)]
187 pub context: Option<String>,
188}
189
190#[derive(Serialize)]
191pub struct AssistImproveResponse {
192 pub content: String,
193}
194
195pub async fn assist_improve(
196 State(state): State<Arc<AppState>>,
197 ctx: AccountContext,
198 Json(body): Json<AssistImproveRequest>,
199) -> Result<Json<AssistImproveResponse>, ApiError> {
200 let gen = get_generator(&state, &ctx.account_id).await?;
201 let rag_context = resolve_composer_rag_context(&state, &ctx.account_id).await;
202
203 let output = gen
204 .improve_draft_with_context(&body.draft, body.context.as_deref(), rag_context.as_deref())
205 .await
206 .map_err(|e| ApiError::Internal(e.to_string()))?;
207
208 Ok(Json(AssistImproveResponse {
209 content: output.text,
210 }))
211}
212
213#[derive(Serialize)]
218pub struct AssistTopicsResponse {
219 pub topics: Vec<TopicRecommendation>,
220}
221
222#[derive(Serialize)]
223pub struct TopicRecommendation {
224 pub topic: String,
225 pub score: f64,
226}
227
228pub async fn assist_topics(
229 State(state): State<Arc<AppState>>,
230 ctx: AccountContext,
231) -> Result<Json<AssistTopicsResponse>, ApiError> {
232 let top = storage::analytics::get_top_topics_for(&state.db, &ctx.account_id, 10).await?;
233
234 let topics = top
235 .into_iter()
236 .map(|cs| TopicRecommendation {
237 topic: cs.topic,
238 score: cs.avg_performance,
239 })
240 .collect();
241
242 Ok(Json(AssistTopicsResponse { topics }))
243}
244
245#[derive(Serialize)]
250pub struct OptimalTimesResponse {
251 pub times: Vec<OptimalTime>,
252}
253
254#[derive(Serialize)]
255pub struct OptimalTime {
256 pub hour: u32,
257 pub avg_engagement: f64,
258 pub post_count: i64,
259}
260
261pub async fn assist_optimal_times(
262 State(state): State<Arc<AppState>>,
263 ctx: AccountContext,
264) -> Result<Json<OptimalTimesResponse>, ApiError> {
265 let rows =
266 storage::analytics::get_optimal_posting_times_for(&state.db, &ctx.account_id).await?;
267
268 let times = rows
269 .into_iter()
270 .map(|r| OptimalTime {
271 hour: r.hour as u32,
272 avg_engagement: r.avg_engagement,
273 post_count: r.post_count,
274 })
275 .collect();
276
277 Ok(Json(OptimalTimesResponse { times }))
278}
279
280#[derive(Serialize)]
285pub struct ModeResponse {
286 pub mode: String,
287 pub approval_mode: bool,
288}
289
290pub async fn get_mode(
291 State(state): State<Arc<AppState>>,
292 ctx: AccountContext,
293) -> Result<(StatusCode, Json<ModeResponse>), ApiError> {
294 let config = crate::routes::content::read_effective_config(&state, &ctx.account_id).await?;
295
296 Ok((
297 StatusCode::OK,
298 Json(ModeResponse {
299 mode: config.mode.to_string(),
300 approval_mode: config.effective_approval_mode(),
301 }),
302 ))
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 use std::collections::HashMap;
310 use std::path::PathBuf;
311
312 use tokio::sync::{broadcast, Mutex, RwLock};
313
314 use crate::ws::AccountWsEvent;
315
316 async fn test_state(config_path: PathBuf) -> AppState {
318 let db = tuitbot_core::storage::init_test_db()
319 .await
320 .expect("init test db");
321 let (event_tx, _) = broadcast::channel::<AccountWsEvent>(16);
322 AppState {
323 db,
324 config_path: config_path.clone(),
325 data_dir: config_path.parent().unwrap_or(&config_path).to_path_buf(),
326 event_tx,
327 api_token: "test-token".to_string(),
328 passphrase_hash: RwLock::new(None),
329 passphrase_hash_mtime: RwLock::new(None),
330 bind_host: "127.0.0.1".to_string(),
331 bind_port: 3001,
332 login_attempts: Mutex::new(HashMap::new()),
333 runtimes: Mutex::new(HashMap::new()),
334 content_generators: Mutex::new(HashMap::new()),
335 circuit_breaker: None,
336 watchtower_cancel: None,
337 content_sources: Default::default(),
338 connector_config: Default::default(),
339 deployment_mode: Default::default(),
340
341 pending_oauth: Mutex::new(HashMap::new()),
342 token_managers: Mutex::new(HashMap::new()),
343 x_client_id: String::new(),
344 }
345 }
346
347 #[tokio::test]
348 async fn resolve_rag_returns_none_when_config_missing() {
349 let state = test_state(PathBuf::from("/nonexistent/config.toml")).await;
350 let result = resolve_composer_rag_context(&state, "test-account").await;
351 assert!(
352 result.is_none(),
353 "should return None when config is missing"
354 );
355 }
356
357 #[tokio::test]
358 async fn resolve_rag_returns_none_when_db_empty() {
359 let dir = tempfile::tempdir().expect("create temp dir");
360 let config_path = dir.path().join("config.toml");
361 std::fs::write(
362 &config_path,
363 "[business]\nproduct_name = \"TestProduct\"\nproduct_keywords = [\"rust\", \"testing\"]\n",
364 )
365 .expect("write config");
366
367 let state = test_state(config_path).await;
368 let result = resolve_composer_rag_context(&state, "test-account").await;
369 assert!(
370 result.is_none(),
371 "should return None when DB has no ancestor data"
372 );
373 }
374
375 #[tokio::test]
376 async fn resolve_rag_returns_none_when_no_keywords() {
377 let dir = tempfile::tempdir().expect("create temp dir");
378 let config_path = dir.path().join("config.toml");
379 std::fs::write(&config_path, "[business]\nproduct_name = \"Empty\"\n")
381 .expect("write config");
382
383 let state = test_state(config_path).await;
384 let result = resolve_composer_rag_context(&state, "test-account").await;
385 assert!(
386 result.is_none(),
387 "should return None when keywords are empty"
388 );
389 }
390}