Skip to main content

smith_protocol/
negotiation.rs

1use crate::{capabilities, Event};
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use uuid::Uuid;
6
7/// Protocol version constants
8pub const PROTOCOL_V0: u32 = 0;
9pub const PROTOCOL_V1: u32 = 1;
10pub const CURRENT_VERSION: u32 = PROTOCOL_V0; // v0 is current stable
11
12/// Supported protocol versions in order of preference (newest first)
13pub const SUPPORTED_VERSIONS: &[u32] = &[PROTOCOL_V1, PROTOCOL_V0];
14
15/// Version negotiation result
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct NegotiationResult {
18    pub version: u32,
19    pub capabilities: Vec<String>,
20    pub fallback_reason: Option<String>,
21    pub service_info: HashMap<String, String>,
22}
23
24/// Client information for negotiation
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ClientInfo {
27    pub supported_versions: Vec<u32>,
28    pub requested_capabilities: Vec<String>,
29    pub client_metadata: HashMap<String, String>,
30}
31
32/// Version and capability negotiator
33pub struct ProtocolNegotiator {
34    service_id: Uuid,
35    available_capabilities: Vec<String>,
36    service_metadata: HashMap<String, String>,
37}
38
39impl ProtocolNegotiator {
40    pub fn new(service_id: Uuid) -> Self {
41        let mut service_metadata = HashMap::new();
42        service_metadata.insert(
43            "service_name".to_string(),
44            "claude-code-rs-core".to_string(),
45        );
46        service_metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string());
47        service_metadata.insert(
48            "build_timestamp".to_string(),
49            std::env::var("BUILD_TIMESTAMP").unwrap_or_else(|_| "unknown".to_string()),
50        );
51        service_metadata.insert(
52            "git_commit".to_string(),
53            std::env::var("GIT_COMMIT").unwrap_or_else(|_| "unknown".to_string()),
54        );
55
56        // Determine available capabilities based on compile-time features
57        let available_capabilities = vec![
58            capabilities::SHELL_EXEC.to_string(),
59            capabilities::REPLAY.to_string(),
60            capabilities::TRACING.to_string(),
61            #[cfg(feature = "hooks-quickjs")]
62            capabilities::HOOKS_JS.to_string(),
63            #[cfg(feature = "hooks-rust")]
64            capabilities::HOOKS_RUST.to_string(),
65            #[cfg(feature = "nats")]
66            capabilities::NATS.to_string(),
67        ];
68
69        #[cfg(feature = "protobuf")]
70        service_metadata.insert("protobuf_support".to_string(), "true".to_string());
71
72        Self {
73            service_id,
74            available_capabilities,
75            service_metadata,
76        }
77    }
78
79    /// Negotiate protocol version and capabilities with a client
80    pub fn negotiate(&self, client_info: ClientInfo) -> Result<NegotiationResult> {
81        // Find the highest mutually supported version
82        let selected_version = self.select_version(&client_info.supported_versions)?;
83
84        // Filter capabilities to those we actually support
85        let granted_capabilities = self.filter_capabilities(&client_info.requested_capabilities);
86
87        // Determine if fallback is needed
88        let fallback_reason =
89            self.check_fallback_conditions(selected_version, &granted_capabilities);
90
91        // Add negotiation metadata
92        let mut service_info = self.service_metadata.clone();
93        service_info.insert(
94            "negotiated_version".to_string(),
95            selected_version.to_string(),
96        );
97        service_info.insert(
98            "granted_capabilities".to_string(),
99            granted_capabilities.len().to_string(),
100        );
101
102        Ok(NegotiationResult {
103            version: selected_version,
104            capabilities: granted_capabilities,
105            fallback_reason,
106            service_info,
107        })
108    }
109
110    /// Select the best mutually supported protocol version
111    fn select_version(&self, client_versions: &[u32]) -> Result<u32> {
112        // Find the highest version that both client and server support
113        for &server_version in SUPPORTED_VERSIONS {
114            if client_versions.contains(&server_version) {
115                return Ok(server_version);
116            }
117        }
118
119        Err(anyhow!(
120            "No compatible protocol version found. Server supports: {:?}, Client supports: {:?}",
121            SUPPORTED_VERSIONS,
122            client_versions
123        ))
124    }
125
126    /// Filter requested capabilities to those actually available
127    fn filter_capabilities(&self, requested: &[String]) -> Vec<String> {
128        requested
129            .iter()
130            .filter(|cap| self.available_capabilities.contains(cap))
131            .cloned()
132            .collect()
133    }
134
135    /// Check if fallback to v0 is needed despite v1 being negotiated
136    fn check_fallback_conditions(&self, version: u32, _capabilities: &[String]) -> Option<String> {
137        match version {
138            PROTOCOL_V1 => {
139                // Check if protobuf feature is actually enabled
140                #[cfg(not(feature = "protobuf"))]
141                {
142                    Some("Protobuf support not compiled in, falling back to JSONL".to_string())
143                }
144
145                #[cfg(feature = "protobuf")]
146                {
147                    // Could add other runtime checks here
148                    None
149                }
150            }
151            PROTOCOL_V0 => None, // v0 always works
152            _ => Some(format!("Unsupported version {}, using v0", version)),
153        }
154    }
155
156    /// Create a Ready event after successful negotiation
157    pub fn create_ready_event(&self, result: &NegotiationResult) -> Event {
158        Event::Ready {
159            version: result.version,
160            capabilities: result.capabilities.clone(),
161            service_id: self.service_id,
162        }
163    }
164
165    /// Get available capabilities
166    pub fn get_available_capabilities(&self) -> &[String] {
167        &self.available_capabilities
168    }
169
170    /// Check if a specific capability is supported
171    pub fn supports_capability(&self, capability: &str) -> bool {
172        self.available_capabilities
173            .contains(&capability.to_string())
174    }
175
176    /// Get service metadata
177    pub fn get_service_metadata(&self) -> &HashMap<String, String> {
178        &self.service_metadata
179    }
180}
181
182/// Capability compatibility checker
183pub struct CapabilityChecker;
184
185impl CapabilityChecker {
186    /// Check if requested capabilities are compatible with each other
187    pub fn check_compatibility(capabilities: &[String]) -> Result<Vec<String>> {
188        let mut warnings = Vec::new();
189
190        // Check for conflicting hook systems
191        let has_js_hooks = capabilities.contains(&capabilities::HOOKS_JS.to_string());
192        let has_rust_hooks = capabilities.contains(&capabilities::HOOKS_RUST.to_string());
193
194        if has_js_hooks && has_rust_hooks {
195            warnings
196                .push("Both JS and Rust hooks enabled - performance may be impacted".to_string());
197        }
198
199        // Check for NATS without tracing (suboptimal observability)
200        let has_nats = capabilities.contains(&capabilities::NATS.to_string());
201        let has_tracing = capabilities.contains(&capabilities::TRACING.to_string());
202
203        if has_nats && !has_tracing {
204            warnings.push("NATS enabled without tracing - reduced observability".to_string());
205        }
206
207        // Warn about replay overhead
208        let has_replay = capabilities.contains(&capabilities::REPLAY.to_string());
209        if has_replay {
210            warnings.push(
211                "Replay enabled - performance overhead for recording all operations".to_string(),
212            );
213        }
214
215        Ok(warnings)
216    }
217
218    /// Get recommended capabilities for a given use case
219    pub fn recommend_capabilities(use_case: &str) -> Vec<String> {
220        match use_case {
221            "development" => vec![
222                capabilities::SHELL_EXEC.to_string(),
223                capabilities::HOOKS_JS.to_string(),
224                capabilities::REPLAY.to_string(),
225                capabilities::TRACING.to_string(),
226            ],
227            "production" => vec![
228                capabilities::SHELL_EXEC.to_string(),
229                capabilities::HOOKS_RUST.to_string(),
230                capabilities::TRACING.to_string(),
231                capabilities::NATS.to_string(),
232            ],
233            "testing" => vec![
234                capabilities::SHELL_EXEC.to_string(),
235                capabilities::REPLAY.to_string(),
236                capabilities::TRACING.to_string(),
237            ],
238            "minimal" => vec![capabilities::SHELL_EXEC.to_string()],
239            _ => vec![
240                capabilities::SHELL_EXEC.to_string(),
241                capabilities::TRACING.to_string(),
242            ],
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_version_negotiation() {
253        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
254
255        // Test successful negotiation
256        let client_info = ClientInfo {
257            supported_versions: vec![0, 1],
258            requested_capabilities: vec![
259                capabilities::SHELL_EXEC.to_string(),
260                capabilities::TRACING.to_string(),
261            ],
262            client_metadata: HashMap::new(),
263        };
264
265        let result = negotiator.negotiate(client_info).unwrap();
266
267        // Should select the highest mutually supported version
268        assert!(result.version <= 1);
269        assert!(result
270            .capabilities
271            .contains(&capabilities::SHELL_EXEC.to_string()));
272        assert!(result
273            .capabilities
274            .contains(&capabilities::TRACING.to_string()));
275    }
276
277    #[test]
278    fn test_incompatible_versions() {
279        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
280
281        let client_info = ClientInfo {
282            supported_versions: vec![999], // Unsupported version
283            requested_capabilities: vec![],
284            client_metadata: HashMap::new(),
285        };
286
287        let result = negotiator.negotiate(client_info);
288        assert!(result.is_err());
289    }
290
291    #[test]
292    fn test_capability_filtering() {
293        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
294
295        let client_info = ClientInfo {
296            supported_versions: vec![0],
297            requested_capabilities: vec![
298                capabilities::SHELL_EXEC.to_string(),
299                "non_existent_capability".to_string(),
300            ],
301            client_metadata: HashMap::new(),
302        };
303
304        let result = negotiator.negotiate(client_info).unwrap();
305
306        // Should only grant supported capabilities
307        assert!(result
308            .capabilities
309            .contains(&capabilities::SHELL_EXEC.to_string()));
310        assert!(!result
311            .capabilities
312            .contains(&"non_existent_capability".to_string()));
313    }
314
315    #[test]
316    fn test_capability_compatibility() {
317        let warnings = CapabilityChecker::check_compatibility(&[
318            capabilities::HOOKS_JS.to_string(),
319            capabilities::HOOKS_RUST.to_string(),
320        ])
321        .unwrap();
322
323        // Should warn about conflicting hook systems
324        assert!(!warnings.is_empty());
325        assert!(warnings[0].contains("JS and Rust hooks"));
326    }
327
328    #[test]
329    fn test_use_case_recommendations() {
330        let dev_caps = CapabilityChecker::recommend_capabilities("development");
331        let prod_caps = CapabilityChecker::recommend_capabilities("production");
332
333        // Development should include replay for debugging
334        assert!(dev_caps.contains(&capabilities::REPLAY.to_string()));
335        assert!(dev_caps.contains(&capabilities::HOOKS_JS.to_string()));
336
337        // Production should prioritize performance
338        assert!(!prod_caps.contains(&capabilities::REPLAY.to_string()));
339        assert!(prod_caps.contains(&capabilities::HOOKS_RUST.to_string()));
340        assert!(prod_caps.contains(&capabilities::NATS.to_string()));
341    }
342
343    #[test]
344    fn test_protocol_constants() {
345        // Test protocol version constants
346        assert_eq!(PROTOCOL_V0, 0);
347        assert_eq!(PROTOCOL_V1, 1);
348        assert_eq!(CURRENT_VERSION, PROTOCOL_V0);
349
350        // Test supported versions array
351        assert!(SUPPORTED_VERSIONS.contains(&PROTOCOL_V0));
352        assert!(SUPPORTED_VERSIONS.contains(&PROTOCOL_V1));
353        assert!(SUPPORTED_VERSIONS.len() >= 2);
354    }
355
356    #[test]
357    fn test_negotiator_methods() {
358        let service_id = Uuid::new_v4();
359        let negotiator = ProtocolNegotiator::new(service_id);
360
361        // Test capability support checking
362        assert!(negotiator.supports_capability(capabilities::SHELL_EXEC));
363        assert!(negotiator.supports_capability(capabilities::TRACING));
364        assert!(!negotiator.supports_capability("non_existent_capability"));
365
366        // Test available capabilities
367        let caps = negotiator.get_available_capabilities();
368        assert!(!caps.is_empty());
369        assert!(caps.contains(&capabilities::SHELL_EXEC.to_string()));
370
371        // Test service metadata
372        let metadata = negotiator.get_service_metadata();
373        assert!(metadata.contains_key("service_name"));
374        assert!(metadata.contains_key("version"));
375        assert_eq!(metadata.get("service_name").unwrap(), "claude-code-rs-core");
376    }
377
378    #[test]
379    fn test_ready_event_creation() {
380        let service_id = Uuid::new_v4();
381        let negotiator = ProtocolNegotiator::new(service_id);
382
383        let negotiation_result = NegotiationResult {
384            version: PROTOCOL_V0,
385            capabilities: vec![capabilities::SHELL_EXEC.to_string()],
386            fallback_reason: None,
387            service_info: HashMap::new(),
388        };
389
390        let event = negotiator.create_ready_event(&negotiation_result);
391
392        match event {
393            Event::Ready {
394                version,
395                capabilities,
396                service_id: event_service_id,
397            } => {
398                assert_eq!(version, PROTOCOL_V0);
399                assert_eq!(capabilities, vec![capabilities::SHELL_EXEC.to_string()]);
400                assert_eq!(event_service_id, service_id);
401            }
402            _ => panic!("Expected Ready event"),
403        }
404    }
405
406    #[test]
407    fn test_fallback_conditions_v0() {
408        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
409        let fallback = negotiator.check_fallback_conditions(PROTOCOL_V0, &[]);
410        assert!(fallback.is_none()); // v0 should never need fallback
411    }
412
413    #[test]
414    fn test_fallback_conditions_unsupported_version() {
415        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
416        let fallback = negotiator.check_fallback_conditions(999, &[]);
417        assert!(fallback.is_some());
418        assert!(fallback.unwrap().contains("Unsupported version 999"));
419    }
420
421    #[test]
422    fn test_capability_compatibility_nats_without_tracing() {
423        let warnings =
424            CapabilityChecker::check_compatibility(&[capabilities::NATS.to_string()]).unwrap();
425
426        // Should warn about NATS without tracing
427        assert!(!warnings.is_empty());
428        assert!(warnings
429            .iter()
430            .any(|w| w.contains("NATS enabled without tracing")));
431    }
432
433    #[test]
434    fn test_capability_compatibility_with_replay() {
435        let warnings =
436            CapabilityChecker::check_compatibility(&[capabilities::REPLAY.to_string()]).unwrap();
437
438        // Should warn about replay overhead
439        assert!(!warnings.is_empty());
440        assert!(warnings.iter().any(|w| w.contains("Replay enabled")));
441    }
442
443    #[test]
444    fn test_capability_compatibility_good_config() {
445        let warnings = CapabilityChecker::check_compatibility(&[
446            capabilities::SHELL_EXEC.to_string(),
447            capabilities::TRACING.to_string(),
448        ])
449        .unwrap();
450
451        // No warnings for good configuration
452        assert!(warnings.is_empty());
453    }
454
455    #[test]
456    fn test_use_case_recommendations_all_variants() {
457        let test_cases = vec![
458            (
459                "development",
460                vec![
461                    capabilities::SHELL_EXEC,
462                    capabilities::HOOKS_JS,
463                    capabilities::REPLAY,
464                    capabilities::TRACING,
465                ],
466            ),
467            (
468                "production",
469                vec![
470                    capabilities::SHELL_EXEC,
471                    capabilities::HOOKS_RUST,
472                    capabilities::TRACING,
473                    capabilities::NATS,
474                ],
475            ),
476            (
477                "testing",
478                vec![
479                    capabilities::SHELL_EXEC,
480                    capabilities::REPLAY,
481                    capabilities::TRACING,
482                ],
483            ),
484            ("minimal", vec![capabilities::SHELL_EXEC]),
485            (
486                "unknown_use_case",
487                vec![capabilities::SHELL_EXEC, capabilities::TRACING],
488            ),
489        ];
490
491        for (use_case, expected_caps) in test_cases {
492            let recommendations = CapabilityChecker::recommend_capabilities(use_case);
493            for cap in expected_caps {
494                assert!(
495                    recommendations.contains(&cap.to_string()),
496                    "Use case '{}' should include capability '{}'",
497                    use_case,
498                    cap
499                );
500            }
501        }
502    }
503
504    #[test]
505    fn test_client_info_and_negotiation_result_serialization() {
506        // Test ClientInfo serialization
507        let mut metadata = HashMap::new();
508        metadata.insert("client_version".to_string(), "1.0.0".to_string());
509
510        let client_info = ClientInfo {
511            supported_versions: vec![0, 1],
512            requested_capabilities: vec![capabilities::SHELL_EXEC.to_string()],
513            client_metadata: metadata,
514        };
515
516        let json = serde_json::to_string(&client_info).unwrap();
517        let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
518        assert_eq!(deserialized.supported_versions, vec![0, 1]);
519        assert_eq!(
520            deserialized.requested_capabilities,
521            vec![capabilities::SHELL_EXEC.to_string()]
522        );
523
524        // Test NegotiationResult serialization
525        let mut service_info = HashMap::new();
526        service_info.insert("key".to_string(), "value".to_string());
527
528        let result = NegotiationResult {
529            version: PROTOCOL_V0,
530            capabilities: vec![capabilities::SHELL_EXEC.to_string()],
531            fallback_reason: Some("test fallback".to_string()),
532            service_info,
533        };
534
535        let json = serde_json::to_string(&result).unwrap();
536        let deserialized: NegotiationResult = serde_json::from_str(&json).unwrap();
537        assert_eq!(deserialized.version, PROTOCOL_V0);
538        assert_eq!(
539            deserialized.fallback_reason,
540            Some("test fallback".to_string())
541        );
542    }
543
544    #[test]
545    fn test_empty_client_versions_negotiation() {
546        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
547
548        let client_info = ClientInfo {
549            supported_versions: vec![], // Empty versions
550            requested_capabilities: vec![],
551            client_metadata: HashMap::new(),
552        };
553
554        let result = negotiator.negotiate(client_info);
555        assert!(result.is_err());
556        assert!(result
557            .unwrap_err()
558            .to_string()
559            .contains("No compatible protocol version"));
560    }
561
562    #[test]
563    fn test_empty_requested_capabilities() {
564        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
565
566        let client_info = ClientInfo {
567            supported_versions: vec![PROTOCOL_V0],
568            requested_capabilities: vec![], // Empty capabilities
569            client_metadata: HashMap::new(),
570        };
571
572        let result = negotiator.negotiate(client_info).unwrap();
573        assert_eq!(result.version, PROTOCOL_V0);
574        assert!(result.capabilities.is_empty()); // Should grant no capabilities
575    }
576
577    #[test]
578    fn test_version_selection_priority() {
579        let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
580
581        // Client supports both versions - should select the newer one (v1)
582        let client_info = ClientInfo {
583            supported_versions: vec![PROTOCOL_V0, PROTOCOL_V1],
584            requested_capabilities: vec![],
585            client_metadata: HashMap::new(),
586        };
587
588        let result = negotiator.negotiate(client_info).unwrap();
589        // Should prefer the highest supported version (v1 is preferred in SUPPORTED_VERSIONS)
590        assert!(SUPPORTED_VERSIONS.contains(&result.version));
591    }
592}