toast_api/
server.rs

1use axum::{
2    extract::State,
3    http::StatusCode,
4    response::IntoResponse,
5    routing::{get, post},
6    Json, Router,
7};
8use clap::Parser;
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15use uuid::Uuid;
16
17use crate::api::{Claude, Session};
18
19/// Command line arguments for server configuration
20#[derive(Parser, Debug, Clone)]
21#[clap(author, version, about = "OpenAI-compatible Claude API server")]
22pub struct Args {
23    /// Port to listen on - public field
24    #[clap(long, short, default_value = "3000")]
25    pub port: u16,
26
27    /// Enable debug logging
28    #[clap(long)]
29    pub debug: bool,
30
31    /// Default Claude model to use
32    #[clap(long, default_value = "claude-3-7-sonnet-latest")]
33    pub model: String,
34
35    /// Host address to bind to
36    #[clap(long, default_value = "0.0.0.0")]
37    pub host: String,
38}
39
40/// Server state shared across all requests
41#[derive(Clone)]
42pub struct AppState {
43    pub claude: Arc<Claude>,
44    // Map to track conversations: Vec<Message> -> (conversation_id, Vec<Message>)
45    pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
46}
47
48/// Base response structure for OpenAI-compatible API
49#[derive(Serialize, Debug)]
50pub struct OpenAIResponse {
51    id: String,
52    object: &'static str,
53    created: u64,
54    model: String,
55    system_fingerprint: String,
56    choices: Vec<Choice>,
57    usage: Usage,
58}
59
60/// Token usage information
61#[derive(Serialize, Debug)]
62pub struct Usage {
63    prompt_tokens: u32,
64    completion_tokens: u32,
65    total_tokens: u32,
66}
67
68/// Chat message structure (OpenAI compatible)
69#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
70pub struct Message {
71    role: String,
72    content: ContentValue,
73}
74
75/// Content value can be string, array of strings, or structured content
76#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
77#[serde(untagged)]
78pub enum ContentValue {
79    String(String),
80    Array(Vec<String>),
81    Object(ContentObject),
82    ContentBlocks(Vec<ContentBlock>),
83}
84
85/// Content object with type and other fields
86#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
87pub struct ContentObject {
88    #[serde(default)]
89    pub text: Option<String>,
90    #[serde(default)]
91    pub r#type: Option<String>,
92    // Allow for other fields
93    #[serde(flatten)]
94    pub other: std::collections::HashMap<String, serde_json::Value>,
95}
96
97/// Structured content block
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99pub struct ContentBlock {
100    pub r#type: String,
101    #[serde(default)]
102    pub text: Option<String>,
103    // Allow for other fields
104    #[serde(flatten)]
105    pub other: std::collections::HashMap<String, serde_json::Value>,
106}
107
108impl Message {
109    // Helper to get content as a string regardless of format
110    pub fn content_as_string(&self) -> String {
111        match &self.content {
112            ContentValue::String(content) => content.clone(),
113            ContentValue::Array(content) => content.join(" "),
114            ContentValue::Object(obj) => {
115                if let Some(text) = &obj.text {
116                    text.clone()
117                } else {
118                    // Try to convert to JSON string if no text field
119                    serde_json::to_string(obj).unwrap_or_default()
120                }
121            }
122            ContentValue::ContentBlocks(blocks) => blocks
123                .iter()
124                .filter_map(|block| block.text.as_ref())
125                .cloned()
126                .collect::<Vec<_>>()
127                .join(" "),
128        }
129    }
130}
131
132#[derive(Debug, Deserialize)]
133pub struct MessageInner {
134    role: String,
135    content: String,
136}
137
138impl From<MessageInner> for Message {
139    fn from(inner: MessageInner) -> Self {
140        Self {
141            role: inner.role,
142            content: ContentValue::String(inner.content),
143        }
144    }
145}
146
147/// Deserializer for message content that supports both strings and arrays
148
149/// Custom serialization/deserialization for message content
150/// Supports either a single string or an array of strings
151mod content_string_or_array {
152    use serde::{self, Deserialize, Deserializer, Serializer};
153
154    pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
155    where
156        S: Serializer,
157    {
158        serializer.serialize_str(content)
159    }
160
161    pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
162    where
163        D: Deserializer<'de>,
164    {
165        #[derive(Deserialize)]
166        #[serde(untagged)]
167        enum StringOrArray {
168            String(()),
169            Array(()),
170        }
171
172        let value = StringOrArray::deserialize(deserializer)?;
173        match value {
174            StringOrArray::String(_) => Ok(String::new()),
175            StringOrArray::Array(_) => Ok(String::new()),
176        }
177    }
178}
179
180/// Chat completion request (OpenAI compatible)
181#[derive(Deserialize, Serialize, Debug)]
182pub struct ChatCompletionRequest {
183    model: String,
184    messages: Vec<Message>,
185}
186
187/// Chat completion response choices (OpenAI compatible)
188#[derive(Serialize, Debug)]
189pub struct Choice {
190    index: i32,
191    message: Message,
192    logprobs: Option<serde_json::Value>,
193    finish_reason: String,
194}
195
196/// Helper function to get model from request or default to Sonnet
197fn get_model(model_name: &str) -> &'static str {
198    match model_name {
199        "gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
200        "gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
201        "gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
202        _ => crate::config::SONNET_MODEL,
203    }
204}
205
206/// Normalize roles for Claude API
207fn _normalize_role(role: &str) -> String {
208    match role {
209        "assistant" => "assistant".to_string(),
210        "user" => "human".to_string(),
211        "system" | "developer" => "system".to_string(),
212        _ => "human".to_string(), // Default to human for unknown roles
213    }
214}
215
216/// Convert messages to a single Claude prompt
217fn _messages_to_prompt(messages: &[Message]) -> String {
218    let mut prompt = String::new();
219
220    for msg in messages {
221        let role = _normalize_role(&msg.role);
222        let content = msg.content_as_string();
223        match role.as_str() {
224            "system" => prompt.push_str(&format!("System: {}\n\n", content)),
225            "human" => prompt.push_str(&format!("Human: {}\n\n", content)),
226            "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
227            _ => prompt.push_str(&format!("{}: {}\n\n", role, content)),
228        }
229    }
230
231    prompt
232}
233
234/// Generate a hash for a sequence of messages (excluding the last one)
235/// Get a stable thread key for logical conversation grouping (by root user message)
236fn conversation_key(messages: &[Message]) -> String {
237    // This uses the first user message content as the conversation identity
238    messages
239        .iter()
240        .find(|m| m.role == "user")
241        .map(|m| m.content_as_string())
242        .unwrap_or_default()
243}
244
245/// Chat completion API handler
246async fn chat_completion_handler(
247    State(state): State<AppState>,
248    Json(request): Json<ChatCompletionRequest>,
249) -> impl IntoResponse {
250    // Debug the raw client request
251    // dbg!(&request);
252
253    // Log when array content is used
254    // Log message content types
255    for (i, msg) in request.messages.iter().enumerate() {
256        match &msg.content {
257            ContentValue::Array(_) => info!("Message at index {} has array content", i),
258            ContentValue::Object(_) => info!("Message at index {} has object content", i),
259            ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
260            _ => {} // String content is the default, no need to log
261        }
262    }
263
264    let model_id = get_model(&request.model);
265
266    if request.messages.is_empty() {
267        return (
268            StatusCode::BAD_REQUEST,
269            Json(serde_json::json!({
270                "error": "No messages provided in the request"
271            })),
272        );
273    }
274
275    let messages = request.messages;
276    let conv_key = conversation_key(&messages);
277
278    // The actual user query is the last message
279    let user_query = messages.last().unwrap().clone();
280
281    let claude = state.claude;
282    let conversations = state.conversations;
283
284    // Try to find existing conversation by canonical thread key
285    let mut conversation_id: Option<String> = None;
286    let mut message_history: Vec<Message> = Vec::new();
287
288    info!("Looking for conversation key: {}", conv_key);
289    info!("Available conversation keys:");
290    for item in conversations.iter() {
291        info!("  Key: '{}', {} chars", item.key(), item.key().len());
292    }
293
294    if let Some(existing) = conversations.get(&conv_key) {
295        conversation_id = Some(existing.0.clone());
296        message_history = existing.1.clone();
297        info!("Found conversation for key: '{}'", conv_key);
298    } else {
299        info!("No existing conversation found for key: '{}'", conv_key);
300    }
301
302    // If no matching conversation, create a new one
303    if conversation_id.is_none() {
304        match claude.create_chat().await {
305            Ok(new_id) => {
306                conversation_id = Some(new_id);
307                message_history = messages[..messages.len() - 1].to_vec();
308
309                // Store the new conversation
310                conversations.insert(
311                    conv_key.clone(),
312                    (conversation_id.clone().unwrap(), message_history.clone()),
313                );
314
315                // Send system prompt if present
316                if let Some(system_msg) = messages
317                    .iter()
318                    .find(|m| m.role == "system" || m.role == "developer")
319                {
320                    if let Err(e) = claude
321                        .send_message(
322                            &conversation_id.clone().unwrap(),
323                            &system_msg.content_as_string(),
324                            &[],
325                        )
326                        .await
327                    {
328                        warn!("Failed to send system prompt: {}", e);
329                    }
330                }
331            }
332            Err(e) => {
333                let error_response = serde_json::json!({
334                    "error": format!("Failed to create conversation: {}", e)
335                });
336                dbg!(&error_response);
337                return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
338            }
339        }
340    }
341
342    // Send the user message and get response
343    let completion = match claude
344        .send_message(
345            &conversation_id.clone().unwrap(),
346            &user_query.content_as_string(),
347            &[],
348        )
349        .await
350    {
351        Ok(response) => response,
352        Err(e) => {
353            let error_response = serde_json::json!({
354                "error": format!("Failed to get completion: {}", e)
355            });
356            dbg!(&error_response);
357            return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
358        }
359    };
360
361    // Update the message history with the new assistant response
362    let new_message = Message {
363        role: "assistant".to_string(),
364        content: ContentValue::String(completion.clone()),
365    };
366
367    // Update message history and store it
368    let mut updated_history = message_history.clone();
369    updated_history.push(user_query.clone());
370    updated_history.push(new_message);
371
372    // Remove the old conversation entry first
373    if let Some(_) = conversations.remove(&conv_key) {
374        info!("Removed old conversation entry with key: '{}'", conv_key);
375    }
376
377    // Use the canonical key again for updated history (assumes initial user prompt is stable through the thread)
378    let new_conv_key = conversation_key(&updated_history);
379    let conv_id = conversation_id.unwrap();
380
381    // Insert with the updated key
382    conversations.insert(
383        new_conv_key.clone(),
384        (conv_id.clone(), updated_history.clone()),
385    );
386    // Log the conversation storage
387    info!(
388        "Stored conversation with ID: {} under new key '{}', {} chars",
389        conv_id,
390        new_conv_key,
391        new_conv_key.len()
392    );
393
394    // Create OpenAI-compatible response
395    let response = OpenAIResponse {
396        id: Uuid::new_v4().to_string(),
397        object: "chat.completion",
398        created: chrono::Utc::now().timestamp() as u64,
399        model: model_id.to_string(),
400        system_fingerprint: "fp_toast".to_string(),
401        choices: vec![Choice {
402            index: 0,
403            message: Message {
404                role: "assistant".to_string(),
405                content: ContentValue::String(completion.clone()),
406            },
407            logprobs: None,
408            finish_reason: "stop".to_string(),
409        }],
410        usage: Usage {
411            prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, // Rough estimate
412            completion_tokens: completion.chars().count() as u32 / 4, // Rough estimate
413            total_tokens: (user_query.content_as_string().chars().count()
414                + completion.chars().count()) as u32
415                / 4,
416        },
417    };
418
419    // Debug the response being sent to the client
420    // dbg!(&response);
421
422    (
423        StatusCode::OK,
424        Json(serde_json::to_value(response).unwrap()),
425    )
426}
427
428/// Health check endpoint
429async fn health_check() -> impl IntoResponse {
430    (StatusCode::OK, "OK")
431}
432
433/// Start the Axum server
434pub async fn run() -> anyhow::Result<()> {
435    // For backward compatibility, parse arguments directly when run() is called
436    run_with_args(Args::parse()).await
437}
438
439/// Start the Axum server with provided arguments
440pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
441    // Use provided arguments instead of parsing them
442
443    // Initialize tracing
444    if args.debug {
445        // Full logging in debug mode
446        tracing_subscriber::fmt::init();
447    } else {
448        // Only show warnings and errors in normal mode
449        tracing_subscriber::fmt()
450            .with_max_level(tracing::Level::WARN)
451            .init();
452    }
453
454    // Load config
455    let config_dir = dirs::config_dir()
456        .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
457        .join("toast");
458
459    let cookie_path = config_dir.join("cookie");
460    let org_id_path = config_dir.join("org_id");
461
462    // Check if config directory exists
463    if !config_dir.exists() {
464        return Err(anyhow::anyhow!(
465            "Configuration directory does not exist at {:?}",
466            config_dir
467        ));
468    }
469
470    // Check and load cookie
471    let cookie = if cookie_path.exists() {
472        std::fs::read_to_string(&cookie_path)?.trim().to_string()
473    } else {
474        return Err(anyhow::anyhow!(
475            "Cookie file not found at {:?}",
476            cookie_path
477        ));
478    };
479
480    // Check and load org_id
481    let org_id = if org_id_path.exists() {
482        std::fs::read_to_string(&org_id_path)?.trim().to_string()
483    } else {
484        // Try to extract org_id from cookie
485        if let Some(extracted_org_id) = crate::cli::extract_org_id_from_cookie(&cookie) {
486            // Save the extracted org_id to the file for future use
487            std::fs::write(&org_id_path, &extracted_org_id)?;
488            info!(
489                "Extracted organization ID from cookie and saved to {:?}",
490                org_id_path
491            );
492            extracted_org_id
493        } else {
494            return Err(anyhow::anyhow!(
495                "Organization ID file not found at {:?} and couldn't extract it from cookie",
496                org_id_path
497            ));
498        }
499    };
500
501    let user_agent =
502        "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
503            .to_string();
504
505    let session = Session {
506        cookie,
507        user_agent,
508        organization_id: org_id,
509    };
510
511    // Convert model string to static lifetime
512    let model_str = Box::leak(args.model.into_boxed_str());
513
514    let claude = Arc::new(Claude::new(session.clone(), model_str)?);
515
516    // Create app state
517    let app_state = AppState {
518        claude,
519        conversations: Arc::new(DashMap::new()),
520    };
521
522    // Set up CORS
523    let cors = CorsLayer::new()
524        .allow_origin(Any)
525        .allow_methods(Any)
526        .allow_headers(Any);
527
528    // Create router
529    let app = Router::new()
530        .route("/health", get(health_check))
531        .route("/v1/chat/completions", post(chat_completion_handler))
532        .route("/", post(chat_completion_handler))
533        .route("/chat/completions", post(chat_completion_handler))
534        .layer(cors)
535        .with_state(app_state);
536
537    // Parse host address
538    let host_parts: Vec<u8> = args
539        .host
540        .split('.')
541        .map(|s| s.parse::<u8>().unwrap_or(0))
542        .collect();
543
544    let host_addr = if host_parts.len() == 4 {
545        [host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
546    } else {
547        [0, 0, 0, 0]
548    };
549
550    // Run server
551    let addr = SocketAddr::from((host_addr, args.port));
552    info!("Server listening on {}", addr);
553
554    // Print usage example
555    info!("Example request:");
556    info!(
557        "curl -X POST http://localhost:{}{} \\",
558        args.port, "/v1/chat/completions  # Also available at / and /chat/completions"
559    );
560    info!("  -H \"Content-Type: application/json\" \\");
561    info!("  -d '{{");
562    info!("    \"model\": \"gpt-4.1\",");
563    info!("    \"messages\": [");
564    info!("      {{");
565    info!("        \"role\": \"developer\",");
566    info!("        \"content\": \"You are a helpful assistant.\"");
567    info!("      }},");
568    info!("      {{");
569    info!("        \"role\": \"user\",");
570    info!("        \"content\": \"Hello! (Simple string content)\"");
571    info!("      }},");
572    info!("      {{");
573    info!("        \"role\": \"user\",");
574    info!("        \"content\": [\"Array of\", \"string content\"]");
575    info!("      }},");
576    info!("      {{");
577    info!("        \"role\": \"user\",");
578    info!("        \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
579    info!("      }},");
580    info!("      {{");
581    info!("        \"role\": \"user\",");
582    info!("        \"content\": [");
583    info!("          {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
584    info!("          {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
585    info!("        ]");
586    info!("      }}");
587    info!("    ]");
588    info!("  }}'");
589
590    let listener = tokio::net::TcpListener::bind(addr).await?;
591    axum::serve(listener, app).await?;
592
593    Ok(())
594}