1use axum::{
2 extract::State,
3 http::StatusCode,
4 response::IntoResponse,
5 routing::{get, post},
6 Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23 #[clap(long, short, default_value = "3000")]
25 pub port: u16,
26
27 #[clap(long)]
29 pub debug: bool,
30
31 #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33 pub model: String,
34
35 #[clap(long, default_value = "0.0.0.0")]
37 pub host: String,
38}
39
40#[derive(Clone)]
42pub struct AppState {
43 pub claude: Arc<Claude>,
44 pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51 id: String,
52 object: &'static str,
53 created: u64,
54 model: String,
55 system_fingerprint: String,
56 choices: Vec<Choice>,
57 usage: Usage,
58}
59
60#[derive(Serialize, Debug)]
62pub struct Usage {
63 prompt_tokens: u32,
64 completion_tokens: u32,
65 total_tokens: u32,
66}
67
68#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71 role: String,
72 content: ContentValue,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79 String(String),
80 Array(Vec<String>),
81 Object(ContentObject),
82 ContentBlocks(Vec<ContentBlock>),
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88 #[serde(default)]
89 pub text: Option<String>,
90 #[serde(default)]
91 pub r#type: Option<String>,
92 #[serde(flatten)]
94 pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100 pub r#type: String,
101 #[serde(default)]
102 pub text: Option<String>,
103 #[serde(flatten)]
105 pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109 pub fn content_as_string(&self) -> String {
111 match &self.content {
112 ContentValue::String(content) => content.clone(),
113 ContentValue::Array(content) => content.join(" "),
114 ContentValue::Object(obj) => {
115 if let Some(text) = &obj.text {
116 text.clone()
117 } else {
118 serde_json::to_string(obj).unwrap_or_default()
120 }
121 }
122 ContentValue::ContentBlocks(blocks) => blocks
123 .iter()
124 .filter_map(|block| block.text.as_ref())
125 .cloned()
126 .collect::<Vec<_>>()
127 .join(" "),
128 }
129 }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134 role: String,
135 content: String,
136}
137
138impl From<MessageInner> for Message {
139 fn from(inner: MessageInner) -> Self {
140 Self {
141 role: inner.role,
142 content: ContentValue::String(inner.content),
143 }
144 }
145}
146
147mod content_string_or_array {
152 use serde::{self, Deserialize, Deserializer, Serializer};
153
154 pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
155 where
156 S: Serializer,
157 {
158 serializer.serialize_str(content)
159 }
160
161 pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
162 where
163 D: Deserializer<'de>,
164 {
165 #[derive(Deserialize)]
166 #[serde(untagged)]
167 enum StringOrArray {
168 String(()),
169 Array(()),
170 }
171
172 let value = StringOrArray::deserialize(deserializer)?;
173 match value {
174 StringOrArray::String(_) => Ok(String::new()),
175 StringOrArray::Array(_) => Ok(String::new()),
176 }
177 }
178}
179
180#[derive(Deserialize, Serialize, Debug)]
182pub struct ChatCompletionRequest {
183 model: String,
184 messages: Vec<Message>,
185 #[serde(default)]
186 user: Option<String>, }
188
189#[derive(Serialize, Debug)]
191pub struct Choice {
192 index: i32,
193 message: Message,
194 logprobs: Option<serde_json::Value>,
195 finish_reason: String,
196}
197
198fn get_model(model_name: &str) -> &'static str {
200 match model_name {
201 "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
202 "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
203 "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
204 _ => crate::config::SONNET_MODEL,
205 }
206}
207
208fn _normalize_role(role: &str) -> String {
210 match role {
211 "assistant" => "assistant".to_string(),
212 "user" => "human".to_string(),
213 "system" | "developer" => "system".to_string(),
214 _ => "human".to_string(), }
216}
217
218fn _messages_to_prompt(messages: &[Message]) -> String {
220 let mut prompt = String::new();
221
222 for msg in messages {
223 let role = _normalize_role(&msg.role);
224 let content = msg.content_as_string();
225 match role.as_str() {
226 "system" => prompt.push_str(&format!("System: {}\n\n", content)),
227 "human" => prompt.push_str(&format!("Human: {}\n\n", content)),
228 "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
229 _ => prompt.push_str(&format!("{}: {}\n\n", role, content)),
230 }
231 }
232
233 prompt
234}
235
236fn generate_conversation_key(messages: &[Message]) -> String {
238 use std::collections::hash_map::DefaultHasher;
239 use std::hash::{Hash, Hasher};
240
241 let context_messages = if messages.len() > 1 {
243 &messages[..messages.len() - 1]
244 } else {
245 &[]
246 };
247
248 let mut hasher = DefaultHasher::new();
249 for msg in context_messages {
250 msg.role.hash(&mut hasher);
251 msg.content_as_string().hash(&mut hasher);
252 }
253
254 format!("ctx:{:016x}", hasher.finish())
255}
256
257fn find_matching_conversation(conversations: &DashMap<String, (String, Vec<Message>)>, messages: &[Message]) -> Option<String> {
260 if messages.is_empty() {
261 return None;
262 }
263
264 let context_key = generate_conversation_key(messages);
266
267 if conversations.contains_key(&context_key) {
269 return Some(context_key);
270 }
271
272 let our_context = if messages.len() > 1 {
275 &messages[..messages.len() - 1]
276 } else {
277 &[]
278 };
279
280 for item in conversations.iter() {
281 let (_, stored_history) = &*item.value();
282
283 if stored_history.len() >= our_context.len() {
285 let prefix_matches = our_context.iter().zip(stored_history.iter())
286 .all(|(a, b)| a.role == b.role && a.content_as_string() == b.content_as_string());
287
288 if prefix_matches {
289 return Some(item.key().clone());
290 }
291 }
292 }
293
294 None
295}
296
297async fn chat_completion_handler(
299 State(state): State<AppState>,
300 Json(request): Json<ChatCompletionRequest>,
301) -> impl IntoResponse {
302 for (i, msg) in request.messages.iter().enumerate() {
308 match &msg.content {
309 ContentValue::Array(_) => info!("Message at index {} has array content", i),
310 ContentValue::Object(_) => info!("Message at index {} has object content", i),
311 ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
312 _ => {} }
314 }
315
316 let model_id = get_model(&request.model);
317
318 if request.messages.is_empty() {
319 return (
320 StatusCode::BAD_REQUEST,
321 Json(serde_json::json!({
322 "error": "No messages provided in the request"
323 })),
324 );
325 }
326
327 let messages = request.messages;
328
329 let _request_id = request.user.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
332
333 let user_query = messages.last().unwrap().clone();
335
336 let claude = state.claude;
337 let conversations = state.conversations;
338
339 let mut conversation_id: Option<String> = None;
341 let mut message_history: Vec<Message> = Vec::new();
342
343 info!("Looking for matching conversation by message history");
344 info!("Available conversations: {}", conversations.len());
345
346 let matching_key = find_matching_conversation(&conversations, &messages);
348
349 if let Some(ref key) = matching_key {
350 if let Some(existing) = conversations.get(key) {
351 conversation_id = Some(existing.0.clone());
352 message_history = existing.1.clone();
353 info!("Found conversation with matching history, key: '{}'", key);
354 }
355 } else {
356 info!("No existing conversation with matching history found");
357 }
358
359 if conversation_id.is_none() {
361 match claude.create_chat().await {
362 Ok(new_id) => {
363 conversation_id = Some(new_id.clone());
364 message_history = messages[..messages.len() - 1].to_vec();
365
366 if let Some(system_msg) = messages
368 .iter()
369 .find(|m| m.role == "system" || m.role == "developer")
370 {
371 if let Err(e) = claude
372 .send_message(&new_id, &system_msg.content_as_string(), &[])
373 .await
374 {
375 warn!("Failed to send system prompt: {}", e);
376 }
377 }
378
379 let context_key = generate_conversation_key(&messages);
381 conversations.insert(
382 context_key,
383 (new_id.clone(), message_history.clone()),
384 );
385 }
386 Err(e) => {
387 let error_response = serde_json::json!({
388 "error": format!("Failed to create conversation: {}", e)
389 });
390 dbg!(&error_response);
391 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
392 }
393 }
394 }
395
396 let completion = match claude
398 .send_message(
399 &conversation_id.clone().unwrap(),
400 &user_query.content_as_string(),
401 &[],
402 )
403 .await
404 {
405 Ok(response) => response,
406 Err(e) => {
407 let error_response = serde_json::json!({
408 "error": format!("Failed to get completion: {}", e)
409 });
410 dbg!(&error_response);
411 return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
412 }
413 };
414
415 let new_message = Message {
417 role: "assistant".to_string(),
418 content: ContentValue::String(completion.clone()),
419 };
420
421 let mut updated_history = message_history.clone();
423 updated_history.push(user_query.clone());
424 updated_history.push(new_message);
425
426 let conv_id = conversation_id.unwrap();
427
428 let mut extended_messages = messages.clone();
430 extended_messages.push(Message {
431 role: "assistant".to_string(),
432 content: ContentValue::String(completion.clone()),
433 });
434 let updated_key = generate_conversation_key(&extended_messages);
435
436 if let Some(ref old_key) = matching_key {
438 if old_key != &updated_key {
439 if let Some(_) = conversations.remove(old_key) {
440 info!("Removed old conversation entry with key: '{}'", old_key);
441 }
442 }
443 }
444
445 conversations.insert(
447 updated_key.clone(),
448 (conv_id.clone(), updated_history.clone()),
449 );
450
451 info!(
453 "Stored conversation with ID: {} under context key '{}'",
454 conv_id,
455 updated_key
456 );
457
458 let response = OpenAIResponse {
460 id: Uuid::new_v4().to_string(),
461 object: "chat.completion",
462 created: chrono::Utc::now().timestamp() as u64,
463 model: model_id.to_string(),
464 system_fingerprint: "fp_toast".to_string(),
465 choices: vec![Choice {
466 index: 0,
467 message: Message {
468 role: "assistant".to_string(),
469 content: ContentValue::String(completion.clone()),
470 },
471 logprobs: None,
472 finish_reason: "stop".to_string(),
473 }],
474 usage: Usage {
475 prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, completion_tokens: completion.chars().count() as u32 / 4, total_tokens: (user_query.content_as_string().chars().count()
478 + completion.chars().count()) as u32
479 / 4,
480 },
481 };
482
483 (
487 StatusCode::OK,
488 Json(serde_json::to_value(response).unwrap()),
489 )
490}
491
492async fn health_check() -> impl IntoResponse {
494 (StatusCode::OK, "OK")
495}
496
497pub async fn run() -> anyhow::Result<()> {
499 run_with_args(Args::parse()).await
501}
502
503pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
505 if args.debug {
509 tracing_subscriber::fmt::init();
511 } else {
512 tracing_subscriber::fmt()
514 .with_max_level(tracing::Level::WARN)
515 .init();
516 }
517
518 let config_dir = dirs::config_dir()
520 .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
521 .join("toast");
522
523 let cookie_path = config_dir.join("cookie");
524 let org_id_path = config_dir.join("org_id");
525
526 if !config_dir.exists() {
528 return Err(anyhow::anyhow!(
529 "Configuration directory does not exist at {:?}",
530 config_dir
531 ));
532 }
533
534 let cookie = if cookie_path.exists() {
536 std::fs::read_to_string(&cookie_path)?.trim().to_string()
537 } else {
538 return Err(anyhow::anyhow!(
539 "Cookie file not found at {:?}",
540 cookie_path
541 ));
542 };
543
544 let org_id = if org_id_path.exists() {
546 std::fs::read_to_string(&org_id_path)?.trim().to_string()
547 } else {
548 if let Some(extracted_org_id) = crate::cli::extract_org_id_from_cookie(&cookie) {
550 std::fs::write(&org_id_path, &extracted_org_id)?;
552 info!(
553 "Extracted organization ID from cookie and saved to {:?}",
554 org_id_path
555 );
556 extracted_org_id
557 } else {
558 return Err(anyhow::anyhow!(
559 "Organization ID file not found at {:?} and couldn't extract it from cookie",
560 org_id_path
561 ));
562 }
563 };
564
565 let user_agent =
566 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
567 .to_string();
568
569 let session = Session {
570 cookie,
571 user_agent,
572 organization_id: org_id,
573 };
574
575 let model_str = crate::config::SONNET_MODEL;
577
578 let claude = Arc::new(Claude::new(session.clone(), model_str)?);
579
580 let app_state = AppState {
582 claude,
583 conversations: Arc::new(DashMap::new()),
584 };
585
586 let cors = CorsLayer::new()
588 .allow_origin(Any)
589 .allow_methods(Any)
590 .allow_headers(Any);
591
592 let app = Router::new()
594 .route("/health", get(health_check))
595 .route("/v1/chat/completions", post(chat_completion_handler))
596 .route("/", post(chat_completion_handler))
597 .route("/chat/completions", post(chat_completion_handler))
598 .layer(cors)
599 .with_state(app_state);
600
601 let host_parts: Vec<u8> = args
603 .host
604 .split('.')
605 .map(|s| s.parse::<u8>().unwrap_or(0))
606 .collect();
607
608 let host_addr = if host_parts.len() == 4 {
609 [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
610 } else {
611 [0, 0, 0, 0]
612 };
613
614 let addr = SocketAddr::from((host_addr, args.port));
616 info!("Server listening on {}", addr);
617
618 info!("Example request:");
620 info!(
621 "curl -X POST http://localhost:{}{} \\",
622 args.port, "/v1/chat/completions # Also available at / and /chat/completions"
623 );
624 info!(" -H \"Content-Type: application/json\" \\");
625 info!(" -d '{{");
626 info!(" \"model\": \"gpt-4.1\",");
627 info!(" \"messages\": [");
628 info!(" {{");
629 info!(" \"role\": \"developer\",");
630 info!(" \"content\": \"You are a helpful assistant.\"");
631 info!(" }},");
632 info!(" {{");
633 info!(" \"role\": \"user\",");
634 info!(" \"content\": \"Hello! (Simple string content)\"");
635 info!(" }},");
636 info!(" {{");
637 info!(" \"role\": \"user\",");
638 info!(" \"content\": [\"Array of\", \"string content\"]");
639 info!(" }},");
640 info!(" {{");
641 info!(" \"role\": \"user\",");
642 info!(" \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
643 info!(" }},");
644 info!(" {{");
645 info!(" \"role\": \"user\",");
646 info!(" \"content\": [");
647 info!(" {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
648 info!(" {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
649 info!(" ]");
650 info!(" }}");
651 info!(" ]");
652 info!(" }}'");
653
654 let listener = tokio::net::TcpListener::bind(addr).await?;
655 axum::serve(listener, app).await?;
656
657 Ok(())
658}