1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14 AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15 ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23 client: ReqwestClient,
24 base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29 pub api_key: Option<String>,
30 pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35 error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40 key: String,
41 message: String,
42}
43
44impl RemoteClient {
45 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46 if response.status().is_success() {
47 Ok(response)
48 } else {
49 let error_body = response
50 .text()
51 .await
52 .unwrap_or_else(|_| "Failed to read error body".to_string());
53
54 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
55 if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
56 if api_error.error.key == "EXCEEDED_API_LIMIT" {
57 return Err(format!(
58 "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
59 api_error.error.message
60 ));
61 } else {
62 return Err(api_error.error.message);
63 }
64 }
65
66 if let Some(error_obj) = json.get("error") {
67 let error_message =
68 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
69 message.to_string()
70 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
71 format!("API error: {}", code)
72 } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
73 format!("API error: {}", key)
74 } else {
75 serde_json::to_string(error_obj)
76 .unwrap_or_else(|_| "Unknown API error".to_string())
77 };
78 return Err(error_message);
79 }
80 }
81
82 Err(error_body)
83 }
84 }
85
86 async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
87 let url = format!("{}/mcp", self.base_url);
88
89 let payload = json!({
90 "jsonrpc": "2.0",
91 "method": "tools/call",
92 "params": {
93 "name": input.name,
94 "arguments": input.arguments,
95 },
96 "id": Uuid::new_v4().to_string(),
97 });
98
99 let response = self
100 .client
101 .post(&url)
102 .json(&payload)
103 .send()
104 .await
105 .map_err(|e: ReqwestError| e.to_string())?;
106
107 let response = self.handle_response_error(response).await?;
108
109 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
110
111 match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
112 Ok(response) => Ok(response.result.content),
113 Err(e) => {
114 eprintln!("Failed to deserialize response: {}", e);
115 eprintln!("Raw response: {}", value);
116 Err("Failed to deserialize response:".into())
117 }
118 }
119 }
120
121 pub fn new(config: &ClientConfig) -> Result<Self, String> {
122 if config.api_key.is_none() {
123 return Err("API Key not found, please login".into());
124 }
125
126 let mut headers = header::HeaderMap::new();
127 headers.insert(
128 header::AUTHORIZATION,
129 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
130 .expect("Invalid API key format"),
131 );
132 headers.insert(
133 header::USER_AGENT,
134 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
135 .expect("Invalid user agent format"),
136 );
137
138 let client = create_tls_client(
139 TlsClientConfig::default()
140 .with_headers(headers)
141 .with_timeout(std::time::Duration::from_secs(300)),
142 )?;
143
144 Ok(Self {
145 client,
146 base_url: config.api_endpoint.clone() + "/v1",
147 })
148 }
149}
150
151#[async_trait]
152impl AgentProvider for RemoteClient {
153 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
154 let url = format!("{}/account", self.base_url);
155
156 let response = self
157 .client
158 .get(&url)
159 .send()
160 .await
161 .map_err(|e: ReqwestError| e.to_string())?;
162
163 let response = self.handle_response_error(response).await?;
164
165 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
166 match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
167 Ok(response) => Ok(response),
168 Err(e) => {
169 eprintln!("Failed to deserialize response: {}", e);
170 eprintln!("Raw response: {}", value);
171 Err("Failed to deserialize response:".into())
172 }
173 }
174 }
175
176 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
177 let url = format!("{}/rules", self.base_url);
178
179 let response = self
180 .client
181 .get(&url)
182 .send()
183 .await
184 .map_err(|e: ReqwestError| e.to_string())?;
185
186 let response = self.handle_response_error(response).await?;
187
188 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
189 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
190 Ok(response) => Ok(response.results),
191 Err(e) => {
192 eprintln!("Failed to deserialize response: {}", e);
193 eprintln!("Raw response: {}", value);
194 Err("Failed to deserialize response:".into())
195 }
196 }
197 }
198
199 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
200 let encoded_uri = urlencoding::encode(uri);
202 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
203
204 let response = self
205 .client
206 .get(&url)
207 .send()
208 .await
209 .map_err(|e: ReqwestError| e.to_string())?;
210
211 let response = self.handle_response_error(response).await?;
212
213 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
214 match serde_json::from_value::<RuleBook>(value.clone()) {
215 Ok(response) => Ok(response),
216 Err(e) => {
217 eprintln!("Failed to deserialize response: {}", e);
218 eprintln!("Raw response: {}", value);
219 Err("Failed to deserialize response:".into())
220 }
221 }
222 }
223
224 async fn create_rulebook(
225 &self,
226 uri: &str,
227 description: &str,
228 content: &str,
229 tags: Vec<String>,
230 visibility: Option<RuleBookVisibility>,
231 ) -> Result<CreateRuleBookResponse, String> {
232 let url = format!("{}/rules", self.base_url);
233
234 let input = CreateRuleBookInput {
235 uri: uri.to_string(),
236 description: description.to_string(),
237 content: content.to_string(),
238 tags,
239 visibility,
240 };
241
242 let response = self
243 .client
244 .post(&url)
245 .json(&input)
246 .send()
247 .await
248 .map_err(|e: ReqwestError| e.to_string())?;
249
250 if !response.status().is_success() {
252 let status = response.status();
253 let error_text = response
254 .text()
255 .await
256 .unwrap_or_else(|_| "Unknown error".to_string());
257 return Err(format!("API error ({}): {}", status, error_text));
258 }
259
260 let response_text = response.text().await.map_err(|e| e.to_string())?;
262
263 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
265 match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
266 Ok(response) => return Ok(response),
267 Err(e) => {
268 eprintln!("Failed to deserialize JSON response: {}", e);
269 eprintln!("Raw response: {}", value);
270 }
271 }
272 }
273
274 if response_text.starts_with("id: ") {
276 let id = response_text.trim_start_matches("id: ").trim().to_string();
277 return Ok(CreateRuleBookResponse { id });
278 }
279
280 Err(format!("Unexpected response format: {}", response_text))
281 }
282
283 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
284 let encoded_uri = urlencoding::encode(uri);
285 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
286
287 let response = self
288 .client
289 .delete(&url)
290 .send()
291 .await
292 .map_err(|e: ReqwestError| e.to_string())?;
293
294 let _response = self.handle_response_error(response).await?;
295
296 Ok(())
297 }
298
299 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
300 let url = format!("{}/agents/sessions", self.base_url);
301
302 let response = self
303 .client
304 .get(&url)
305 .send()
306 .await
307 .map_err(|e: ReqwestError| e.to_string())?;
308
309 let response = self.handle_response_error(response).await?;
310
311 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
312 match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
313 Ok(response) => Ok(response),
314 Err(e) => {
315 eprintln!("Failed to deserialize response: {}", e);
316 eprintln!("Raw response: {}", value);
317 Err("Failed to deserialize response:".into())
318 }
319 }
320 }
321
322 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
323 let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
324
325 let response = self
326 .client
327 .get(&url)
328 .send()
329 .await
330 .map_err(|e: ReqwestError| e.to_string())?;
331
332 let response = self.handle_response_error(response).await?;
333
334 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
335
336 match serde_json::from_value::<AgentSession>(value.clone()) {
337 Ok(response) => Ok(response),
338 Err(e) => {
339 eprintln!("Failed to deserialize response: {}", e);
340 eprintln!("Raw response: {}", value);
341 Err("Failed to deserialize response:".into())
342 }
343 }
344 }
345
346 async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
347 let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
348
349 let response = self
350 .client
351 .get(&url)
352 .send()
353 .await
354 .map_err(|e: ReqwestError| e.to_string())?;
355
356 let response = self.handle_response_error(response).await?;
357
358 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
359
360 match serde_json::from_value::<AgentSessionStats>(value.clone()) {
361 Ok(response) => Ok(response),
362 Err(e) => {
363 eprintln!("Failed to deserialize response: {}", e);
364 eprintln!("Raw response: {}", value);
365 Err("Failed to deserialize response:".into())
366 }
367 }
368 }
369
370 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
371 let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
372
373 let response = self
374 .client
375 .get(&url)
376 .send()
377 .await
378 .map_err(|e: ReqwestError| e.to_string())?;
379
380 let response = self.handle_response_error(response).await?;
381
382 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
383 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
384 Ok(response) => Ok(response),
385 Err(e) => {
386 eprintln!("Failed to deserialize response: {}", e);
387 eprintln!("Raw response: {}", value);
388 Err("Failed to deserialize response:".into())
389 }
390 }
391 }
392
393 async fn get_agent_session_latest_checkpoint(
394 &self,
395 session_id: Uuid,
396 ) -> Result<RunAgentOutput, String> {
397 let url = format!(
398 "{}/agents/sessions/{}/checkpoints/latest",
399 self.base_url, session_id
400 );
401
402 let response = self
403 .client
404 .get(&url)
405 .send()
406 .await
407 .map_err(|e: ReqwestError| e.to_string())?;
408
409 let response = self.handle_response_error(response).await?;
410
411 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
412 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
413 Ok(response) => Ok(response),
414 Err(e) => {
415 eprintln!("Failed to deserialize response: {}", e);
416 eprintln!("Raw response: {}", value);
417 Err("Failed to deserialize response:".into())
418 }
419 }
420 }
421
422 async fn chat_completion(
423 &self,
424 model: AgentModel,
425 messages: Vec<ChatMessage>,
426 tools: Option<Vec<Tool>>,
427 ) -> Result<ChatCompletionResponse, String> {
428 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
429
430 let model_string = model.to_string();
431 let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, None);
432
433 let response = self
434 .client
435 .post(&url)
436 .json(&input)
437 .send()
438 .await
439 .map_err(|e: ReqwestError| e.to_string())?;
440
441 let response = self.handle_response_error(response).await?;
442
443 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
444
445 if let Some(error_obj) = value.get("error") {
446 let error_message = if let Some(message) =
447 error_obj.get("message").and_then(|m| m.as_str())
448 {
449 message.to_string()
450 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
451 format!("API error: {}", code)
452 } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
453 format!("API error: {}", key)
454 } else {
455 serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
456 };
457 return Err(error_message);
458 }
459
460 match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
461 Ok(response) => Ok(response),
462 Err(e) => {
463 eprintln!("Failed to deserialize response: {}", e);
464 eprintln!("Raw response: {}", value);
465 Err("Failed to deserialize response:".into())
466 }
467 }
468 }
469
470 async fn chat_completion_stream(
471 &self,
472 model: AgentModel,
473 messages: Vec<ChatMessage>,
474 tools: Option<Vec<Tool>>,
475 headers: Option<HeaderMap>,
476 ) -> Result<
477 (
478 std::pin::Pin<
479 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
480 >,
481 Option<String>,
482 ),
483 String,
484 > {
485 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
486
487 let model_string = model.to_string();
488 let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, Some(true));
489
490 let response = self
491 .client
492 .post(&url)
493 .headers(headers.unwrap_or_default())
494 .json(&input)
495 .send()
496 .await
497 .map_err(|e: ReqwestError| e.to_string())?;
498
499 let content_type = response
501 .headers()
502 .get("content-type")
503 .and_then(|v| v.to_str().ok())
504 .unwrap_or("unknown");
505
506 let request_id = response
508 .headers()
509 .get("x-request-id")
510 .and_then(|v| v.to_str().ok())
511 .map(|s| s.to_string());
512
513 if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
515 let status = response.status();
516 let error_body = response
517 .text()
518 .await
519 .unwrap_or_else(|_| "Failed to read error body".to_string());
520
521 let error_message =
522 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
523 if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
525 api_error.error.message
526 } else if let Some(error_obj) = json.get("error") {
527 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
529 message.to_string()
530 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
531 format!("API error: {}", code)
532 } else {
533 error_body
534 }
535 } else {
536 error_body
537 }
538 } else {
539 error_body
540 };
541
542 return Err(format!(
543 "Server returned non-stream response ({}): {}",
544 status, error_message
545 ));
546 }
547
548 let response = self.handle_response_error(response).await?;
549 let stream = response.bytes_stream().eventsource().map(move |event| {
550 event
551 .map_err(|_| ApiStreamError::Unknown("Failed to read response".to_string()))
552 .and_then(|event| match event.event.as_str() {
553 "error" => Err(ApiStreamError::from(event.data)),
554 _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
555 |_| {
556 ApiStreamError::Unknown(
557 "Failed to parse JSON from Anthropic response".to_string(),
558 )
559 },
560 ),
561 })
562 });
563
564 Ok((Box::pin(stream), request_id))
565 }
566
567 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
568 let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
569 self.client
570 .post(&url)
571 .send()
572 .await
573 .map_err(|e: ReqwestError| e.to_string())?;
574
575 Ok(())
576 }
577
578 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
606 self.call_mcp_tool(&ToolsCallParams {
607 name: "search_docs".to_string(),
608 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
609 })
610 .await
611 }
612
613 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
614 self.call_mcp_tool(&ToolsCallParams {
615 name: "search_memory".to_string(),
616 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
617 })
618 .await
619 }
620
621 async fn slack_read_messages(
622 &self,
623 input: &SlackReadMessagesRequest,
624 ) -> Result<Vec<Content>, String> {
625 self.call_mcp_tool(&ToolsCallParams {
626 name: "slack_read_messages".to_string(),
627 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
628 })
629 .await
630 }
631
632 async fn slack_read_replies(
633 &self,
634 input: &SlackReadRepliesRequest,
635 ) -> Result<Vec<Content>, String> {
636 self.call_mcp_tool(&ToolsCallParams {
637 name: "slack_read_replies".to_string(),
638 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
639 })
640 .await
641 }
642
643 async fn slack_send_message(
644 &self,
645 input: &SlackSendMessageRequest,
646 ) -> Result<Vec<Content>, String> {
647 let arguments = json!({
663 "channel": input.channel,
664 "markdown_text": input.mrkdwn_text,
665 "thread_ts": input.thread_ts,
666 });
667
668 self.call_mcp_tool(&ToolsCallParams {
669 name: "slack_send_message".to_string(),
670 arguments,
671 })
672 .await
673 }
674
675 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
676 let url = format!(
677 "{}/agents/sessions/checkpoints/{}/extract-memory",
678 self.base_url, checkpoint_id
679 );
680
681 let response = self
682 .client
683 .post(&url)
684 .send()
685 .await
686 .map_err(|e: ReqwestError| e.to_string())?;
687
688 let _ = self.handle_response_error(response).await?;
689 Ok(())
690 }
691}