1use crate::error::{ConnectorError, Result};
2use crate::transport::{Transport, TransportOptions, TransportType, create_transport};
3use crate::types::*;
4use crate::url_parser::parse_url;
5use crate::utils::{generate_id, sanitize_identifier};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::sync::{RwLock, mpsc};
11use tokio::time::Duration;
12use tracing::debug;
13
14const KEEPALIVE_INTERVAL_SECS: u64 = 30;
17
18#[derive(Debug, Clone, Default)]
47pub struct ClientOptions {
48 pub url: Option<String>,
54 #[deprecated(note = "Use `url` with scheme for auto-detection")]
56 pub host: Option<String>,
57 pub use_tls: Option<bool>,
59 #[deprecated(note = "Use `url` with scheme for auto-detection")]
61 pub transport: Option<TransportType>,
62 pub default_timeout_ms: Option<u64>,
64}
65
66pub use strike48_proto::proto;
67
68use proto::{StreamMessage as ProtoStreamMessage, stream_message};
69
70struct PendingInvoke {
72 resolve: tokio::sync::oneshot::Sender<crate::types::InvokeCapabilityResponse>,
73 #[allow(dead_code)]
74 deadline: tokio::time::Instant,
75}
76
77pub(crate) struct StartedInvoke {
80 pub receiver: Option<tokio::sync::oneshot::Receiver<crate::types::InvokeCapabilityResponse>>,
81 pub request_id: String,
82 pub timeout_ms: u64,
83 pending_invokes: Arc<RwLock<HashMap<String, PendingInvoke>>>,
84}
85
86impl StartedInvoke {
87 pub(crate) async fn cancel(&self) {
89 self.pending_invokes.write().await.remove(&self.request_id);
90 }
91}
92
93pub struct ConnectorClient {
116 host: String,
117 use_tls: bool,
118 transport_type: TransportType,
119 transport: Option<Box<dyn Transport>>,
121 connected: Arc<AtomicBool>,
123 registered: Arc<AtomicBool>,
125 session_token: Arc<RwLock<Option<String>>>,
126 #[allow(dead_code)] connector_address: Arc<RwLock<Option<String>>>,
128 request_tx: Arc<RwLock<Option<mpsc::UnboundedSender<ProtoStreamMessage>>>>,
129 pending_invokes: Arc<RwLock<HashMap<String, PendingInvoke>>>,
130 default_timeout_ms: u64,
131 heartbeat_sent_at_nanos: Arc<AtomicU64>,
133}
134
135impl ConnectorClient {
136 #[allow(dead_code)]
140 pub fn new(host: String, use_tls: bool) -> Self {
141 #[allow(deprecated)]
142 Self::with_options(ClientOptions {
143 url: None,
144 host: Some(host),
145 use_tls: Some(use_tls),
146 transport: Some(TransportType::default()),
147 default_timeout_ms: Some(30000),
148 })
149 }
150
151 #[allow(deprecated)]
173 pub fn with_options(opts: ClientOptions) -> Self {
174 let (host, use_tls, transport) = if let Some(url) = &opts.url {
176 match parse_url(url) {
178 Ok(parsed) => {
179 let host = parsed.host_port();
180 let tls = opts.use_tls.unwrap_or(parsed.use_tls);
181 let trans = opts.transport.unwrap_or(parsed.transport);
182 (host, tls, trans)
183 }
184 Err(_) => {
185 let host = url.clone();
187 let tls = opts.use_tls.unwrap_or(false);
188 let trans = opts.transport.unwrap_or(TransportType::Grpc);
189 (host, tls, trans)
190 }
191 }
192 } else if let Some(host) = &opts.host {
193 match parse_url(host) {
195 Ok(parsed) => {
196 let host_port = parsed.host_port();
197 let tls = opts.use_tls.unwrap_or(parsed.use_tls);
198 let trans = opts.transport.unwrap_or(parsed.transport);
199 (host_port, tls, trans)
200 }
201 Err(_) => {
202 let tls = opts.use_tls.unwrap_or(false);
204 let trans = opts.transport.unwrap_or(TransportType::Grpc);
205 (host.clone(), tls, trans)
206 }
207 }
208 } else {
209 ("localhost:50061".to_string(), false, TransportType::Grpc)
211 };
212
213 if transport == TransportType::WebSocket {
214 debug!(
215 "WebSocket transport selected (detected from URL scheme). \
216 This transport works through corporate proxies that block HTTP/2."
217 );
218 }
219
220 debug!(
221 "ConnectorClient initialized: {} (transport: {:?}, TLS: {})",
222 host, transport, use_tls
223 );
224
225 Self {
226 host,
227 use_tls,
228 transport_type: transport,
229 transport: None, connected: Arc::new(AtomicBool::new(false)),
231 registered: Arc::new(AtomicBool::new(false)),
232 session_token: Arc::new(RwLock::new(None)),
233 connector_address: Arc::new(RwLock::new(None)),
234 request_tx: Arc::new(RwLock::new(None)),
235 pending_invokes: Arc::new(RwLock::new(HashMap::new())),
236 default_timeout_ms: opts.default_timeout_ms.unwrap_or(30000),
237 heartbeat_sent_at_nanos: Arc::new(AtomicU64::new(0)),
238 }
239 }
240
241 pub async fn connect_channel(&mut self) -> Result<()> {
247 debug!(
248 "Connecting to Strike48 server: {} (transport: {:?})",
249 self.host, self.transport_type
250 );
251
252 let options = TransportOptions {
254 host: self.host.clone(),
255 use_tls: self.use_tls,
256 connect_timeout_ms: Some(10000),
257 default_timeout_ms: Some(self.default_timeout_ms),
258 channel_capacity: Some(1024), };
260
261 let mut transport = create_transport(self.transport_type, options);
263 transport.connect().await?;
264
265 self.connected.store(true, Ordering::SeqCst);
266 self.transport = Some(transport);
267
268 debug!("Connected to Strike48 server");
269 Ok(())
270 }
271
272 #[allow(dead_code)]
275 pub async fn send_register_request(
276 &self,
277 tenant_id: &str,
278 connector_type: &str,
279 instance_id: &str,
280 capabilities: &ConnectorCapabilities,
281 auth_token: &str,
282 ) -> Result<()> {
283 let capabilities_proto = proto::ConnectorCapabilities {
285 connector_type: capabilities.connector_type.clone(),
286 version: capabilities.version.clone(),
287 supported_encodings: capabilities
288 .supported_encodings
289 .iter()
290 .map(|e| *e as i32)
291 .collect(),
292 behaviors: capabilities.behaviors.iter().map(|b| *b as i32).collect(),
293 metadata: capabilities.metadata.clone(),
294 task_types: capabilities
295 .task_types
296 .as_ref()
297 .map(|tts| {
298 tts.iter()
299 .map(|tt| proto::TaskTypeSchema {
300 task_type_id: tt.task_type_id.clone(),
301 name: tt.name.clone(),
302 description: tt.description.clone(),
303 category: tt.category.clone(),
304 icon: tt.icon.clone(),
305 input_schema_json: tt.input_schema_json.clone(),
306 output_schema_json: tt.output_schema_json.clone(),
307 })
308 .collect()
309 })
310 .unwrap_or_default(),
311 };
312
313 let sanitized_instance_id = sanitize_identifier(instance_id);
316
317 let instance_metadata = Some(proto::InstanceMetadata {
319 display_name: sanitized_instance_id.clone(),
320 tags: Vec::new(),
321 metadata: std::collections::HashMap::new(),
322 });
323
324 let mut request = proto::RegisterConnectorRequest {
325 tenant_id: sanitize_identifier(tenant_id),
326 connector_type: sanitize_identifier(connector_type),
327 instance_id: sanitized_instance_id,
328 capabilities: Some(capabilities_proto),
329 jwt_token: if auth_token.is_empty() {
330 String::new()
331 } else {
332 auth_token.to_string()
333 },
334 session_token: String::new(),
335 scope: 0, instance_metadata,
337 };
338
339 if let Some(session_token) = self.session_token.read().await.as_ref() {
341 request.session_token = session_token.clone();
342 debug!("Using session token for reconnection");
343 }
344
345 let message = ProtoStreamMessage {
347 message: Some(proto::stream_message::Message::RegisterRequest(request)),
348 };
349
350 self.send_message(message).await
351 }
352
353 pub async fn start_stream_with_registration(
363 &mut self,
364 initial_message: ProtoStreamMessage,
365 ) -> Result<(
366 mpsc::UnboundedSender<ProtoStreamMessage>,
367 mpsc::UnboundedReceiver<ProtoStreamMessage>,
368 )> {
369 debug!("start_stream: getting transport reference");
370 let transport = self
371 .transport
372 .as_mut()
373 .ok_or(ConnectorError::NotConnected)?;
374
375 debug!("start_stream: starting transport stream with initial message");
376
377 let (tx, rx) = transport.start_stream(Some(initial_message)).await?;
380
381 debug!("start_stream: transport stream started successfully");
382
383 *self.request_tx.write().await = Some(tx.clone());
385
386 Self::spawn_keepalive(
390 tx.clone(),
391 self.connected.clone(),
392 self.heartbeat_sent_at_nanos.clone(),
393 );
394
395 Ok((tx, rx))
396 }
397
398 fn spawn_keepalive(
404 tx: mpsc::UnboundedSender<ProtoStreamMessage>,
405 connected: Arc<AtomicBool>,
406 sent_at_nanos: Arc<AtomicU64>,
407 ) {
408 tokio::spawn(async move {
409 let mut interval = tokio::time::interval(Duration::from_secs(KEEPALIVE_INTERVAL_SECS));
410 interval.tick().await;
413
414 loop {
415 interval.tick().await;
416
417 if !connected.load(Ordering::SeqCst) {
418 debug!("keepalive: client disconnected, stopping");
419 break;
420 }
421
422 let now = SystemTime::now()
423 .duration_since(UNIX_EPOCH)
424 .unwrap_or_default();
425 let now_ms = now.as_millis() as i64;
426
427 sent_at_nanos.store(now.as_nanos() as u64, Ordering::Release);
428
429 let heartbeat = ProtoStreamMessage {
430 message: Some(proto::stream_message::Message::HeartbeatRequest(
431 proto::HeartbeatRequest {
432 gateway_id: String::new(),
433 timestamp_ms: now_ms,
434 },
435 )),
436 };
437
438 if tx.send(heartbeat).is_err() {
439 debug!("keepalive: stream closed, stopping");
440 break;
441 }
442 }
443 });
444 }
445
446 pub(crate) fn heartbeat_sent_at_nanos(&self) -> &Arc<AtomicU64> {
449 &self.heartbeat_sent_at_nanos
450 }
451
452 pub async fn send_message(&self, message: ProtoStreamMessage) -> Result<()> {
454 if let Some(tx) = self.request_tx.read().await.as_ref() {
455 tx.send(message)
456 .map_err(|e| ConnectorError::StreamError(format!("Failed to send message: {e}")))?;
457 Ok(())
458 } else {
459 Err(ConnectorError::StreamError(
460 "Stream not started".to_string(),
461 ))
462 }
463 }
464
465 pub(crate) async fn clone_message_tx(
470 &self,
471 ) -> Result<mpsc::UnboundedSender<ProtoStreamMessage>> {
472 self.request_tx
473 .read()
474 .await
475 .as_ref()
476 .cloned()
477 .ok_or_else(|| ConnectorError::StreamError("Stream not started".to_string()))
478 }
479
480 pub(crate) async fn start_invoke(
487 &self,
488 target_address: &str,
489 payload: Vec<u8>,
490 options: InvokeOptions,
491 ) -> Result<StartedInvoke> {
492 use tokio::sync::oneshot;
493
494 if !self.registered.load(Ordering::SeqCst) {
495 return Err(ConnectorError::NotRegistered);
496 }
497
498 let request_id = format!("invoke-{}", generate_id());
499 let timeout_ms = options.timeout_ms.unwrap_or(self.default_timeout_ms);
500 let fire_and_forget = options.fire_and_forget.unwrap_or(false);
501
502 let proto_request = proto::InvokeCapabilityRequest {
503 request_id: request_id.clone(),
504 target_address: target_address.to_string(),
505 capability_id: options.capability_id.unwrap_or_default(),
506 payload,
507 payload_encoding: options.payload_encoding.unwrap_or(PayloadEncoding::Json) as i32,
508 context: options.context.unwrap_or_default(),
509 timeout_ms: timeout_ms as i32,
510 fire_and_forget,
511 };
512
513 let message = ProtoStreamMessage {
514 message: Some(stream_message::Message::InvokeRequest(proto_request)),
515 };
516
517 if fire_and_forget {
518 self.send_message(message).await?;
519 return Ok(StartedInvoke {
520 receiver: None,
521 request_id,
522 timeout_ms,
523 pending_invokes: self.pending_invokes.clone(),
524 });
525 }
526
527 let (tx, rx) = oneshot::channel();
528 let deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
529
530 {
531 let mut pending = self.pending_invokes.write().await;
532 pending.insert(
533 request_id.clone(),
534 PendingInvoke {
535 resolve: tx,
536 deadline,
537 },
538 );
539 }
540
541 self.send_message(message).await?;
542
543 Ok(StartedInvoke {
544 receiver: Some(rx),
545 request_id,
546 timeout_ms,
547 pending_invokes: self.pending_invokes.clone(),
548 })
549 }
550
551 pub async fn set_session_token(&self, token: String) {
553 *self.session_token.write().await = Some(token);
554 }
555
556 #[allow(dead_code)]
558 pub async fn send_response(&self, response: ExecuteResponse) -> Result<()> {
559 let message = ProtoStreamMessage {
560 message: Some(stream_message::Message::ExecuteResponse(
561 proto::ExecuteResponse {
562 request_id: response.request_id,
563 success: response.success,
564 payload: response.payload,
565 payload_encoding: response.payload_encoding as i32,
566 error: response.error,
567 duration_ms: response.duration_ms as i64,
568 },
569 )),
570 };
571
572 self.send_message(message).await
573 }
574
575 pub async fn disconnect(&mut self) {
577 if let Some(transport) = self.transport.as_mut() {
579 let _ = transport.disconnect().await;
580 }
581
582 self.connected.store(false, Ordering::SeqCst);
583 self.registered.store(false, Ordering::SeqCst);
584 self.transport = None;
585 *self.request_tx.write().await = None;
586
587 let mut pending = self.pending_invokes.write().await;
592 let count = pending.len();
593 pending.clear(); if count > 0 {
595 debug!(
596 "Cancelled {} in-flight invoke request(s) on disconnect",
597 count
598 );
599 }
600
601 debug!("Disconnected from Strike48 server");
602 }
603
604 pub fn is_connected(&self) -> bool {
606 self.connected.load(Ordering::SeqCst)
607 }
608
609 #[allow(dead_code)]
611 pub fn is_registered(&self) -> bool {
612 self.registered.load(Ordering::SeqCst)
613 }
614
615 pub fn mark_registered(&self) {
617 self.registered.store(true, Ordering::SeqCst);
618 }
619
620 #[allow(dead_code)]
622 pub async fn invoke_capability(
623 &self,
624 target_address: &str,
625 payload: Vec<u8>,
626 options: InvokeOptions,
627 ) -> Result<Option<InvokeCapabilityResponse>> {
628 use tokio::sync::oneshot;
629 use tokio::time::{Duration, timeout};
630
631 if !self.registered.load(Ordering::SeqCst) {
632 return Err(ConnectorError::NotRegistered);
633 }
634
635 let request_id = format!("invoke-{}", generate_id());
636 let timeout_ms = options.timeout_ms.unwrap_or(self.default_timeout_ms);
637 let fire_and_forget = options.fire_and_forget.unwrap_or(false);
638
639 let proto_request = proto::InvokeCapabilityRequest {
641 request_id: request_id.clone(),
642 target_address: target_address.to_string(),
643 capability_id: options.capability_id.unwrap_or_default(),
644 payload,
645 payload_encoding: options.payload_encoding.unwrap_or(PayloadEncoding::Json) as i32,
646 context: options.context.unwrap_or_default(),
647 timeout_ms: timeout_ms as i32,
648 fire_and_forget,
649 };
650
651 let message = ProtoStreamMessage {
652 message: Some(stream_message::Message::InvokeRequest(proto_request)),
653 };
654
655 if fire_and_forget {
657 self.send_message(message).await?;
658 return Ok(None);
659 }
660
661 let (tx, rx) = oneshot::channel();
663 let deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
664
665 {
667 let mut pending = self.pending_invokes.write().await;
668 pending.insert(
669 request_id.clone(),
670 PendingInvoke {
671 resolve: tx,
672 deadline,
673 },
674 );
675 }
676
677 self.send_message(message).await?;
678
679 match timeout(Duration::from_millis(timeout_ms), rx).await {
681 Ok(Ok(response)) => Ok(Some(response)),
682 Ok(Err(_)) => {
683 self.pending_invokes.write().await.remove(&request_id);
685 Err(ConnectorError::StreamError(
686 "Response channel closed".to_string(),
687 ))
688 }
689 Err(_) => {
690 self.pending_invokes.write().await.remove(&request_id);
692 Err(ConnectorError::Timeout(format!(
693 "Invoke request {request_id} timed out after {timeout_ms}ms"
694 )))
695 }
696 }
697 }
698
699 pub(crate) async fn handle_invoke_response(
701 &self,
702 response: proto::InvokeCapabilityResponse,
703 ) -> bool {
704 let request_id = response.request_id.clone();
705 let mut pending = self.pending_invokes.write().await;
706
707 if let Some(pending_invoke) = pending.remove(&request_id) {
708 let invoke_response = InvokeCapabilityResponse {
709 request_id: response.request_id,
710 success: response.success,
711 payload: response.payload,
712 payload_encoding: PayloadEncoding::from(response.payload_encoding),
713 error: response.error,
714 duration_ms: response.duration_ms as u64,
715 context: if response.context.is_empty() {
716 None
717 } else {
718 Some(response.context)
719 },
720 error_details: if response.error_details.is_empty() {
721 None
722 } else {
723 Some(response.error_details)
724 },
725 };
726
727 let _ = pending_invoke.resolve.send(invoke_response);
728 true
729 } else {
730 false
731 }
732 }
733
734 #[allow(dead_code)]
736 pub fn get_default_timeout(&self) -> Option<u64> {
737 Some(self.default_timeout_ms)
738 }
739}
740
741#[derive(Debug, Clone, Default)]
743pub struct InvokeOptions {
744 pub payload_encoding: Option<PayloadEncoding>,
745 pub capability_id: Option<String>,
746 pub timeout_ms: Option<u64>,
747 pub fire_and_forget: Option<bool>,
748 pub context: Option<HashMap<String, String>>,
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[tokio::test]
756 async fn test_keepalive_sends_heartbeats() {
757 let (tx, mut rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
758 let connected = Arc::new(AtomicBool::new(true));
759 let sent_at = Arc::new(AtomicU64::new(0));
760
761 ConnectorClient::spawn_keepalive(tx, connected.clone(), sent_at);
762
763 tokio::time::sleep(Duration::from_millis(50)).await;
768
769 connected.store(false, Ordering::SeqCst);
771
772 tokio::time::sleep(Duration::from_millis(100)).await;
774
775 assert!(rx.try_recv().is_err());
778 }
779
780 #[tokio::test]
781 async fn test_keepalive_stops_on_channel_close() {
782 let (tx, rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
783 let connected = Arc::new(AtomicBool::new(true));
784 let sent_at = Arc::new(AtomicU64::new(0));
785
786 ConnectorClient::spawn_keepalive(tx, connected.clone(), sent_at);
787
788 drop(rx);
791
792 tokio::time::sleep(Duration::from_millis(100)).await;
794
795 assert!(connected.load(Ordering::SeqCst));
797 }
798
799 #[tokio::test]
800 async fn test_keepalive_heartbeat_format() {
801 let (tx, mut rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
802 let connected = Arc::new(AtomicBool::new(true));
803
804 let keepalive_tx = tx;
806 let keepalive_connected = connected.clone();
807 tokio::spawn(async move {
808 let mut interval = tokio::time::interval(Duration::from_millis(50));
809 interval.tick().await; interval.tick().await;
812 if !keepalive_connected.load(Ordering::SeqCst) {
813 return;
814 }
815
816 let now_ms = SystemTime::now()
817 .duration_since(UNIX_EPOCH)
818 .map(|d| d.as_millis() as i64)
819 .unwrap_or(0);
820
821 let heartbeat = ProtoStreamMessage {
822 message: Some(proto::stream_message::Message::HeartbeatRequest(
823 proto::HeartbeatRequest {
824 gateway_id: String::new(),
825 timestamp_ms: now_ms,
826 },
827 )),
828 };
829 let _ = keepalive_tx.send(heartbeat);
830 });
831
832 tokio::time::sleep(Duration::from_millis(200)).await;
834
835 let msg = rx.try_recv().expect("should have received a heartbeat");
836 match msg.message {
837 Some(proto::stream_message::Message::HeartbeatRequest(hb)) => {
838 assert!(
839 hb.gateway_id.is_empty(),
840 "gateway_id should be empty (server fills it)"
841 );
842 assert!(hb.timestamp_ms > 0, "timestamp should be set");
843 }
844 other => panic!("expected HeartbeatRequest, got {:?}", other),
845 }
846
847 connected.store(false, Ordering::SeqCst);
848 }
849}