1use axum::{
2 extract::State,
3 http::StatusCode,
4 response::IntoResponse,
5 routing::{get, post},
6 Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23 #[clap(long, short, default_value = "3000")]
25 pub port: u16,
26
27 #[clap(long)]
29 pub debug: bool,
30
31 #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33 pub model: String,
34
35 #[clap(long, default_value = "0.0.0.0")]
37 pub host: String,
38}
39
40#[derive(Clone)]
42pub struct AppState {
43 pub claude: Arc<Claude>,
44 pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51 id: String,
52 object: &'static str,
53 created: u64,
54 model: String,
55 system_fingerprint: String,
56 choices: Vec<Choice>,
57 usage: Usage,
58}
59
60#[derive(Serialize, Debug)]
62pub struct Usage {
63 prompt_tokens: u32,
64 completion_tokens: u32,
65 total_tokens: u32,
66}
67
68#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71 role: String,
72 content: ContentValue,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79 String(String),
80 Array(Vec<String>),
81 Object(ContentObject),
82 ContentBlocks(Vec<ContentBlock>),
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88 #[serde(default)]
89 pub text: Option<String>,
90 #[serde(default)]
91 pub r#type: Option<String>,
92 #[serde(flatten)]
94 pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100 pub r#type: String,
101 #[serde(default)]
102 pub text: Option<String>,
103 #[serde(flatten)]
105 pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109 pub fn content_as_string(&self) -> String {
111 match &self.content {
112 ContentValue::String(content) => content.clone(),
113 ContentValue::Array(content) => content.join(" "),
114 ContentValue::Object(obj) => {
115 if let Some(text) = &obj.text {
116 text.clone()
117 } else {
118 serde_json::to_string(obj).unwrap_or_default()
120 }
121 }
122 ContentValue::ContentBlocks(blocks) => blocks
123 .iter()
124 .filter_map(|block| block.text.as_ref())
125 .cloned()
126 .collect::<Vec<_>>()
127 .join(" "),
128 }
129 }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134 role: String,
135 content: String,
136}
137
138impl From<MessageInner> for Message {
139 fn from(inner: MessageInner) -> Self {
140 Self {
141 role: inner.role,
142 content: ContentValue::String(inner.content),
143 }
144 }
145}
146
147mod content_string_or_array {
151 use serde::{self, Deserialize, Deserializer, Serializer};
152
153 pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
154 where
155 S: Serializer,
156 {
157 serializer.serialize_str(content)
158 }
159
160 pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
161 where
162 D: Deserializer<'de>,
163 {
164 #[derive(Deserialize)]
165 #[serde(untagged)]
166 enum StringOrArray {
167 String(()),
168 Array(()),
169 }
170
171 let value = StringOrArray::deserialize(deserializer)?;
172 match value {
173 StringOrArray::String(_) => Ok(String::new()),
174 StringOrArray::Array(_) => Ok(String::new()),
175 }
176 }
177}
178
179#[derive(Deserialize, Serialize, Debug)]
181pub struct ChatCompletionRequest {
182 model: String,
183 messages: Vec<Message>,
184 #[serde(default)]
185 user: Option<String>, }
187
188#[derive(Serialize, Debug)]
190pub struct Choice {
191 index: i32,
192 message: Message,
193 logprobs: Option<serde_json::Value>,
194 finish_reason: String,
195}
196
197fn get_model(model_name: &str) -> &'static str {
199 match model_name {
200 "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
201 "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
202 "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
203 _ => crate::config::SONNET_MODEL,
204 }
205}
206
207fn _normalize_role(role: &str) -> String {
209 match role {
210 "assistant" => "assistant".to_string(),
211 "user" => "human".to_string(),
212 "system" | "developer" => "system".to_string(),
213 _ => "human".to_string(), }
215}
216
217fn _messages_to_prompt(messages: &[Message]) -> String {
219 let mut prompt = String::new();
220
221 for msg in messages {
222 let role = _normalize_role(&msg.role);
223 let content = msg.content_as_string();
224 match role.as_str() {
225 "system" => prompt.push_str(&format!("System: {content}\n\n")),
226 "human" => prompt.push_str(&format!("Human: {content}\n\n")),
227 "assistant" => prompt.push_str(&format!("Assistant: {content}\n\n")),
228 _ => prompt.push_str(&format!("{role}: {content}\n\n")),
229 }
230 }
231
232 prompt
233}
234
235fn generate_conversation_key(messages: &[Message]) -> String {
237 use std::collections::hash_map::DefaultHasher;
238 use std::hash::{Hash, Hasher};
239
240 let context_messages = if messages.len() > 1 {
242 &messages[..messages.len() - 1]
243 } else {
244 &[]
245 };
246
247 let mut hasher = DefaultHasher::new();
248 for msg in context_messages {
249 msg.role.hash(&mut hasher);
250 msg.content_as_string().hash(&mut hasher);
251 }
252
253 format!("ctx:{:016x}", hasher.finish())
254}
255
256fn find_matching_conversation(conversations: &DashMap<String, (String, Vec<Message>)>, messages: &[Message]) -> Option<String> {
259 if messages.is_empty() {
260 return None;
261 }
262
263 let context_key = generate_conversation_key(messages);
265
266 if conversations.contains_key(&context_key) {
268 return Some(context_key);
269 }
270
271 let our_context = if messages.len() > 1 {
274 &messages[..messages.len() - 1]
275 } else {
276 &[]
277 };
278
279 for item in conversations.iter() {
280 let (_, stored_history) = item.value();
281
282 if stored_history.len() >= our_context.len() {
284 let prefix_matches = our_context.iter().zip(stored_history.iter())
285 .all(|(a, b)| a.role == b.role && a.content_as_string() == b.content_as_string());
286
287 if prefix_matches {
288 return Some(item.key().clone());
289 }
290 }
291 }
292
293 None
294}
295
296async fn chat_completion_handler(
298 State(state): State<AppState>,
299 Json(request): Json<ChatCompletionRequest>,
300) -> impl IntoResponse {
301 for (i, msg) in request.messages.iter().enumerate() {
307 match &msg.content {
308 ContentValue::Array(_) => info!("Message at index {} has array content", i),
309 ContentValue::Object(_) => info!("Message at index {} has object content", i),
310 ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
311 _ => {} }
313 }
314
315 let model_id = get_model(&request.model);
316
317 if request.messages.is_empty() {
318 return (
319 StatusCode::BAD_REQUEST,
320 Json(serde_json::json!({
321 "error": "No messages provided in the request"
322 })),
323 );
324 }
325
326 let messages = request.messages;
327
328 let _request_id = request.user.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
331
332 let user_query = messages.last().unwrap().clone();
334
335 let claude = state.claude;
336 let conversations = state.conversations;
337
338 let mut conversation_id: Option<String> = None;
340 let mut message_history: Vec<Message> = Vec::new();
341
342 info!("Looking for matching conversation by message history");
343 info!("Available conversations: {}", conversations.len());
344
345 let matching_key = find_matching_conversation(&conversations, &messages);
347
348 if let Some(ref key) = matching_key {
349 if let Some(existing) = conversations.get(key) {
350 conversation_id = Some(existing.0.clone());
351 message_history = existing.1.clone();
352 info!("Found conversation with matching history, key: '{}'", key);
353 }
354 } else {
355 info!("No existing conversation with matching history found");
356 }
357
358 if conversation_id.is_none() {
360 match claude.create_chat().await {
361 Ok(new_id) => {
362 conversation_id = Some(new_id.clone());
363 message_history = messages[..messages.len() - 1].to_vec();
364
365 if let Some(system_msg) = messages
367 .iter()
368 .find(|m| m.role == "system" || m.role == "developer")
369 {
370 if let Err(e) = claude
371 .send_message(&new_id, &system_msg.content_as_string(), &[])
372 .await
373 {
374 warn!("Failed to send system prompt: {}", e);
375 }
376 }
377
378 let context_key = generate_conversation_key(&messages);
380 conversations.insert(
381 context_key,
382 (new_id.clone(), message_history.clone()),
383 );
384 }
385 Err(e) => {
386 let error_response = serde_json::json!({
387 "error": format!("Failed to create conversation: {}", e)
388 });
389 dbg!(&error_response);
390 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
391 }
392 }
393 }
394
395 let completion = match claude
397 .send_message(
398 &conversation_id.clone().unwrap(),
399 &user_query.content_as_string(),
400 &[],
401 )
402 .await
403 {
404 Ok(response) => response,
405 Err(e) => {
406 let error_response = serde_json::json!({
407 "error": format!("Failed to get completion: {}", e)
408 });
409 dbg!(&error_response);
410 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
411 }
412 };
413
414 let new_message = Message {
416 role: "assistant".to_string(),
417 content: ContentValue::String(completion.clone()),
418 };
419
420 let mut updated_history = message_history.clone();
422 updated_history.push(user_query.clone());
423 updated_history.push(new_message);
424
425 let conv_id = conversation_id.unwrap();
426
427 let mut extended_messages = messages.clone();
429 extended_messages.push(Message {
430 role: "assistant".to_string(),
431 content: ContentValue::String(completion.clone()),
432 });
433 let updated_key = generate_conversation_key(&extended_messages);
434
435 if let Some(ref old_key) = matching_key {
437 if old_key != &updated_key
438 && conversations.remove(old_key).is_some() {
439 info!("Removed old conversation entry with key: '{}'", old_key);
440 }
441 }
442
443 conversations.insert(
445 updated_key.clone(),
446 (conv_id.clone(), updated_history.clone()),
447 );
448
449 info!(
451 "Stored conversation with ID: {} under context key '{}'",
452 conv_id,
453 updated_key
454 );
455
456 let response = OpenAIResponse {
458 id: Uuid::new_v4().to_string(),
459 object: "chat.completion",
460 created: chrono::Utc::now().timestamp() as u64,
461 model: model_id.to_string(),
462 system_fingerprint: "fp_toast".to_string(),
463 choices: vec![Choice {
464 index: 0,
465 message: Message {
466 role: "assistant".to_string(),
467 content: ContentValue::String(completion.clone()),
468 },
469 logprobs: None,
470 finish_reason: "stop".to_string(),
471 }],
472 usage: Usage {
473 prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, completion_tokens: completion.chars().count() as u32 / 4, total_tokens: (user_query.content_as_string().chars().count()
476 + completion.chars().count()) as u32
477 / 4,
478 },
479 };
480
481 (
485 StatusCode::OK,
486 Json(serde_json::to_value(response).unwrap()),
487 )
488}
489
490async fn health_check() -> impl IntoResponse {
492 (StatusCode::OK, "OK")
493}
494
495pub async fn run() -> anyhow::Result<()> {
497 run_with_args(Args::parse()).await
499}
500
501pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
503 if args.debug {
507 tracing_subscriber::fmt::init();
509 } else {
510 tracing_subscriber::fmt()
512 .with_max_level(tracing::Level::WARN)
513 .init();
514 }
515
516 let config_dir = dirs::config_dir()
518 .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
519 .join("toast");
520
521 let cookie_path = config_dir.join("cookie");
522 let org_id_path = config_dir.join("org_id");
523
524 if !config_dir.exists() {
526 return Err(anyhow::anyhow!(
527 "Configuration directory does not exist at {:?}",
528 config_dir
529 ));
530 }
531
532 let cookie = if cookie_path.exists() {
534 std::fs::read_to_string(&cookie_path)?.trim().to_string()
535 } else {
536 return Err(anyhow::anyhow!(
537 "Cookie file not found at {:?}",
538 cookie_path
539 ));
540 };
541
542 let org_id = if org_id_path.exists() {
544 std::fs::read_to_string(&org_id_path)?.trim().to_string()
545 } else {
546 if let Some(extracted_org_id) = crate::utils::extract_org_id_from_cookie(&cookie) {
548 std::fs::write(&org_id_path, &extracted_org_id)?;
550 info!(
551 "Extracted organization ID from cookie and saved to {:?}",
552 org_id_path
553 );
554 extracted_org_id
555 } else {
556 return Err(anyhow::anyhow!(
557 "Organization ID file not found at {:?} and couldn't extract it from cookie",
558 org_id_path
559 ));
560 }
561 };
562
563 let user_agent =
564 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
565 .to_string();
566
567 let session = Session {
568 cookie,
569 user_agent,
570 organization_id: org_id,
571 };
572
573 let model_str = crate::config::SONNET_MODEL;
575
576 let claude = Arc::new(Claude::new(session.clone(), model_str)?);
577
578 let app_state = AppState {
580 claude,
581 conversations: Arc::new(DashMap::new()),
582 };
583
584 let cors = CorsLayer::new()
586 .allow_origin(Any)
587 .allow_methods(Any)
588 .allow_headers(Any);
589
590 let app = Router::new()
592 .route("/health", get(health_check))
593 .route("/v1/chat/completions", post(chat_completion_handler))
594 .route("/", post(chat_completion_handler))
595 .route("/chat/completions", post(chat_completion_handler))
596 .layer(cors)
597 .with_state(app_state);
598
599 let host_parts: Vec<u8> = args
601 .host
602 .split('.')
603 .map(|s| s.parse::<u8>().unwrap_or(0))
604 .collect();
605
606 let host_addr = if host_parts.len() == 4 {
607 [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
608 } else {
609 [0, 0, 0, 0]
610 };
611
612 let addr = SocketAddr::from((host_addr, args.port));
614 info!("Server listening on {}", addr);
615
616 info!("Example request:");
618 info!(
619 "curl -X POST http://localhost:{}{} \\",
620 args.port, "/v1/chat/completions # Also available at / and /chat/completions"
621 );
622 info!(" -H \"Content-Type: application/json\" \\");
623 info!(" -d '{{");
624 info!(" \"model\": \"gpt-4.1\",");
625 info!(" \"messages\": [");
626 info!(" {{");
627 info!(" \"role\": \"developer\",");
628 info!(" \"content\": \"You are a helpful assistant.\"");
629 info!(" }},");
630 info!(" {{");
631 info!(" \"role\": \"user\",");
632 info!(" \"content\": \"Hello! (Simple string content)\"");
633 info!(" }},");
634 info!(" {{");
635 info!(" \"role\": \"user\",");
636 info!(" \"content\": [\"Array of\", \"string content\"]");
637 info!(" }},");
638 info!(" {{");
639 info!(" \"role\": \"user\",");
640 info!(" \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
641 info!(" }},");
642 info!(" {{");
643 info!(" \"role\": \"user\",");
644 info!(" \"content\": [");
645 info!(" {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
646 info!(" {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
647 info!(" ]");
648 info!(" }}");
649 info!(" ]");
650 info!(" }}'");
651
652 let listener = tokio::net::TcpListener::bind(addr).await?;
653 axum::serve(listener, app).await?;
654
655 Ok(())
656}