1use axum::{
2 extract::State,
3 http::StatusCode,
4 response::IntoResponse,
5 routing::{get, post},
6 Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23 #[clap(long, short, default_value = "3000")]
25 pub port: u16,
26
27 #[clap(long)]
29 pub debug: bool,
30
31 #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33 pub model: String,
34
35 #[clap(long, default_value = "0.0.0.0")]
37 pub host: String,
38}
39
40#[derive(Clone)]
42pub struct AppState {
43 pub claude: Arc<Claude>,
44 pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51 id: String,
52 object: &'static str,
53 created: u64,
54 model: String,
55 system_fingerprint: String,
56 choices: Vec<Choice>,
57 usage: Usage,
58}
59
60#[derive(Serialize, Debug)]
62pub struct Usage {
63 prompt_tokens: u32,
64 completion_tokens: u32,
65 total_tokens: u32,
66}
67
68#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71 role: String,
72 content: ContentValue,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79 String(String),
80 Array(Vec<String>),
81 Object(ContentObject),
82 ContentBlocks(Vec<ContentBlock>),
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88 #[serde(default)]
89 pub text: Option<String>,
90 #[serde(default)]
91 pub r#type: Option<String>,
92 #[serde(flatten)]
94 pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100 pub r#type: String,
101 #[serde(default)]
102 pub text: Option<String>,
103 #[serde(flatten)]
105 pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109 pub fn content_as_string(&self) -> String {
111 match &self.content {
112 ContentValue::String(content) => content.clone(),
113 ContentValue::Array(content) => content.join(" "),
114 ContentValue::Object(obj) => {
115 if let Some(text) = &obj.text {
116 text.clone()
117 } else {
118 serde_json::to_string(obj).unwrap_or_default()
120 }
121 }
122 ContentValue::ContentBlocks(blocks) => blocks
123 .iter()
124 .filter_map(|block| block.text.as_ref())
125 .cloned()
126 .collect::<Vec<_>>()
127 .join(" "),
128 }
129 }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134 role: String,
135 content: String,
136}
137
138impl From<MessageInner> for Message {
139 fn from(inner: MessageInner) -> Self {
140 Self {
141 role: inner.role,
142 content: ContentValue::String(inner.content),
143 }
144 }
145}
146
147mod content_string_or_array {
152 use serde::{self, Deserialize, Deserializer, Serializer};
153
154 pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
155 where
156 S: Serializer,
157 {
158 serializer.serialize_str(content)
159 }
160
161 pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
162 where
163 D: Deserializer<'de>,
164 {
165 #[derive(Deserialize)]
166 #[serde(untagged)]
167 enum StringOrArray {
168 String(()),
169 Array(()),
170 }
171
172 let value = StringOrArray::deserialize(deserializer)?;
173 match value {
174 StringOrArray::String(_) => Ok(String::new()),
175 StringOrArray::Array(_) => Ok(String::new()),
176 }
177 }
178}
179
180#[derive(Deserialize, Serialize, Debug)]
182pub struct ChatCompletionRequest {
183 model: String,
184 messages: Vec<Message>,
185}
186
187#[derive(Serialize, Debug)]
189pub struct Choice {
190 index: i32,
191 message: Message,
192 logprobs: Option<serde_json::Value>,
193 finish_reason: String,
194}
195
196fn get_model(model_name: &str) -> &'static str {
198 match model_name {
199 "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
200 "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
201 "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
202 _ => crate::config::SONNET_MODEL,
203 }
204}
205
206fn _normalize_role(role: &str) -> String {
208 match role {
209 "assistant" => "assistant".to_string(),
210 "user" => "human".to_string(),
211 "system" | "developer" => "system".to_string(),
212 _ => "human".to_string(), }
214}
215
216fn _messages_to_prompt(messages: &[Message]) -> String {
218 let mut prompt = String::new();
219
220 for msg in messages {
221 let role = _normalize_role(&msg.role);
222 let content = msg.content_as_string();
223 match role.as_str() {
224 "system" => prompt.push_str(&format!("System: {}\n\n", content)),
225 "human" => prompt.push_str(&format!("Human: {}\n\n", content)),
226 "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
227 _ => prompt.push_str(&format!("{}: {}\n\n", role, content)),
228 }
229 }
230
231 prompt
232}
233
234fn conversation_key(messages: &[Message]) -> String {
237 messages
239 .iter()
240 .find(|m| m.role == "user")
241 .map(|m| m.content_as_string())
242 .unwrap_or_default()
243}
244
245async fn chat_completion_handler(
247 State(state): State<AppState>,
248 Json(request): Json<ChatCompletionRequest>,
249) -> impl IntoResponse {
250 for (i, msg) in request.messages.iter().enumerate() {
256 match &msg.content {
257 ContentValue::Array(_) => info!("Message at index {} has array content", i),
258 ContentValue::Object(_) => info!("Message at index {} has object content", i),
259 ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
260 _ => {} }
262 }
263
264 let model_id = get_model(&request.model);
265
266 if request.messages.is_empty() {
267 return (
268 StatusCode::BAD_REQUEST,
269 Json(serde_json::json!({
270 "error": "No messages provided in the request"
271 })),
272 );
273 }
274
275 let messages = request.messages;
276 let conv_key = conversation_key(&messages);
277
278 let user_query = messages.last().unwrap().clone();
280
281 let claude = state.claude;
282 let conversations = state.conversations;
283
284 let mut conversation_id: Option<String> = None;
286 let mut message_history: Vec<Message> = Vec::new();
287
288 info!("Looking for conversation key: {}", conv_key);
289 info!("Available conversation keys:");
290 for item in conversations.iter() {
291 info!(" Key: '{}', {} chars", item.key(), item.key().len());
292 }
293
294 if let Some(existing) = conversations.get(&conv_key) {
295 conversation_id = Some(existing.0.clone());
296 message_history = existing.1.clone();
297 info!("Found conversation for key: '{}'", conv_key);
298 } else {
299 info!("No existing conversation found for key: '{}'", conv_key);
300 }
301
302 if conversation_id.is_none() {
304 match claude.create_chat().await {
305 Ok(new_id) => {
306 conversation_id = Some(new_id);
307 message_history = messages[..messages.len() - 1].to_vec();
308
309 conversations.insert(
311 conv_key.clone(),
312 (conversation_id.clone().unwrap(), message_history.clone()),
313 );
314
315 if let Some(system_msg) = messages
317 .iter()
318 .find(|m| m.role == "system" || m.role == "developer")
319 {
320 if let Err(e) = claude
321 .send_message(
322 &conversation_id.clone().unwrap(),
323 &system_msg.content_as_string(),
324 &[],
325 )
326 .await
327 {
328 warn!("Failed to send system prompt: {}", e);
329 }
330 }
331 }
332 Err(e) => {
333 let error_response = serde_json::json!({
334 "error": format!("Failed to create conversation: {}", e)
335 });
336 dbg!(&error_response);
337 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
338 }
339 }
340 }
341
342 let completion = match claude
344 .send_message(
345 &conversation_id.clone().unwrap(),
346 &user_query.content_as_string(),
347 &[],
348 )
349 .await
350 {
351 Ok(response) => response,
352 Err(e) => {
353 let error_response = serde_json::json!({
354 "error": format!("Failed to get completion: {}", e)
355 });
356 dbg!(&error_response);
357 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
358 }
359 };
360
361 let new_message = Message {
363 role: "assistant".to_string(),
364 content: ContentValue::String(completion.clone()),
365 };
366
367 let mut updated_history = message_history.clone();
369 updated_history.push(user_query.clone());
370 updated_history.push(new_message);
371
372 if let Some(_) = conversations.remove(&conv_key) {
374 info!("Removed old conversation entry with key: '{}'", conv_key);
375 }
376
377 let new_conv_key = conversation_key(&updated_history);
379 let conv_id = conversation_id.unwrap();
380
381 conversations.insert(
383 new_conv_key.clone(),
384 (conv_id.clone(), updated_history.clone()),
385 );
386 info!(
388 "Stored conversation with ID: {} under new key '{}', {} chars",
389 conv_id,
390 new_conv_key,
391 new_conv_key.len()
392 );
393
394 let response = OpenAIResponse {
396 id: Uuid::new_v4().to_string(),
397 object: "chat.completion",
398 created: chrono::Utc::now().timestamp() as u64,
399 model: model_id.to_string(),
400 system_fingerprint: "fp_toast".to_string(),
401 choices: vec![Choice {
402 index: 0,
403 message: Message {
404 role: "assistant".to_string(),
405 content: ContentValue::String(completion.clone()),
406 },
407 logprobs: None,
408 finish_reason: "stop".to_string(),
409 }],
410 usage: Usage {
411 prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, completion_tokens: completion.chars().count() as u32 / 4, total_tokens: (user_query.content_as_string().chars().count()
414 + completion.chars().count()) as u32
415 / 4,
416 },
417 };
418
419 (
423 StatusCode::OK,
424 Json(serde_json::to_value(response).unwrap()),
425 )
426}
427
428async fn health_check() -> impl IntoResponse {
430 (StatusCode::OK, "OK")
431}
432
433pub async fn run() -> anyhow::Result<()> {
435 run_with_args(Args::parse()).await
437}
438
439pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
441 if args.debug {
445 tracing_subscriber::fmt::init();
447 } else {
448 tracing_subscriber::fmt()
450 .with_max_level(tracing::Level::WARN)
451 .init();
452 }
453
454 let config_dir = dirs::config_dir()
456 .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
457 .join("toast");
458
459 let cookie_path = config_dir.join("cookie");
460 let org_id_path = config_dir.join("org_id");
461
462 if !config_dir.exists() {
464 return Err(anyhow::anyhow!(
465 "Configuration directory does not exist at {:?}",
466 config_dir
467 ));
468 }
469
470 let cookie = if cookie_path.exists() {
472 std::fs::read_to_string(&cookie_path)?.trim().to_string()
473 } else {
474 return Err(anyhow::anyhow!(
475 "Cookie file not found at {:?}",
476 cookie_path
477 ));
478 };
479
480 let org_id = if org_id_path.exists() {
482 std::fs::read_to_string(&org_id_path)?.trim().to_string()
483 } else {
484 if let Some(extracted_org_id) = crate::cli::extract_org_id_from_cookie(&cookie) {
486 std::fs::write(&org_id_path, &extracted_org_id)?;
488 info!(
489 "Extracted organization ID from cookie and saved to {:?}",
490 org_id_path
491 );
492 extracted_org_id
493 } else {
494 return Err(anyhow::anyhow!(
495 "Organization ID file not found at {:?} and couldn't extract it from cookie",
496 org_id_path
497 ));
498 }
499 };
500
501 let user_agent =
502 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
503 .to_string();
504
505 let session = Session {
506 cookie,
507 user_agent,
508 organization_id: org_id,
509 };
510
511 let model_str = Box::leak(args.model.into_boxed_str());
513
514 let claude = Arc::new(Claude::new(session.clone(), model_str)?);
515
516 let app_state = AppState {
518 claude,
519 conversations: Arc::new(DashMap::new()),
520 };
521
522 let cors = CorsLayer::new()
524 .allow_origin(Any)
525 .allow_methods(Any)
526 .allow_headers(Any);
527
528 let app = Router::new()
530 .route("/health", get(health_check))
531 .route("/v1/chat/completions", post(chat_completion_handler))
532 .route("/", post(chat_completion_handler))
533 .route("/chat/completions", post(chat_completion_handler))
534 .layer(cors)
535 .with_state(app_state);
536
537 let host_parts: Vec<u8> = args
539 .host
540 .split('.')
541 .map(|s| s.parse::<u8>().unwrap_or(0))
542 .collect();
543
544 let host_addr = if host_parts.len() == 4 {
545 [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
546 } else {
547 [0, 0, 0, 0]
548 };
549
550 let addr = SocketAddr::from((host_addr, args.port));
552 info!("Server listening on {}", addr);
553
554 info!("Example request:");
556 info!(
557 "curl -X POST http://localhost:{}{} \\",
558 args.port, "/v1/chat/completions # Also available at / and /chat/completions"
559 );
560 info!(" -H \"Content-Type: application/json\" \\");
561 info!(" -d '{{");
562 info!(" \"model\": \"gpt-4.1\",");
563 info!(" \"messages\": [");
564 info!(" {{");
565 info!(" \"role\": \"developer\",");
566 info!(" \"content\": \"You are a helpful assistant.\"");
567 info!(" }},");
568 info!(" {{");
569 info!(" \"role\": \"user\",");
570 info!(" \"content\": \"Hello! (Simple string content)\"");
571 info!(" }},");
572 info!(" {{");
573 info!(" \"role\": \"user\",");
574 info!(" \"content\": [\"Array of\", \"string content\"]");
575 info!(" }},");
576 info!(" {{");
577 info!(" \"role\": \"user\",");
578 info!(" \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
579 info!(" }},");
580 info!(" {{");
581 info!(" \"role\": \"user\",");
582 info!(" \"content\": [");
583 info!(" {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
584 info!(" {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
585 info!(" ]");
586 info!(" }}");
587 info!(" ]");
588 info!(" }}'");
589
590 let listener = tokio::net::TcpListener::bind(addr).await?;
591 axum::serve(listener, app).await?;
592
593 Ok(())
594}