Skip to main content

stakpak_api/remote/
mod.rs

1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14    AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15    ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23    client: ReqwestClient,
24    base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29    pub api_key: Option<String>,
30    pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35    error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40    key: String,
41    message: String,
42}
43
44impl RemoteClient {
45    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46        if response.status().is_success() {
47            Ok(response)
48        } else {
49            let error_body = response
50                .text()
51                .await
52                .unwrap_or_else(|_| "Failed to read error body".to_string());
53
54            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
55                if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
56                    if api_error.error.key == "EXCEEDED_API_LIMIT" {
57                        return Err(format!(
58                            "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
59                            api_error.error.message
60                        ));
61                    } else {
62                        return Err(api_error.error.message);
63                    }
64                }
65
66                if let Some(error_obj) = json.get("error") {
67                    let error_message =
68                        if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
69                            message.to_string()
70                        } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
71                            format!("API error: {}", code)
72                        } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
73                            format!("API error: {}", key)
74                        } else {
75                            serde_json::to_string(error_obj)
76                                .unwrap_or_else(|_| "Unknown API error".to_string())
77                        };
78                    return Err(error_message);
79                }
80            }
81
82            Err(error_body)
83        }
84    }
85
86    async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
87        let url = format!("{}/mcp", self.base_url);
88
89        let payload = json!({
90            "jsonrpc": "2.0",
91            "method": "tools/call",
92            "params": {
93                "name": input.name,
94                "arguments": input.arguments,
95            },
96            "id": Uuid::new_v4().to_string(),
97        });
98
99        let response = self
100            .client
101            .post(&url)
102            .json(&payload)
103            .send()
104            .await
105            .map_err(|e: ReqwestError| e.to_string())?;
106
107        let response = self.handle_response_error(response).await?;
108
109        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
110
111        match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
112            Ok(response) => Ok(response.result.content),
113            Err(_) => {
114                // eprintln!("Failed to deserialize response: {}", e);
115                // eprintln!("Raw response: {}", value);
116                Err("Failed to deserialize response:".into())
117            }
118        }
119    }
120
121    pub fn new(config: &ClientConfig) -> Result<Self, String> {
122        if config.api_key.is_none() {
123            return Err("API Key not found, please login".into());
124        }
125
126        let mut headers = header::HeaderMap::new();
127        headers.insert(
128            header::AUTHORIZATION,
129            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
130                .expect("Invalid API key format"),
131        );
132        headers.insert(
133            header::USER_AGENT,
134            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
135                .expect("Invalid user agent format"),
136        );
137
138        let client = create_tls_client(
139            TlsClientConfig::default()
140                .with_headers(headers)
141                .with_timeout(std::time::Duration::from_secs(300)),
142        )?;
143
144        Ok(Self {
145            client,
146            base_url: config.api_endpoint.clone() + "/v1",
147        })
148    }
149}
150
151#[async_trait]
152impl AgentProvider for RemoteClient {
153    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
154        let url = format!("{}/account", self.base_url);
155
156        let response = self
157            .client
158            .get(&url)
159            .send()
160            .await
161            .map_err(|e: ReqwestError| e.to_string())?;
162
163        let response = self.handle_response_error(response).await?;
164
165        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
166        match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
167            Ok(response) => Ok(response),
168            Err(e) => {
169                eprintln!("Failed to deserialize response: {}", e);
170                eprintln!("Raw response: {}", value);
171                Err("Failed to deserialize response:".into())
172            }
173        }
174    }
175
176    async fn get_billing_info(
177        &self,
178        account_username: &str,
179    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
180        // Billing endpoint is v2 and requires account username in path
181
182        let base = self.base_url.trim_end_matches("/v1");
183        let url = format!("{}/v2/{}/billing", base, account_username);
184
185        let response = self
186            .client
187            .get(&url)
188            .send()
189            .await
190            .map_err(|e: ReqwestError| e.to_string())?;
191
192        let response = self.handle_response_error(response).await?;
193
194        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
195        match serde_json::from_value::<stakpak_shared::models::billing::BillingResponse>(
196            value.clone(),
197        ) {
198            Ok(response) => Ok(response),
199            Err(e) => {
200                let error_msg = format!("Failed to deserialize billing response: {}", e);
201                Err(error_msg)
202            }
203        }
204    }
205
206    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
207        let url = format!("{}/rules", self.base_url);
208
209        let response = self
210            .client
211            .get(&url)
212            .send()
213            .await
214            .map_err(|e: ReqwestError| e.to_string())?;
215
216        let response = self.handle_response_error(response).await?;
217
218        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
219        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
220            Ok(response) => Ok(response.results),
221            Err(e) => {
222                eprintln!("Failed to deserialize response: {}", e);
223                eprintln!("Raw response: {}", value);
224                Err("Failed to deserialize response:".into())
225            }
226        }
227    }
228
229    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
230        // URL encode the URI to handle special characters
231        let encoded_uri = urlencoding::encode(uri);
232        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
233
234        let response = self
235            .client
236            .get(&url)
237            .send()
238            .await
239            .map_err(|e: ReqwestError| e.to_string())?;
240
241        let response = self.handle_response_error(response).await?;
242
243        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
244        match serde_json::from_value::<RuleBook>(value.clone()) {
245            Ok(response) => Ok(response),
246            Err(e) => {
247                eprintln!("Failed to deserialize response: {}", e);
248                eprintln!("Raw response: {}", value);
249                Err("Failed to deserialize response:".into())
250            }
251        }
252    }
253
254    async fn create_rulebook(
255        &self,
256        uri: &str,
257        description: &str,
258        content: &str,
259        tags: Vec<String>,
260        visibility: Option<RuleBookVisibility>,
261    ) -> Result<CreateRuleBookResponse, String> {
262        let url = format!("{}/rules", self.base_url);
263
264        let input = CreateRuleBookInput {
265            uri: uri.to_string(),
266            description: description.to_string(),
267            content: content.to_string(),
268            tags,
269            visibility,
270        };
271
272        let response = self
273            .client
274            .post(&url)
275            .json(&input)
276            .send()
277            .await
278            .map_err(|e: ReqwestError| e.to_string())?;
279
280        // Check status before consuming body
281        if !response.status().is_success() {
282            let status = response.status();
283            let error_text = response
284                .text()
285                .await
286                .unwrap_or_else(|_| "Unknown error".to_string());
287            return Err(format!("API error ({}): {}", status, error_text));
288        }
289
290        // Get response as text first to handle non-JSON responses
291        let response_text = response.text().await.map_err(|e| e.to_string())?;
292
293        // Try to parse as JSON first
294        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
295            match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
296                Ok(response) => return Ok(response),
297                Err(e) => {
298                    eprintln!("Failed to deserialize JSON response: {}", e);
299                    eprintln!("Raw response: {}", value);
300                }
301            }
302        }
303
304        // If JSON parsing failed, try to parse as plain text "id: <uuid>"
305        if response_text.starts_with("id: ") {
306            let id = response_text.trim_start_matches("id: ").trim().to_string();
307            return Ok(CreateRuleBookResponse { id });
308        }
309
310        Err(format!("Unexpected response format: {}", response_text))
311    }
312
313    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
314        let encoded_uri = urlencoding::encode(uri);
315        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
316
317        let response = self
318            .client
319            .delete(&url)
320            .send()
321            .await
322            .map_err(|e: ReqwestError| e.to_string())?;
323
324        let _response = self.handle_response_error(response).await?;
325
326        Ok(())
327    }
328
329    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
330        let url = format!("{}/agents/sessions", self.base_url);
331
332        let response = self
333            .client
334            .get(&url)
335            .send()
336            .await
337            .map_err(|e: ReqwestError| e.to_string())?;
338
339        let response = self.handle_response_error(response).await?;
340
341        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
342        match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
343            Ok(response) => Ok(response),
344            Err(e) => {
345                eprintln!("Failed to deserialize response: {}", e);
346                eprintln!("Raw response: {}", value);
347                Err("Failed to deserialize response:".into())
348            }
349        }
350    }
351
352    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
353        let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
354
355        let response = self
356            .client
357            .get(&url)
358            .send()
359            .await
360            .map_err(|e: ReqwestError| e.to_string())?;
361
362        let response = self.handle_response_error(response).await?;
363
364        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
365
366        match serde_json::from_value::<AgentSession>(value.clone()) {
367            Ok(response) => Ok(response),
368            Err(e) => {
369                eprintln!("Failed to deserialize response: {}", e);
370                eprintln!("Raw response: {}", value);
371                Err("Failed to deserialize response:".into())
372            }
373        }
374    }
375
376    async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
377        let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
378
379        let response = self
380            .client
381            .get(&url)
382            .send()
383            .await
384            .map_err(|e: ReqwestError| e.to_string())?;
385
386        let response = self.handle_response_error(response).await?;
387
388        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
389
390        match serde_json::from_value::<AgentSessionStats>(value.clone()) {
391            Ok(response) => Ok(response),
392            Err(e) => {
393                eprintln!("Failed to deserialize response: {}", e);
394                eprintln!("Raw response: {}", value);
395                Err("Failed to deserialize response:".into())
396            }
397        }
398    }
399
400    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
401        let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
402
403        let response = self
404            .client
405            .get(&url)
406            .send()
407            .await
408            .map_err(|e: ReqwestError| e.to_string())?;
409
410        let response = self.handle_response_error(response).await?;
411
412        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
413        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
414            Ok(response) => Ok(response),
415            Err(e) => {
416                eprintln!("Failed to deserialize response: {}", e);
417                eprintln!("Raw response: {}", value);
418                Err("Failed to deserialize response:".into())
419            }
420        }
421    }
422
423    async fn get_agent_session_latest_checkpoint(
424        &self,
425        session_id: Uuid,
426    ) -> Result<RunAgentOutput, String> {
427        let url = format!(
428            "{}/agents/sessions/{}/checkpoints/latest",
429            self.base_url, session_id
430        );
431
432        let response = self
433            .client
434            .get(&url)
435            .send()
436            .await
437            .map_err(|e: ReqwestError| e.to_string())?;
438
439        let response = self.handle_response_error(response).await?;
440
441        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
442        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
443            Ok(response) => Ok(response),
444            Err(e) => {
445                eprintln!("Failed to deserialize response: {}", e);
446                eprintln!("Raw response: {}", value);
447                Err("Failed to deserialize response:".into())
448            }
449        }
450    }
451
452    async fn chat_completion(
453        &self,
454        model: AgentModel,
455        messages: Vec<ChatMessage>,
456        tools: Option<Vec<Tool>>,
457    ) -> Result<ChatCompletionResponse, String> {
458        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
459
460        let model_string = model.to_string();
461        let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, None);
462
463        let response = self
464            .client
465            .post(&url)
466            .json(&input)
467            .send()
468            .await
469            .map_err(|e: ReqwestError| e.to_string())?;
470
471        let response = self.handle_response_error(response).await?;
472
473        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
474
475        if let Some(error_obj) = value.get("error") {
476            let error_message = if let Some(message) =
477                error_obj.get("message").and_then(|m| m.as_str())
478            {
479                message.to_string()
480            } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
481                format!("API error: {}", code)
482            } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
483                format!("API error: {}", key)
484            } else {
485                serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
486            };
487            return Err(error_message);
488        }
489
490        match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
491            Ok(response) => Ok(response),
492            Err(e) => {
493                eprintln!("Failed to deserialize response: {}", e);
494                eprintln!("Raw response: {}", value);
495                Err("Failed to deserialize response:".into())
496            }
497        }
498    }
499
500    async fn chat_completion_stream(
501        &self,
502        model: AgentModel,
503        messages: Vec<ChatMessage>,
504        tools: Option<Vec<Tool>>,
505        headers: Option<HeaderMap>,
506    ) -> Result<
507        (
508            std::pin::Pin<
509                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
510            >,
511            Option<String>,
512        ),
513        String,
514    > {
515        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
516
517        let model_string = model.to_string();
518        let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, Some(true));
519
520        let response = self
521            .client
522            .post(&url)
523            .headers(headers.unwrap_or_default())
524            .json(&input)
525            .send()
526            .await
527            .map_err(|e: ReqwestError| e.to_string())?;
528
529        // Check content-type before processing
530        let content_type = response
531            .headers()
532            .get("content-type")
533            .and_then(|v| v.to_str().ok())
534            .unwrap_or("unknown");
535
536        // Extract x-request-id from headers
537        let request_id = response
538            .headers()
539            .get("x-request-id")
540            .and_then(|v| v.to_str().ok())
541            .map(|s| s.to_string());
542
543        // If content-type is not event-stream, it's likely an error message
544        if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
545            let status = response.status();
546            let error_body = response
547                .text()
548                .await
549                .unwrap_or_else(|_| "Failed to read error body".to_string());
550
551            let error_message =
552                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
553                    // Try ApiError format first (Stakpak API format)
554                    if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
555                        api_error.error.message
556                    } else if let Some(error_obj) = json.get("error") {
557                        // Generic error format
558                        if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
559                            message.to_string()
560                        } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
561                            format!("API error: {}", code)
562                        } else {
563                            error_body
564                        }
565                    } else {
566                        error_body
567                    }
568                } else {
569                    error_body
570                };
571
572            return Err(format!(
573                "Server returned non-stream response ({}): {}",
574                status, error_message
575            ));
576        }
577
578        let response = self.handle_response_error(response).await?;
579        let stream = response.bytes_stream().eventsource().map(move |event| {
580            event
581                .map_err(|_| ApiStreamError::Unknown("Failed to read response".to_string()))
582                .and_then(|event| match event.event.as_str() {
583                    "error" => Err(ApiStreamError::from(event.data)),
584                    _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
585                        |_| {
586                            ApiStreamError::Unknown(
587                                "Failed to parse JSON from Anthropic response".to_string(),
588                            )
589                        },
590                    ),
591                })
592        });
593
594        Ok((Box::pin(stream), request_id))
595    }
596
597    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
598        let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
599        self.client
600            .post(&url)
601            .send()
602            .await
603            .map_err(|e: ReqwestError| e.to_string())?;
604
605        Ok(())
606    }
607
608    // async fn build_code_index(
609    //     &self,
610    //     input: &BuildCodeIndexInput,
611    // ) -> Result<BuildCodeIndexOutput, String> {
612    //     let url = format!("{}/commands/build_code_index", self.base_url,);
613
614    //     let response = self
615    //         .client
616    //         .post(&url)
617    //         .json(&input)
618    //         .send()
619    //         .await
620    //         .map_err(|e: ReqwestError| e.to_string())?;
621
622    //     let response = self.handle_response_error(response).await?;
623
624    //     let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
625    //     match serde_json::from_value::<BuildCodeIndexOutput>(value.clone()) {
626    //         Ok(response) => Ok(response),
627    //         Err(e) => {
628    //             eprintln!("Failed to deserialize response: {}", e);
629    //             eprintln!("Raw response: {}", value);
630    //             Err("Failed to deserialize response:".into())
631    //         }
632    //     }
633    // }
634
635    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
636        self.call_mcp_tool(&ToolsCallParams {
637            name: "search_docs".to_string(),
638            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
639        })
640        .await
641    }
642
643    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
644        self.call_mcp_tool(&ToolsCallParams {
645            name: "search_memory".to_string(),
646            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
647        })
648        .await
649    }
650
651    async fn slack_read_messages(
652        &self,
653        input: &SlackReadMessagesRequest,
654    ) -> Result<Vec<Content>, String> {
655        self.call_mcp_tool(&ToolsCallParams {
656            name: "slack_read_messages".to_string(),
657            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
658        })
659        .await
660    }
661
662    async fn slack_read_replies(
663        &self,
664        input: &SlackReadRepliesRequest,
665    ) -> Result<Vec<Content>, String> {
666        self.call_mcp_tool(&ToolsCallParams {
667            name: "slack_read_replies".to_string(),
668            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
669        })
670        .await
671    }
672
673    async fn slack_send_message(
674        &self,
675        input: &SlackSendMessageRequest,
676    ) -> Result<Vec<Content>, String> {
677        // Note: The remote tool expects "markdown_text" but the struct has "mrkdwn_text".
678        // We need to map this correctly. The struct in models.rs has mrkdwn_text.
679        // The remote tool likely expects what was previously passed.
680        // In slack.rs, it was mapping "mrkdwn_text" to "markdown_text".
681        // So we should construct the arguments manually or use a custom serializer if we want to match exactly.
682        // However, since we are sending `input` which is `SlackSendMessageRequest`, let's check its definition.
683        // It has `mrkdwn_text`.
684        // The previous implementation in slack.rs did:
685        // arguments: json!({
686        //     "channel": channel,
687        //     "markdown_text": mrkdwn_text,
688        //     "thread_ts": thread_ts,
689        // }),
690        // So we need to replicate this mapping.
691
692        let arguments = json!({
693            "channel": input.channel,
694            "markdown_text": input.mrkdwn_text,
695            "thread_ts": input.thread_ts,
696        });
697
698        self.call_mcp_tool(&ToolsCallParams {
699            name: "slack_send_message".to_string(),
700            arguments,
701        })
702        .await
703    }
704
705    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
706        let url = format!(
707            "{}/agents/sessions/checkpoints/{}/extract-memory",
708            self.base_url, checkpoint_id
709        );
710
711        let response = self
712            .client
713            .post(&url)
714            .send()
715            .await
716            .map_err(|e: ReqwestError| e.to_string())?;
717
718        let _ = self.handle_response_error(response).await?;
719        Ok(())
720    }
721}