1use std::net::SocketAddr;
15use std::path::{Path, PathBuf};
16
17#[cfg(feature = "axum-server")]
18pub mod server;
19
20use anyhow::{Context, Result, anyhow};
21use serde::{Deserialize, Serialize};
22use tokio::net::TcpListener;
23
24pub async fn bind_with_auto_port(addr: SocketAddr, max_attempts: u16) -> Result<TcpListener> {
40 use std::io::ErrorKind;
41 let mut current = addr;
42 for attempt in 0..=max_attempts {
43 match TcpListener::bind(current).await {
44 Ok(l) => return Ok(l),
45 Err(e) if e.kind() == ErrorKind::AddrInUse && attempt < max_attempts => {
46 let next_port = current.port().saturating_add(1);
47 if next_port == 0 {
48 anyhow::bail!("ran out of ports while searching for free slot");
49 }
50 tracing::warn!("port {} in use, trying {}", current.port(), next_port);
51 current.set_port(next_port);
52 }
53 Err(e) => return Err(e.into()),
54 }
55 }
56 anyhow::bail!("could not find free port after {max_attempts} attempts")
57}
58
59pub fn resolve_data_dir(app_name: &str) -> Result<PathBuf> {
71 let base = dirs::data_dir()
72 .or_else(|| dirs::home_dir().map(|h| h.join(format!(".{app_name}"))))
73 .context("could not resolve data directory or home directory")?;
74 let dir = if base.ends_with(format!(".{app_name}")) {
75 base
76 } else {
77 base.join(app_name)
78 };
79 std::fs::create_dir_all(&dir)
80 .with_context(|| format!("create data directory {}", dir.display()))?;
81 Ok(dir)
82}
83
84const DAEMON_ADDR_FILENAME: &str = "http_addr";
90
91pub fn write_daemon_addr(app_name: &str, addr: &str) -> Result<()> {
105 let dir = resolve_data_dir(app_name)?;
106 let path = dir.join(DAEMON_ADDR_FILENAME);
107 std::fs::write(&path, addr).with_context(|| format!("write daemon addr to {}", path.display()))
108}
109
110pub fn read_daemon_addr(app_name: &str) -> Result<Option<String>> {
121 let dir = resolve_data_dir(app_name)?;
122 let path = dir.join(DAEMON_ADDR_FILENAME);
123 match std::fs::read_to_string(&path) {
124 Ok(s) => Ok(Some(s.trim().to_string())),
125 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
126 Err(e) => Err(anyhow::Error::new(e))
127 .with_context(|| format!("read daemon addr from {}", path.display())),
128 }
129}
130
131pub fn init_tracing(verbose_count: u8) {
144 let default_filter = match verbose_count {
145 0 => "warn",
146 1 => "info",
147 2 => "debug",
148 _ => "trace",
149 };
150 let filter = tracing_subscriber::EnvFilter::try_from_default_env()
151 .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_filter));
152 let _ = tracing_subscriber::fmt()
154 .with_env_filter(filter)
155 .with_writer(std::io::stderr)
156 .with_target(false)
157 .try_init();
158}
159
160pub fn maybe_disable_color(no_color: bool) {
170 let env_says_no =
171 std::env::var("NO_COLOR").is_ok() || std::env::var("TERM").as_deref() == Ok("dumb");
172 if no_color || env_says_no {
173 colored::control::set_override(false);
174 }
175}
176
177const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
180const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
181const X_TITLE: &str = "trusty-common";
182const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
183const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120; #[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct ChatMessage {
195 pub role: String,
196 pub content: String,
197}
198
199#[derive(Debug, Serialize)]
200struct ChatRequest<'a> {
201 model: &'a str,
202 messages: &'a [ChatMessage],
203 stream: bool,
204}
205
206#[derive(Debug, Deserialize)]
207struct ChatResponse {
208 choices: Vec<Choice>,
209}
210
211#[derive(Debug, Deserialize)]
212struct Choice {
213 message: ResponseMessage,
214}
215
216#[derive(Debug, Deserialize)]
217struct ResponseMessage {
218 #[serde(default)]
219 content: String,
220}
221
222pub async fn openrouter_chat(
233 api_key: &str,
234 model: &str,
235 messages: Vec<ChatMessage>,
236) -> Result<String> {
237 if api_key.is_empty() {
238 return Err(anyhow!("openrouter api key is empty"));
239 }
240 let client = reqwest::Client::builder()
241 .connect_timeout(std::time::Duration::from_secs(
242 OPENROUTER_CONNECT_TIMEOUT_SECS,
243 ))
244 .timeout(std::time::Duration::from_secs(
245 OPENROUTER_REQUEST_TIMEOUT_SECS,
246 ))
247 .build()
248 .context("build reqwest client for openrouter_chat")?;
249 let body = ChatRequest {
250 model,
251 messages: &messages,
252 stream: false,
253 };
254 let resp = client
255 .post(OPENROUTER_URL)
256 .bearer_auth(api_key)
257 .header("HTTP-Referer", HTTP_REFERER)
258 .header("X-Title", X_TITLE)
259 .json(&body)
260 .send()
261 .await
262 .context("POST openrouter chat completions")?;
263 let status = resp.status();
264 if !status.is_success() {
265 let text = resp.text().await.unwrap_or_default();
266 return Err(anyhow!("openrouter HTTP {status}: {text}"));
267 }
268 let payload: ChatResponse = resp.json().await.context("decode openrouter response")?;
269 payload
270 .choices
271 .into_iter()
272 .next()
273 .map(|c| c.message.content)
274 .ok_or_else(|| anyhow!("openrouter returned no choices"))
275}
276
277pub async fn openrouter_chat_stream(
288 api_key: &str,
289 model: &str,
290 messages: Vec<ChatMessage>,
291 tx: tokio::sync::mpsc::Sender<String>,
292) -> Result<()> {
293 use futures_util::StreamExt;
294
295 if api_key.is_empty() {
296 return Err(anyhow!("openrouter api key is empty"));
297 }
298 let client = reqwest::Client::builder()
299 .connect_timeout(std::time::Duration::from_secs(
300 OPENROUTER_CONNECT_TIMEOUT_SECS,
301 ))
302 .timeout(std::time::Duration::from_secs(
303 OPENROUTER_REQUEST_TIMEOUT_SECS,
304 ))
305 .build()
306 .context("build reqwest client for openrouter_chat_stream")?;
307 let body = ChatRequest {
308 model,
309 messages: &messages,
310 stream: true,
311 };
312 let resp = client
313 .post(OPENROUTER_URL)
314 .bearer_auth(api_key)
315 .header("HTTP-Referer", HTTP_REFERER)
316 .header("X-Title", X_TITLE)
317 .json(&body)
318 .send()
319 .await
320 .context("POST openrouter chat completions (stream)")?;
321 let status = resp.status();
322 if !status.is_success() {
323 let text = resp.text().await.unwrap_or_default();
324 return Err(anyhow!("openrouter HTTP {status}: {text}"));
325 }
326
327 let mut buf = String::new();
328 let mut stream = resp.bytes_stream();
329 while let Some(chunk) = stream.next().await {
330 let bytes = chunk.context("read openrouter stream chunk")?;
331 let text = match std::str::from_utf8(&bytes) {
332 Ok(s) => s,
333 Err(_) => continue,
334 };
335 buf.push_str(text);
336
337 while let Some(idx) = buf.find('\n') {
338 let line: String = buf.drain(..=idx).collect();
339 let line = line.trim();
340 let Some(payload) = line.strip_prefix("data:").map(str::trim) else {
341 continue;
342 };
343 if payload.is_empty() || payload == "[DONE]" {
344 continue;
345 }
346 let v: serde_json::Value = match serde_json::from_str(payload) {
347 Ok(v) => v,
348 Err(_) => continue,
349 };
350 if let Some(delta) = v
351 .get("choices")
352 .and_then(|c| c.get(0))
353 .and_then(|c| c.get("delta"))
354 .and_then(|d| d.get("content"))
355 .and_then(|c| c.as_str())
356 {
357 if !delta.is_empty() && tx.send(delta.to_string()).await.is_err() {
358 return Ok(());
360 }
361 }
362 }
363 }
364 Ok(())
365}
366
367pub fn is_dir(path: &Path) -> bool {
376 path.metadata().map(|m| m.is_dir()).unwrap_or(false)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[tokio::test]
384 async fn auto_port_walks_forward() {
385 let occupied = TcpListener::bind("127.0.0.1:0").await.unwrap();
387 let port = occupied.local_addr().unwrap().port();
388 let addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap();
389 let next = bind_with_auto_port(addr, 8).await.unwrap();
390 let got = next.local_addr().unwrap().port();
391 assert_ne!(got, port, "expected walk-forward to a different port");
392 }
393
394 #[tokio::test]
395 async fn auto_port_zero_attempts_still_binds_free() {
396 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
397 let l = bind_with_auto_port(addr, 0).await.unwrap();
398 assert!(l.local_addr().unwrap().port() > 0);
399 }
400
401 #[test]
402 fn resolve_data_dir_creates_directory() {
403 let tmp = tempfile_like_dir();
405 unsafe {
407 std::env::set_var("HOME", &tmp);
408 std::env::set_var("XDG_DATA_HOME", tmp.join("share"));
411 }
412 let dir = resolve_data_dir("trusty-test-xyz").unwrap();
413 assert!(
414 dir.exists(),
415 "data dir should be created at {}",
416 dir.display()
417 );
418 assert!(dir.is_dir());
419 }
420
421 #[test]
422 fn daemon_addr_round_trips() {
423 let tmp = tempfile_like_dir();
424 unsafe {
426 std::env::set_var("HOME", &tmp);
427 std::env::set_var("XDG_DATA_HOME", tmp.join("share"));
428 }
429 let app = format!(
430 "trusty-test-daemon-{}-{}",
431 std::process::id(),
432 std::time::SystemTime::now()
433 .duration_since(std::time::UNIX_EPOCH)
434 .map(|d| d.as_nanos())
435 .unwrap_or(0)
436 );
437 write_daemon_addr(&app, "127.0.0.1:12345").unwrap();
438 let got = read_daemon_addr(&app).unwrap();
439 assert_eq!(got.as_deref(), Some("127.0.0.1:12345"));
440 }
441
442 #[test]
443 fn read_daemon_addr_missing_returns_none() {
444 let tmp = tempfile_like_dir();
445 unsafe {
447 std::env::set_var("HOME", &tmp);
448 std::env::set_var("XDG_DATA_HOME", tmp.join("share"));
449 }
450 let app = format!(
451 "trusty-test-daemon-missing-{}-{}",
452 std::process::id(),
453 std::time::SystemTime::now()
454 .duration_since(std::time::UNIX_EPOCH)
455 .map(|d| d.as_nanos())
456 .unwrap_or(0)
457 );
458 let got = read_daemon_addr(&app).unwrap();
459 assert!(got.is_none(), "expected None when file absent, got {got:?}");
460 }
461
462 #[test]
463 fn is_dir_recognises_directories() {
464 let tmp = tempfile_like_dir();
465 assert!(is_dir(&tmp));
466 assert!(!is_dir(&tmp.join("nope")));
467 }
468
469 #[test]
470 fn chat_message_round_trips() {
471 let m = ChatMessage {
472 role: "user".into(),
473 content: "hello".into(),
474 };
475 let s = serde_json::to_string(&m).unwrap();
476 let back: ChatMessage = serde_json::from_str(&s).unwrap();
477 assert_eq!(back.role, "user");
478 assert_eq!(back.content, "hello");
479 }
480
481 #[tokio::test]
482 async fn openrouter_chat_rejects_empty_key() {
483 let err = openrouter_chat("", "x", vec![]).await.unwrap_err();
484 assert!(err.to_string().contains("api key"));
485 }
486
487 fn tempfile_like_dir() -> PathBuf {
490 let pid = std::process::id();
491 let nanos = std::time::SystemTime::now()
492 .duration_since(std::time::UNIX_EPOCH)
493 .map(|d| d.as_nanos())
494 .unwrap_or(0);
495 let p = std::env::temp_dir().join(format!("trusty-common-test-{pid}-{nanos}"));
496 std::fs::create_dir_all(&p).unwrap();
497 p
498 }
499}