Skip to main content

strike48_connector/
client.rs

1use crate::error::{ConnectorError, Result};
2use crate::transport::{Transport, TransportOptions, TransportType, create_transport};
3use crate::types::*;
4use crate::url_parser::parse_url;
5use crate::utils::{generate_id, sanitize_identifier};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::sync::{RwLock, mpsc};
11use tokio::time::Duration;
12use tracing::debug;
13
14/// Keepalive interval: send a HeartbeatRequest every 30 seconds to prevent
15/// the server's session reaper from killing the session (timeout = 90s).
16const KEEPALIVE_INTERVAL_SECS: u64 = 30;
17
18/// Client configuration options.
19///
20/// Transport is auto-detected from URL scheme when using `url`:
21/// - `grpc://` or `grpcs://` → gRPC transport
22/// - `ws://`, `wss://`, `http://`, `https://` → WebSocket transport
23///
24/// # Examples
25///
26/// ```rust,ignore
27/// // Auto-detect transport from URL (recommended)
28/// ClientOptions {
29///     url: Some("grpcs://connectors.example.com:443".to_string()),
30///     ..Default::default()
31/// }
32///
33/// // WebSocket transport (auto-detected from wss://)
34/// ClientOptions {
35///     url: Some("wss://connectors.example.com:443".to_string()),
36///     ..Default::default()
37/// }
38///
39/// // Legacy: explicit host
40/// ClientOptions {
41///     host: Some("connectors.example.com:443".to_string()),
42///     use_tls: Some(true),
43///     ..Default::default()
44/// }
45/// ```
46#[derive(Debug, Clone, Default)]
47pub struct ClientOptions {
48    /// Strike48 server URL with scheme (preferred, auto-detects transport).
49    ///
50    /// Examples:
51    /// - `grpcs://matrix.example.com:443` - gRPC with TLS
52    /// - `wss://matrix.example.com:443` - WebSocket with TLS
53    pub url: Option<String>,
54    /// Strike48 server host (legacy, use `url` instead).
55    #[deprecated(note = "Use `url` with scheme for auto-detection")]
56    pub host: Option<String>,
57    /// Use TLS for connection (auto-detected from URL scheme when using `url`).
58    pub use_tls: Option<bool>,
59    /// Transport type (auto-detected from URL scheme when using `url`).
60    #[deprecated(note = "Use `url` with scheme for auto-detection")]
61    pub transport: Option<TransportType>,
62    /// Default timeout for operations in milliseconds (default: 30000).
63    pub default_timeout_ms: Option<u64>,
64}
65
66pub use strike48_proto::proto;
67
68use proto::{StreamMessage as ProtoStreamMessage, stream_message};
69
70/// Pending invoke request awaiting response
71struct PendingInvoke {
72    resolve: tokio::sync::oneshot::Sender<crate::types::InvokeCapabilityResponse>,
73    #[allow(dead_code)]
74    deadline: tokio::time::Instant,
75}
76
77/// Result of `start_invoke` -- contains everything needed to wait for the
78/// response outside any parent lock scope.
79pub(crate) struct StartedInvoke {
80    pub receiver: Option<tokio::sync::oneshot::Receiver<crate::types::InvokeCapabilityResponse>>,
81    pub request_id: String,
82    pub timeout_ms: u64,
83    pending_invokes: Arc<RwLock<HashMap<String, PendingInvoke>>>,
84}
85
86impl StartedInvoke {
87    /// Remove the pending invoke entry (cleanup on timeout or channel close).
88    pub(crate) async fn cancel(&self) {
89        self.pending_invokes.write().await.remove(&self.request_id);
90    }
91}
92
93/// gRPC client for connector communication
94///
95/// Supports two transport modes:
96/// - `TransportType::Grpc` (default): Native gRPC over HTTP/2
97/// - `TransportType::WebSocket`: WebSocket over HTTP/1.1 (for corporate proxy compatibility)
98///
99/// # Example
100///
101/// ```rust,ignore
102/// use strike48_connector::{ConnectorClient, client::{ClientOptions, TransportType}};
103///
104/// // Using native gRPC (default)
105/// let client = ConnectorClient::new("connectors-poc-us.strike48.com:443".to_string(), true);
106///
107/// // Using WebSocket for corporate proxy compatibility
108/// let client = ConnectorClient::with_options(ClientOptions {
109///     host: "connectors-poc-us.strike48.com:443".to_string(),
110///     use_tls: true,
111///     transport: TransportType::WebSocket,
112///     default_timeout_ms: 30000,
113/// });
114/// ```
115pub struct ConnectorClient {
116    host: String,
117    use_tls: bool,
118    transport_type: TransportType,
119    /// Transport layer abstraction (gRPC or WebSocket)
120    transport: Option<Box<dyn Transport>>,
121    /// Atomic flag for connection state (lock-free)
122    connected: Arc<AtomicBool>,
123    /// Atomic flag for registration state (lock-free)
124    registered: Arc<AtomicBool>,
125    session_token: Arc<RwLock<Option<String>>>,
126    #[allow(dead_code)] // Connector address after registration
127    connector_address: Arc<RwLock<Option<String>>>,
128    request_tx: Arc<RwLock<Option<mpsc::UnboundedSender<ProtoStreamMessage>>>>,
129    pending_invokes: Arc<RwLock<HashMap<String, PendingInvoke>>>,
130    default_timeout_ms: u64,
131    /// Epoch nanos when the last HeartbeatRequest was sent (for RTT calculation).
132    heartbeat_sent_at_nanos: Arc<AtomicU64>,
133}
134
135impl ConnectorClient {
136    /// Create a new ConnectorClient with default transport (gRPC).
137    ///
138    /// For URL-based transport auto-detection, use `with_options` with `url`.
139    #[allow(dead_code)]
140    pub fn new(host: String, use_tls: bool) -> Self {
141        #[allow(deprecated)]
142        Self::with_options(ClientOptions {
143            url: None,
144            host: Some(host),
145            use_tls: Some(use_tls),
146            transport: Some(TransportType::default()),
147            default_timeout_ms: Some(30000),
148        })
149    }
150
151    /// Create a new ConnectorClient with full configuration.
152    ///
153    /// Transport is auto-detected from URL scheme when using `url` option:
154    /// - `grpc://` or `grpcs://` → gRPC transport
155    /// - `ws://`, `wss://`, `http://`, `https://` → WebSocket transport
156    ///
157    /// # Examples
158    ///
159    /// ```rust,ignore
160    /// // Auto-detect transport from URL (recommended)
161    /// let client = ConnectorClient::with_options(ClientOptions {
162    ///     url: Some("grpcs://connectors.example.com:443".to_string()),
163    ///     ..Default::default()
164    /// });
165    ///
166    /// // WebSocket transport (auto-detected from wss://)
167    /// let client = ConnectorClient::with_options(ClientOptions {
168    ///     url: Some("wss://connectors.example.com:443".to_string()),
169    ///     ..Default::default()
170    /// });
171    /// ```
172    #[allow(deprecated)]
173    pub fn with_options(opts: ClientOptions) -> Self {
174        // Parse URL to auto-detect transport, TLS, and normalize host
175        let (host, use_tls, transport) = if let Some(url) = &opts.url {
176            // URL-based configuration (auto-detect transport from scheme)
177            match parse_url(url) {
178                Ok(parsed) => {
179                    let host = parsed.host_port();
180                    let tls = opts.use_tls.unwrap_or(parsed.use_tls);
181                    let trans = opts.transport.unwrap_or(parsed.transport);
182                    (host, tls, trans)
183                }
184                Err(_) => {
185                    // Fallback to legacy behavior
186                    let host = url.clone();
187                    let tls = opts.use_tls.unwrap_or(false);
188                    let trans = opts.transport.unwrap_or(TransportType::Grpc);
189                    (host, tls, trans)
190                }
191            }
192        } else if let Some(host) = &opts.host {
193            // Try to parse host as URL for auto-detection
194            match parse_url(host) {
195                Ok(parsed) => {
196                    let host_port = parsed.host_port();
197                    let tls = opts.use_tls.unwrap_or(parsed.use_tls);
198                    let trans = opts.transport.unwrap_or(parsed.transport);
199                    (host_port, tls, trans)
200                }
201                Err(_) => {
202                    // Simple host:port format
203                    let tls = opts.use_tls.unwrap_or(false);
204                    let trans = opts.transport.unwrap_or(TransportType::Grpc);
205                    (host.clone(), tls, trans)
206                }
207            }
208        } else {
209            // No URL or host provided, use defaults
210            ("localhost:50061".to_string(), false, TransportType::Grpc)
211        };
212
213        if transport == TransportType::WebSocket {
214            debug!(
215                "WebSocket transport selected (detected from URL scheme). \
216                This transport works through corporate proxies that block HTTP/2."
217            );
218        }
219
220        debug!(
221            "ConnectorClient initialized: {} (transport: {:?}, TLS: {})",
222            host, transport, use_tls
223        );
224
225        Self {
226            host,
227            use_tls,
228            transport_type: transport,
229            transport: None, // Created on connect
230            connected: Arc::new(AtomicBool::new(false)),
231            registered: Arc::new(AtomicBool::new(false)),
232            session_token: Arc::new(RwLock::new(None)),
233            connector_address: Arc::new(RwLock::new(None)),
234            request_tx: Arc::new(RwLock::new(None)),
235            pending_invokes: Arc::new(RwLock::new(HashMap::new())),
236            default_timeout_ms: opts.default_timeout_ms.unwrap_or(30000),
237            heartbeat_sent_at_nanos: Arc::new(AtomicU64::new(0)),
238        }
239    }
240
241    /// Connect to Strike48 server using the configured transport.
242    ///
243    /// Uses the configured transport:
244    /// - `TransportType::Grpc`: Native gRPC over HTTP/2 (default)
245    /// - `TransportType::WebSocket`: WebSocket over HTTP/1.1 (for corporate proxy compatibility)
246    pub async fn connect_channel(&mut self) -> Result<()> {
247        debug!(
248            "Connecting to Strike48 server: {} (transport: {:?})",
249            self.host, self.transport_type
250        );
251
252        // Create transport options
253        let options = TransportOptions {
254            host: self.host.clone(),
255            use_tls: self.use_tls,
256            connect_timeout_ms: Some(10000),
257            default_timeout_ms: Some(self.default_timeout_ms),
258            channel_capacity: Some(1024), // Bounded for backpressure
259        };
260
261        // Create and connect the appropriate transport
262        let mut transport = create_transport(self.transport_type, options);
263        transport.connect().await?;
264
265        self.connected.store(true, Ordering::SeqCst);
266        self.transport = Some(transport);
267
268        debug!("Connected to Strike48 server");
269        Ok(())
270    }
271
272    /// Send registration request on the stream
273    /// Note: Registration happens through the bidirectional stream, not a separate RPC
274    #[allow(dead_code)]
275    pub async fn send_register_request(
276        &self,
277        tenant_id: &str,
278        connector_type: &str,
279        instance_id: &str,
280        capabilities: &ConnectorCapabilities,
281        auth_token: &str,
282    ) -> Result<()> {
283        // Convert capabilities to protobuf
284        let capabilities_proto = proto::ConnectorCapabilities {
285            connector_type: capabilities.connector_type.clone(),
286            version: capabilities.version.clone(),
287            supported_encodings: capabilities
288                .supported_encodings
289                .iter()
290                .map(|e| *e as i32)
291                .collect(),
292            behaviors: capabilities.behaviors.iter().map(|b| *b as i32).collect(),
293            metadata: capabilities.metadata.clone(),
294            task_types: capabilities
295                .task_types
296                .as_ref()
297                .map(|tts| {
298                    tts.iter()
299                        .map(|tt| proto::TaskTypeSchema {
300                            task_type_id: tt.task_type_id.clone(),
301                            name: tt.name.clone(),
302                            description: tt.description.clone(),
303                            category: tt.category.clone(),
304                            icon: tt.icon.clone(),
305                            input_schema_json: tt.input_schema_json.clone(),
306                            output_schema_json: tt.output_schema_json.clone(),
307                        })
308                        .collect()
309                })
310                .unwrap_or_default(),
311        };
312
313        // Sanitize identifiers to prevent address parsing issues
314        // Dots and colons are used as separators in addresses
315        let sanitized_instance_id = sanitize_identifier(instance_id);
316
317        // Default instance metadata (display_name = instance_id)
318        let instance_metadata = Some(proto::InstanceMetadata {
319            display_name: sanitized_instance_id.clone(),
320            tags: Vec::new(),
321            metadata: std::collections::HashMap::new(),
322        });
323
324        let mut request = proto::RegisterConnectorRequest {
325            tenant_id: sanitize_identifier(tenant_id),
326            connector_type: sanitize_identifier(connector_type),
327            instance_id: sanitized_instance_id,
328            capabilities: Some(capabilities_proto),
329            jwt_token: if auth_token.is_empty() {
330                String::new()
331            } else {
332                auth_token.to_string()
333            },
334            session_token: String::new(),
335            scope: 0, // Default scope
336            instance_metadata,
337        };
338
339        // Use session token if available
340        if let Some(session_token) = self.session_token.read().await.as_ref() {
341            request.session_token = session_token.clone();
342            debug!("Using session token for reconnection");
343        }
344
345        // Send registration request on the stream
346        let message = ProtoStreamMessage {
347            message: Some(proto::stream_message::Message::RegisterRequest(request)),
348        };
349
350        self.send_message(message).await
351    }
352
353    /// Start bidirectional streaming using the transport abstraction.
354    ///
355    /// This method works identically for both gRPC and WebSocket transports.
356    /// The transport handles protocol-specific details internally.
357    ///
358    /// A background keepalive task is automatically spawned that sends
359    /// `HeartbeatRequest` messages every 30 seconds. This prevents the
360    /// server's session reaper (90s timeout) from killing the session,
361    /// regardless of whether the caller sends their own metrics/heartbeats.
362    pub async fn start_stream_with_registration(
363        &mut self,
364        initial_message: ProtoStreamMessage,
365    ) -> Result<(
366        mpsc::UnboundedSender<ProtoStreamMessage>,
367        mpsc::UnboundedReceiver<ProtoStreamMessage>,
368    )> {
369        debug!("start_stream: getting transport reference");
370        let transport = self
371            .transport
372            .as_mut()
373            .ok_or(ConnectorError::NotConnected)?;
374
375        debug!("start_stream: starting transport stream with initial message");
376
377        // Start the bidirectional stream via transport abstraction
378        // Pass the initial message so it's sent immediately (prevents deadlock in gRPC)
379        let (tx, rx) = transport.start_stream(Some(initial_message)).await?;
380
381        debug!("start_stream: transport stream started successfully");
382
383        // Store the tx for later use
384        *self.request_tx.write().await = Some(tx.clone());
385
386        // Spawn keepalive task: sends HeartbeatRequest every 30s to prevent
387        // the server from reaping the session. Stops automatically when the
388        // stream closes (tx dropped) or when the client disconnects.
389        Self::spawn_keepalive(
390            tx.clone(),
391            self.connected.clone(),
392            self.heartbeat_sent_at_nanos.clone(),
393        );
394
395        Ok((tx, rx))
396    }
397
398    /// Spawn a background task that sends periodic `HeartbeatRequest` messages
399    /// on the stream to prevent the server's session reaper from timing out.
400    ///
401    /// Records `SystemTime` epoch nanos into `sent_at_nanos` right before each
402    /// send so the receiver can compute round-trip time.
403    fn spawn_keepalive(
404        tx: mpsc::UnboundedSender<ProtoStreamMessage>,
405        connected: Arc<AtomicBool>,
406        sent_at_nanos: Arc<AtomicU64>,
407    ) {
408        tokio::spawn(async move {
409            let mut interval = tokio::time::interval(Duration::from_secs(KEEPALIVE_INTERVAL_SECS));
410            // First tick fires immediately; skip it so the first heartbeat
411            // is sent after one full interval.
412            interval.tick().await;
413
414            loop {
415                interval.tick().await;
416
417                if !connected.load(Ordering::SeqCst) {
418                    debug!("keepalive: client disconnected, stopping");
419                    break;
420                }
421
422                let now = SystemTime::now()
423                    .duration_since(UNIX_EPOCH)
424                    .unwrap_or_default();
425                let now_ms = now.as_millis() as i64;
426
427                sent_at_nanos.store(now.as_nanos() as u64, Ordering::Release);
428
429                let heartbeat = ProtoStreamMessage {
430                    message: Some(proto::stream_message::Message::HeartbeatRequest(
431                        proto::HeartbeatRequest {
432                            gateway_id: String::new(),
433                            timestamp_ms: now_ms,
434                        },
435                    )),
436                };
437
438                if tx.send(heartbeat).is_err() {
439                    debug!("keepalive: stream closed, stopping");
440                    break;
441                }
442            }
443        });
444    }
445
446    /// Returns the epoch nanos when the last `HeartbeatRequest` was sent.
447    /// Used by `ConnectorRunner` to compute RTT on `HeartbeatResponse`.
448    pub(crate) fn heartbeat_sent_at_nanos(&self) -> &Arc<AtomicU64> {
449        &self.heartbeat_sent_at_nanos
450    }
451
452    /// Send a message on the stream
453    pub async fn send_message(&self, message: ProtoStreamMessage) -> Result<()> {
454        if let Some(tx) = self.request_tx.read().await.as_ref() {
455            tx.send(message)
456                .map_err(|e| ConnectorError::StreamError(format!("Failed to send message: {e}")))?;
457            Ok(())
458        } else {
459            Err(ConnectorError::StreamError(
460                "Stream not started".to_string(),
461            ))
462        }
463    }
464
465    /// Clone the stream sender for use outside a parent lock scope.
466    ///
467    /// Returns the cloned `UnboundedSender` so callers can send messages
468    /// without holding any outer lock across the send.
469    pub(crate) async fn clone_message_tx(
470        &self,
471    ) -> Result<mpsc::UnboundedSender<ProtoStreamMessage>> {
472        self.request_tx
473            .read()
474            .await
475            .as_ref()
476            .cloned()
477            .ok_or_else(|| ConnectorError::StreamError("Stream not started".to_string()))
478    }
479
480    /// Prepare an invoke request: validate state, register the response channel,
481    /// and send the request message. Returns the oneshot receiver to wait on
482    /// (or `None` for fire-and-forget), the request ID, and the timeout.
483    ///
484    /// Callers should drop any parent locks before awaiting the receiver so that
485    /// reconnection is not blocked during the (potentially long) wait.
486    pub(crate) async fn start_invoke(
487        &self,
488        target_address: &str,
489        payload: Vec<u8>,
490        options: InvokeOptions,
491    ) -> Result<StartedInvoke> {
492        use tokio::sync::oneshot;
493
494        if !self.registered.load(Ordering::SeqCst) {
495            return Err(ConnectorError::NotRegistered);
496        }
497
498        let request_id = format!("invoke-{}", generate_id());
499        let timeout_ms = options.timeout_ms.unwrap_or(self.default_timeout_ms);
500        let fire_and_forget = options.fire_and_forget.unwrap_or(false);
501
502        let proto_request = proto::InvokeCapabilityRequest {
503            request_id: request_id.clone(),
504            target_address: target_address.to_string(),
505            capability_id: options.capability_id.unwrap_or_default(),
506            payload,
507            payload_encoding: options.payload_encoding.unwrap_or(PayloadEncoding::Json) as i32,
508            context: options.context.unwrap_or_default(),
509            timeout_ms: timeout_ms as i32,
510            fire_and_forget,
511        };
512
513        let message = ProtoStreamMessage {
514            message: Some(stream_message::Message::InvokeRequest(proto_request)),
515        };
516
517        if fire_and_forget {
518            self.send_message(message).await?;
519            return Ok(StartedInvoke {
520                receiver: None,
521                request_id,
522                timeout_ms,
523                pending_invokes: self.pending_invokes.clone(),
524            });
525        }
526
527        let (tx, rx) = oneshot::channel();
528        let deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
529
530        {
531            let mut pending = self.pending_invokes.write().await;
532            pending.insert(
533                request_id.clone(),
534                PendingInvoke {
535                    resolve: tx,
536                    deadline,
537                },
538            );
539        }
540
541        self.send_message(message).await?;
542
543        Ok(StartedInvoke {
544            receiver: Some(rx),
545            request_id,
546            timeout_ms,
547            pending_invokes: self.pending_invokes.clone(),
548        })
549    }
550
551    /// Set session token
552    pub async fn set_session_token(&self, token: String) {
553        *self.session_token.write().await = Some(token);
554    }
555
556    /// Send execute response
557    #[allow(dead_code)]
558    pub async fn send_response(&self, response: ExecuteResponse) -> Result<()> {
559        let message = ProtoStreamMessage {
560            message: Some(stream_message::Message::ExecuteResponse(
561                proto::ExecuteResponse {
562                    request_id: response.request_id,
563                    success: response.success,
564                    payload: response.payload,
565                    payload_encoding: response.payload_encoding as i32,
566                    error: response.error,
567                    duration_ms: response.duration_ms as i64,
568                },
569            )),
570        };
571
572        self.send_message(message).await
573    }
574
575    /// Disconnect from Strike48 server
576    pub async fn disconnect(&mut self) {
577        // Disconnect transport
578        if let Some(transport) = self.transport.as_mut() {
579            let _ = transport.disconnect().await;
580        }
581
582        self.connected.store(false, Ordering::SeqCst);
583        self.registered.store(false, Ordering::SeqCst);
584        self.transport = None;
585        *self.request_tx.write().await = None;
586
587        // Cancel all in-flight invoke requests. Without this, callers waiting
588        // on the oneshot receiver will block until their individual timeouts
589        // fire (up to 30s each). Draining eagerly on disconnect lets them fail
590        // fast and allows the connector to reconnect cleanly.
591        let mut pending = self.pending_invokes.write().await;
592        let count = pending.len();
593        pending.clear(); // drops all PendingInvoke entries, closing the oneshot senders
594        if count > 0 {
595            debug!(
596                "Cancelled {} in-flight invoke request(s) on disconnect",
597                count
598            );
599        }
600
601        debug!("Disconnected from Strike48 server");
602    }
603
604    /// Check if connected (lock-free atomic check)
605    pub fn is_connected(&self) -> bool {
606        self.connected.load(Ordering::SeqCst)
607    }
608
609    /// Check if registered (lock-free atomic check)
610    #[allow(dead_code)]
611    pub fn is_registered(&self) -> bool {
612        self.registered.load(Ordering::SeqCst)
613    }
614
615    /// Mark this client as successfully registered with the server.
616    pub fn mark_registered(&self) {
617        self.registered.store(true, Ordering::SeqCst);
618    }
619
620    /// Invoke a capability on another connector through Strike48 routing
621    #[allow(dead_code)]
622    pub async fn invoke_capability(
623        &self,
624        target_address: &str,
625        payload: Vec<u8>,
626        options: InvokeOptions,
627    ) -> Result<Option<InvokeCapabilityResponse>> {
628        use tokio::sync::oneshot;
629        use tokio::time::{Duration, timeout};
630
631        if !self.registered.load(Ordering::SeqCst) {
632            return Err(ConnectorError::NotRegistered);
633        }
634
635        let request_id = format!("invoke-{}", generate_id());
636        let timeout_ms = options.timeout_ms.unwrap_or(self.default_timeout_ms);
637        let fire_and_forget = options.fire_and_forget.unwrap_or(false);
638
639        // Convert to protobuf
640        let proto_request = proto::InvokeCapabilityRequest {
641            request_id: request_id.clone(),
642            target_address: target_address.to_string(),
643            capability_id: options.capability_id.unwrap_or_default(),
644            payload,
645            payload_encoding: options.payload_encoding.unwrap_or(PayloadEncoding::Json) as i32,
646            context: options.context.unwrap_or_default(),
647            timeout_ms: timeout_ms as i32,
648            fire_and_forget,
649        };
650
651        let message = ProtoStreamMessage {
652            message: Some(stream_message::Message::InvokeRequest(proto_request)),
653        };
654
655        // For fire-and-forget, just send and return None
656        if fire_and_forget {
657            self.send_message(message).await?;
658            return Ok(None);
659        }
660
661        // For synchronous invoke, wait for response
662        let (tx, rx) = oneshot::channel();
663        let deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms);
664
665        // Store pending request
666        {
667            let mut pending = self.pending_invokes.write().await;
668            pending.insert(
669                request_id.clone(),
670                PendingInvoke {
671                    resolve: tx,
672                    deadline,
673                },
674            );
675        }
676
677        self.send_message(message).await?;
678
679        // Wait for response with timeout
680        match timeout(Duration::from_millis(timeout_ms), rx).await {
681            Ok(Ok(response)) => Ok(Some(response)),
682            Ok(Err(_)) => {
683                // Clean up pending request
684                self.pending_invokes.write().await.remove(&request_id);
685                Err(ConnectorError::StreamError(
686                    "Response channel closed".to_string(),
687                ))
688            }
689            Err(_) => {
690                // Clean up pending request
691                self.pending_invokes.write().await.remove(&request_id);
692                Err(ConnectorError::Timeout(format!(
693                    "Invoke request {request_id} timed out after {timeout_ms}ms"
694                )))
695            }
696        }
697    }
698
699    /// Handle incoming invoke response
700    pub(crate) async fn handle_invoke_response(
701        &self,
702        response: proto::InvokeCapabilityResponse,
703    ) -> bool {
704        let request_id = response.request_id.clone();
705        let mut pending = self.pending_invokes.write().await;
706
707        if let Some(pending_invoke) = pending.remove(&request_id) {
708            let invoke_response = InvokeCapabilityResponse {
709                request_id: response.request_id,
710                success: response.success,
711                payload: response.payload,
712                payload_encoding: PayloadEncoding::from(response.payload_encoding),
713                error: response.error,
714                duration_ms: response.duration_ms as u64,
715                context: if response.context.is_empty() {
716                    None
717                } else {
718                    Some(response.context)
719                },
720                error_details: if response.error_details.is_empty() {
721                    None
722                } else {
723                    Some(response.error_details)
724                },
725            };
726
727            let _ = pending_invoke.resolve.send(invoke_response);
728            true
729        } else {
730            false
731        }
732    }
733
734    /// Get default timeout in milliseconds
735    #[allow(dead_code)]
736    pub fn get_default_timeout(&self) -> Option<u64> {
737        Some(self.default_timeout_ms)
738    }
739}
740
741/// Options for invoke capability
742#[derive(Debug, Clone, Default)]
743pub struct InvokeOptions {
744    pub payload_encoding: Option<PayloadEncoding>,
745    pub capability_id: Option<String>,
746    pub timeout_ms: Option<u64>,
747    pub fire_and_forget: Option<bool>,
748    pub context: Option<HashMap<String, String>>,
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754
755    #[tokio::test]
756    async fn test_keepalive_sends_heartbeats() {
757        let (tx, mut rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
758        let connected = Arc::new(AtomicBool::new(true));
759        let sent_at = Arc::new(AtomicU64::new(0));
760
761        ConnectorClient::spawn_keepalive(tx, connected.clone(), sent_at);
762
763        // Wait for at least one heartbeat (interval is 30s, but we can't wait
764        // that long in a test -- override by using a shorter sleep).
765        // Instead, test the mechanism: spawn_keepalive skips the first tick,
766        // so we verify the task is alive by disconnecting after a moment.
767        tokio::time::sleep(Duration::from_millis(50)).await;
768
769        // Signal disconnect to stop the keepalive
770        connected.store(false, Ordering::SeqCst);
771
772        // Give the task time to notice the flag and exit
773        tokio::time::sleep(Duration::from_millis(100)).await;
774
775        // The channel should still be valid (not panicked)
776        // No heartbeat expected yet since interval is 30s
777        assert!(rx.try_recv().is_err());
778    }
779
780    #[tokio::test]
781    async fn test_keepalive_stops_on_channel_close() {
782        let (tx, rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
783        let connected = Arc::new(AtomicBool::new(true));
784        let sent_at = Arc::new(AtomicU64::new(0));
785
786        ConnectorClient::spawn_keepalive(tx, connected.clone(), sent_at);
787
788        // Drop the receiver -- the next send in the keepalive task will fail
789        // and the task should exit cleanly.
790        drop(rx);
791
792        // Give the task time to attempt a send and exit
793        tokio::time::sleep(Duration::from_millis(100)).await;
794
795        // If we get here without a panic, the task handled the closed channel gracefully.
796        assert!(connected.load(Ordering::SeqCst));
797    }
798
799    #[tokio::test]
800    async fn test_keepalive_heartbeat_format() {
801        let (tx, mut rx) = mpsc::unbounded_channel::<ProtoStreamMessage>();
802        let connected = Arc::new(AtomicBool::new(true));
803
804        // Spawn a keepalive with a very short interval for testing
805        let keepalive_tx = tx;
806        let keepalive_connected = connected.clone();
807        tokio::spawn(async move {
808            let mut interval = tokio::time::interval(Duration::from_millis(50));
809            interval.tick().await; // skip first
810
811            interval.tick().await;
812            if !keepalive_connected.load(Ordering::SeqCst) {
813                return;
814            }
815
816            let now_ms = SystemTime::now()
817                .duration_since(UNIX_EPOCH)
818                .map(|d| d.as_millis() as i64)
819                .unwrap_or(0);
820
821            let heartbeat = ProtoStreamMessage {
822                message: Some(proto::stream_message::Message::HeartbeatRequest(
823                    proto::HeartbeatRequest {
824                        gateway_id: String::new(),
825                        timestamp_ms: now_ms,
826                    },
827                )),
828            };
829            let _ = keepalive_tx.send(heartbeat);
830        });
831
832        // Wait for the heartbeat
833        tokio::time::sleep(Duration::from_millis(200)).await;
834
835        let msg = rx.try_recv().expect("should have received a heartbeat");
836        match msg.message {
837            Some(proto::stream_message::Message::HeartbeatRequest(hb)) => {
838                assert!(
839                    hb.gateway_id.is_empty(),
840                    "gateway_id should be empty (server fills it)"
841                );
842                assert!(hb.timestamp_ms > 0, "timestamp should be set");
843            }
844            other => panic!("expected HeartbeatRequest, got {:?}", other),
845        }
846
847        connected.store(false, Ordering::SeqCst);
848    }
849}