zinit_client/
client.rs

1use crate::connection::ConnectionManager;
2use crate::error::{Result, ZinitError};
3use crate::models::{
4    LogEntry, LogStream, Protocol, ServerCapabilities, ServiceState, ServiceStatus, ServiceTarget,
5};
6use crate::protocol::ProtocolHandler;
7use crate::retry::RetryStrategy;
8use chrono::Utc;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::io::{AsyncBufReadExt, BufReader};
15use tokio::sync::OnceCell;
16use tracing::{debug, trace};
17
18/// Configuration for the Zinit client
19#[derive(Debug, Clone)]
20pub struct ClientConfig {
21    /// Path to the Zinit Unix socket
22    pub socket_path: PathBuf,
23    /// Timeout for connection attempts
24    pub connection_timeout: Duration,
25    /// Timeout for operations
26    pub operation_timeout: Duration,
27    /// Maximum number of retry attempts
28    pub max_retries: usize,
29    /// Base delay between retries
30    pub retry_delay: Duration,
31    /// Maximum delay between retries
32    pub max_retry_delay: Duration,
33    /// Whether to add jitter to retry delays
34    pub retry_jitter: bool,
35}
36
37impl Default for ClientConfig {
38    fn default() -> Self {
39        Self {
40            socket_path: PathBuf::from("/var/run/zinit.sock"),
41            connection_timeout: Duration::from_secs(5),
42            operation_timeout: Duration::from_secs(30),
43            max_retries: 3,
44            retry_delay: Duration::from_millis(100),
45            max_retry_delay: Duration::from_secs(5),
46            retry_jitter: true,
47        }
48    }
49}
50
51/// Client for interacting with Zinit
52#[derive(Debug)]
53pub struct ZinitClient {
54    /// Connection manager
55    connection_manager: ConnectionManager,
56    /// Client configuration
57    #[allow(dead_code)]
58    config: ClientConfig,
59    /// Detected protocol (lazy initialization)
60    protocol: OnceCell<Protocol>,
61    /// Server capabilities (lazy initialization)
62    capabilities: OnceCell<ServerCapabilities>,
63    /// Request ID counter for JSON-RPC
64    request_id: Arc<AtomicU64>,
65}
66
67impl ZinitClient {
68    /// Create a new Zinit client with the default configuration
69    pub fn new(socket_path: impl AsRef<Path>) -> Self {
70        Self::with_config(ClientConfig {
71            socket_path: socket_path.as_ref().to_path_buf(),
72            ..Default::default()
73        })
74    }
75
76    /// Create a new Zinit client with a custom configuration
77    pub fn with_config(config: ClientConfig) -> Self {
78        let retry_strategy = RetryStrategy::new(
79            config.max_retries,
80            config.retry_delay,
81            config.max_retry_delay,
82            config.retry_jitter,
83        );
84
85        let connection_manager = ConnectionManager::new(
86            &config.socket_path,
87            config.connection_timeout,
88            config.operation_timeout,
89            retry_strategy,
90        );
91
92        Self {
93            connection_manager,
94            config,
95            protocol: OnceCell::new(),
96            capabilities: OnceCell::new(),
97            request_id: Arc::new(AtomicU64::new(1)),
98        }
99    }
100
101    /// Get the next request ID for JSON-RPC calls
102    fn next_request_id(&self) -> u64 {
103        self.request_id.fetch_add(1, Ordering::SeqCst)
104    }
105
106    /// Detect the protocol used by the server
107    async fn detect_protocol(&self) -> Result<Protocol> {
108        debug!("Detecting server protocol");
109
110        // Try JSON-RPC first (new servers)
111        let request_id = self.next_request_id();
112        let json_rpc_request = ProtocolHandler::format_json_rpc_request(
113            "service_list",
114            serde_json::Value::Array(vec![]),
115            request_id,
116        )?;
117
118        match self
119            .connection_manager
120            .send_command(&json_rpc_request)
121            .await
122        {
123            Ok(response) => {
124                // Check if response looks like JSON-RPC
125                if response.contains("\"jsonrpc\":\"2.0\"") {
126                    debug!("Detected JSON-RPC protocol (new server)");
127                    return Ok(Protocol::JsonRpc);
128                }
129            }
130            Err(_) => {
131                // JSON-RPC failed, continue to try raw commands
132            }
133        }
134
135        // Try raw commands (old servers)
136        let raw_command = ProtocolHandler::format_raw_command("list", &[]);
137        match self.connection_manager.send_command(&raw_command).await {
138            Ok(response) => {
139                // Check if response looks like old server format
140                if response.contains("\"state\":\"ok\"") || response.contains("\"state\":\"error\"")
141                {
142                    debug!("Detected raw command protocol (old server)");
143                    return Ok(Protocol::RawCommands);
144                }
145            }
146            Err(e) => {
147                return Err(ZinitError::ProtocolDetectionFailed(format!(
148                    "Failed to detect protocol: {e}"
149                )));
150            }
151        }
152
153        Err(ZinitError::ProtocolDetectionFailed(
154            "Unable to determine server protocol".to_string(),
155        ))
156    }
157
158    /// Detect server capabilities based on protocol
159    async fn detect_capabilities(&self) -> Result<ServerCapabilities> {
160        let protocol = self.get_protocol().await?;
161        debug!("Detecting server capabilities for protocol: {}", protocol);
162
163        let capabilities = match protocol {
164            Protocol::JsonRpc => {
165                // New servers support all features
166                ServerCapabilities::full()
167            }
168            Protocol::RawCommands => {
169                // Old servers have limited capabilities
170                ServerCapabilities::legacy()
171            }
172        };
173
174        debug!("Detected capabilities: {:?}", capabilities);
175        Ok(capabilities)
176    }
177
178    /// Get the detected protocol (with lazy initialization)
179    async fn get_protocol(&self) -> Result<Protocol> {
180        if let Some(protocol) = self.protocol.get() {
181            return Ok(*protocol);
182        }
183
184        let protocol = self.detect_protocol().await?;
185        let _ = self.protocol.set(protocol);
186        Ok(protocol)
187    }
188
189    /// Get the server capabilities (with lazy initialization)
190    async fn get_capabilities(&self) -> Result<&ServerCapabilities> {
191        if let Some(capabilities) = self.capabilities.get() {
192            return Ok(capabilities);
193        }
194
195        let capabilities = self.detect_capabilities().await?;
196        let _ = self.capabilities.set(capabilities);
197        Ok(self.capabilities.get().unwrap())
198    }
199
200    /// Execute a command using the appropriate protocol
201    async fn execute_command(
202        &self,
203        method: &str,
204        args: &[&str],
205        params: Option<serde_json::Value>,
206    ) -> Result<serde_json::Value> {
207        let protocol = self.get_protocol().await?;
208        let request_id = self.next_request_id();
209
210        let request = ProtocolHandler::format_request(protocol, method, args, params, request_id)?;
211        let response = self.connection_manager.send_command(&request).await?;
212        ProtocolHandler::parse_response_by_protocol(protocol, &response)
213    }
214
215    /// List all services and their states
216    pub async fn list(&self) -> Result<HashMap<String, ServiceState>> {
217        debug!("Listing all services");
218
219        let protocol = self.get_protocol().await?;
220        let response = match protocol {
221            Protocol::JsonRpc => self.execute_command("service_list", &[], None).await?,
222            Protocol::RawCommands => self.execute_command("list", &[], None).await?,
223        };
224
225        let map: HashMap<String, String> = serde_json::from_value(response)?;
226        let result = map
227            .into_iter()
228            .map(|(name, state_str)| {
229                let state = match state_str.as_str() {
230                    "Unknown" => ServiceState::Unknown,
231                    "Blocked" => ServiceState::Blocked,
232                    "Spawned" => ServiceState::Spawned,
233                    "Running" => ServiceState::Running,
234                    "Success" => ServiceState::Success,
235                    "Error" => ServiceState::Error,
236                    "TestFailure" => ServiceState::TestFailure,
237                    _ => ServiceState::Unknown,
238                };
239                (name, state)
240            })
241            .collect();
242
243        Ok(result)
244    }
245
246    /// Get the status of a service
247    pub async fn status(&self, service: impl AsRef<str>) -> Result<ServiceStatus> {
248        let service_name = service.as_ref();
249        debug!("Getting status for service: {}", service_name);
250
251        let protocol = self.get_protocol().await?;
252        let response = match protocol {
253            Protocol::JsonRpc => {
254                let params = serde_json::json!([service_name]);
255                self.execute_command("service_status", &[], Some(params))
256                    .await?
257            }
258            Protocol::RawCommands => {
259                self.execute_command("status", &[service_name], None)
260                    .await?
261            }
262        };
263
264        // Parse the response based on protocol
265        let status = self.parse_status_response(response, service_name).await?;
266        Ok(status)
267    }
268
269    /// Parse status response handling different formats between protocols
270    async fn parse_status_response(
271        &self,
272        response: serde_json::Value,
273        service_name: &str,
274    ) -> Result<ServiceStatus> {
275        let protocol = self.get_protocol().await?;
276
277        match protocol {
278            Protocol::JsonRpc => {
279                // New server JSON-RPC format
280                let name = response
281                    .get("name")
282                    .and_then(|v| v.as_str())
283                    .unwrap_or(service_name)
284                    .to_string();
285
286                let pid = response.get("pid").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
287
288                let state_str = response
289                    .get("state")
290                    .and_then(|v| v.as_str())
291                    .unwrap_or("Unknown");
292
293                let target_str = response
294                    .get("target")
295                    .and_then(|v| v.as_str())
296                    .unwrap_or("Down");
297
298                let after = response
299                    .get("after")
300                    .and_then(|v| v.as_object())
301                    .map(|obj| {
302                        obj.iter()
303                            .map(|(k, v)| (k.clone(), v.as_str().unwrap_or("Unknown").to_string()))
304                            .collect()
305                    })
306                    .unwrap_or_default();
307
308                Ok(ServiceStatus {
309                    name,
310                    pid,
311                    state: self.parse_service_state(state_str),
312                    target: self.parse_service_target(target_str),
313                    after,
314                })
315            }
316            Protocol::RawCommands => {
317                // Old server format - try direct deserialization first
318                match serde_json::from_value::<ServiceStatus>(response.clone()) {
319                    Ok(mut status) => {
320                        // Convert state and target strings to enums
321                        status.state = self.parse_service_state(&status.state.to_string());
322                        status.target = self.parse_service_target(&status.target.to_string());
323                        Ok(status)
324                    }
325                    Err(_) => {
326                        // Fallback parsing for old format
327                        let name = service_name.to_string();
328                        let pid = response.get("pid").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
329
330                        let state_str = response
331                            .get("state")
332                            .and_then(|v| v.as_str())
333                            .unwrap_or("Unknown");
334
335                        let target_str = response
336                            .get("target")
337                            .and_then(|v| v.as_str())
338                            .unwrap_or("Down");
339
340                        let after = response
341                            .get("after")
342                            .and_then(|v| v.as_object())
343                            .map(|obj| {
344                                obj.iter()
345                                    .map(|(k, v)| {
346                                        (k.clone(), v.as_str().unwrap_or("Unknown").to_string())
347                                    })
348                                    .collect()
349                            })
350                            .unwrap_or_default();
351
352                        Ok(ServiceStatus {
353                            name,
354                            pid,
355                            state: self.parse_service_state(state_str),
356                            target: self.parse_service_target(target_str),
357                            after,
358                        })
359                    }
360                }
361            }
362        }
363    }
364
365    /// Parse service state string to enum
366    fn parse_service_state(&self, state_str: &str) -> ServiceState {
367        match state_str {
368            "Unknown" => ServiceState::Unknown,
369            "Blocked" => ServiceState::Blocked,
370            "Spawned" => ServiceState::Spawned,
371            "Running" => ServiceState::Running,
372            "Success" => ServiceState::Success,
373            "Error" => ServiceState::Error,
374            "TestFailure" => ServiceState::TestFailure,
375            _ => ServiceState::Unknown,
376        }
377    }
378
379    /// Parse service target string to enum
380    fn parse_service_target(&self, target_str: &str) -> ServiceTarget {
381        match target_str {
382            "Up" => ServiceTarget::Up,
383            "Down" => ServiceTarget::Down,
384            _ => ServiceTarget::Down,
385        }
386    }
387
388    /// Start a service
389    pub async fn start(&self, service: impl AsRef<str>) -> Result<()> {
390        let service_name = service.as_ref();
391        debug!("Starting service: {}", service_name);
392
393        let protocol = self.get_protocol().await?;
394        match protocol {
395            Protocol::JsonRpc => {
396                let params = serde_json::json!([service_name]);
397                self.execute_command("service_start", &[], Some(params))
398                    .await?;
399            }
400            Protocol::RawCommands => {
401                self.execute_command("start", &[service_name], None).await?;
402            }
403        }
404
405        Ok(())
406    }
407
408    /// Stop a service
409    pub async fn stop(&self, service: impl AsRef<str>) -> Result<()> {
410        let service_name = service.as_ref();
411        debug!("Stopping service: {}", service_name);
412
413        let protocol = self.get_protocol().await?;
414        match protocol {
415            Protocol::JsonRpc => {
416                let params = serde_json::json!([service_name]);
417                self.execute_command("service_stop", &[], Some(params))
418                    .await?;
419            }
420            Protocol::RawCommands => {
421                self.execute_command("stop", &[service_name], None).await?;
422            }
423        }
424
425        Ok(())
426    }
427
428    /// Restart a service
429    pub async fn restart(&self, service: impl AsRef<str>) -> Result<()> {
430        let service_name = service.as_ref();
431        debug!("Restarting service: {}", service_name);
432
433        // First stop the service
434        self.stop(service_name).await?;
435
436        // Wait for the service to stop
437        let mut attempts = 0;
438        let max_attempts = 20;
439
440        while attempts < max_attempts {
441            let status = self.status(service_name).await?;
442            if status.pid == 0 && status.target == ServiceTarget::Down {
443                // Service is stopped, now start it
444                return self.start(service_name).await;
445            }
446
447            attempts += 1;
448            tokio::time::sleep(Duration::from_secs(1)).await;
449        }
450
451        // Service didn't stop gracefully, try to kill it
452        self.kill(service_name, "SIGKILL").await?;
453        self.start(service_name).await
454    }
455
456    /// Monitor a service
457    pub async fn monitor(&self, service: impl AsRef<str>) -> Result<()> {
458        let service_name = service.as_ref();
459        debug!("Monitoring service: {}", service_name);
460
461        let protocol = self.get_protocol().await?;
462        match protocol {
463            Protocol::JsonRpc => {
464                let params = serde_json::json!([service_name]);
465                self.execute_command("service_monitor", &[], Some(params))
466                    .await?;
467            }
468            Protocol::RawCommands => {
469                self.execute_command("monitor", &[service_name], None)
470                    .await?;
471            }
472        }
473
474        Ok(())
475    }
476
477    /// Forget a service
478    pub async fn forget(&self, service: impl AsRef<str>) -> Result<()> {
479        let service_name = service.as_ref();
480        debug!("Forgetting service: {}", service_name);
481
482        let protocol = self.get_protocol().await?;
483        match protocol {
484            Protocol::JsonRpc => {
485                let params = serde_json::json!([service_name]);
486                self.execute_command("service_forget", &[], Some(params))
487                    .await?;
488            }
489            Protocol::RawCommands => {
490                self.execute_command("forget", &[service_name], None)
491                    .await?;
492            }
493        }
494
495        Ok(())
496    }
497
498    /// Send a signal to a service
499    pub async fn kill(&self, service: impl AsRef<str>, signal: impl AsRef<str>) -> Result<()> {
500        let service_name = service.as_ref();
501        let signal_name = signal.as_ref();
502        debug!(
503            "Sending signal {} to service: {}",
504            signal_name, service_name
505        );
506
507        let protocol = self.get_protocol().await?;
508        match protocol {
509            Protocol::JsonRpc => {
510                let params = serde_json::json!([service_name, signal_name]);
511                self.execute_command("service_kill", &[], Some(params))
512                    .await?;
513            }
514            Protocol::RawCommands => {
515                self.execute_command("kill", &[service_name, signal_name], None)
516                    .await?;
517            }
518        }
519
520        Ok(())
521    }
522
523    /// Stream logs from services
524    pub async fn logs(&self, follow: bool, filter: Option<impl AsRef<str>>) -> Result<LogStream> {
525        let command = if follow {
526            "log".to_string()
527        } else {
528            "log snapshot".to_string()
529        };
530
531        debug!("Streaming logs with command: {}", command);
532        let stream = self.connection_manager.stream_logs(&command).await?;
533        let reader = BufReader::new(stream);
534        let mut lines = reader.lines();
535
536        // Create a stream of log entries
537        let filter_str = filter.as_ref().map(|f| f.as_ref().to_string());
538
539        let log_stream = async_stream::stream! {
540            while let Some(line_result) = lines.next_line().await.transpose() {
541                match line_result {
542                    Ok(line) => {
543                        trace!("Received log line: {}", line);
544
545                        // Parse the log line
546                        if let Some(entry) = parse_log_line(&line, &filter_str) {
547                            yield Ok(entry);
548                        }
549                    }
550                    Err(e) => {
551                        yield Err(ZinitError::ConnectionError(e));
552                        break;
553                    }
554                }
555            }
556        };
557
558        Ok(LogStream {
559            inner: Box::pin(log_stream),
560        })
561    }
562
563    /// Shutdown the system
564    pub async fn shutdown(&self) -> Result<()> {
565        debug!("Shutting down the system");
566        self.connection_manager.execute_command("shutdown").await?;
567        Ok(())
568    }
569
570    /// Reboot the system
571    pub async fn reboot(&self) -> Result<()> {
572        debug!("Rebooting the system");
573        self.connection_manager.execute_command("reboot").await?;
574        Ok(())
575    }
576
577    /// Get raw service information
578    pub async fn get_service(&self, service: impl AsRef<str>) -> Result<serde_json::Value> {
579        let service_name = service.as_ref();
580        debug!("Getting raw service info for: {}", service_name);
581
582        // Use the universal interface
583        let protocol = self.get_protocol().await?;
584        match protocol {
585            Protocol::JsonRpc => {
586                // New servers: use service_status RPC call
587                let params = serde_json::json!([service_name]);
588                self.execute_command("service_status", &[], Some(params))
589                    .await
590            }
591            Protocol::RawCommands => {
592                // Old servers: use status command
593                self.execute_command("status", &[service_name], None).await
594            }
595        }
596    }
597
598    /// Create a new service
599    pub async fn create_service(
600        &self,
601        name: impl AsRef<str>,
602        config: serde_json::Value,
603    ) -> Result<()> {
604        let service_name = name.as_ref();
605        debug!("Creating service: {}", service_name);
606
607        // Check if the server supports dynamic service creation
608        let capabilities = self.get_capabilities().await?;
609        if !capabilities.supports_create {
610            return Err(ZinitError::FeatureNotSupported(format!(
611                "Dynamic service creation is not supported by this zinit server ({}). \
612                     Please create a service configuration file manually in /etc/zinit/{}.yaml",
613                capabilities.protocol, service_name
614            )));
615        }
616
617        // Use the appropriate protocol
618        let protocol = self.get_protocol().await?;
619        match protocol {
620            Protocol::JsonRpc => {
621                // New servers: use service_create RPC call
622                let params = serde_json::json!([service_name, config]);
623                self.execute_command("service_create", &[], Some(params))
624                    .await?;
625            }
626            Protocol::RawCommands => {
627                // This should not happen since we checked capabilities above,
628                // but handle it gracefully
629                return Err(ZinitError::FeatureNotSupported(
630                    "Dynamic service creation requires zinit v0.2.25+".to_string(),
631                ));
632            }
633        }
634
635        Ok(())
636    }
637
638    /// Delete a service
639    pub async fn delete_service(&self, name: impl AsRef<str>) -> Result<()> {
640        let service_name = name.as_ref();
641        debug!("Deleting service: {}", service_name);
642
643        // Try to get status, but don't fail if it doesn't work
644        match self.status(service_name).await {
645            Ok(status) => {
646                if status.state == ServiceState::Running || status.target == ServiceTarget::Up {
647                    // Stop the service first
648                    if let Err(e) = self.stop(service_name).await {
649                        debug!("Warning: Failed to stop service {}: {}", service_name, e);
650                    }
651
652                    // Wait for the service to stop
653                    let mut attempts = 0;
654                    let max_attempts = 10;
655
656                    while attempts < max_attempts {
657                        match self.status(service_name).await {
658                            Ok(status) => {
659                                if status.pid == 0 && status.target == ServiceTarget::Down {
660                                    break;
661                                }
662                            }
663                            Err(_) => {
664                                // If status fails, assume service is stopped
665                                break;
666                            }
667                        }
668
669                        attempts += 1;
670                        tokio::time::sleep(Duration::from_millis(500)).await;
671                    }
672                }
673            }
674            Err(e) => {
675                debug!("Warning: Could not get status for {}: {}", service_name, e);
676                // Continue with deletion anyway
677            }
678        }
679
680        // Now forget the service and delete the config file
681        self.forget(service_name).await?;
682
683        // For new servers, also delete the service configuration file
684        let protocol = self.get_protocol().await?;
685        if let Protocol::JsonRpc = protocol {
686            let params = serde_json::json!([service_name]);
687            if let Err(e) = self
688                .execute_command("service_delete", &[], Some(params))
689                .await
690            {
691                debug!(
692                    "Warning: Could not delete service config file for {}: {}",
693                    service_name, e
694                );
695                // Don't fail the whole operation if config file deletion fails
696            }
697        }
698
699        Ok(())
700    }
701}
702
703/// Parse a log line into a LogEntry
704fn parse_log_line(line: &str, filter: &Option<String>) -> Option<LogEntry> {
705    // Example log line: "zinit: INFO (service) message"
706    let parts: Vec<&str> = line.splitn(4, ' ').collect();
707
708    if parts.len() < 4 || !parts[0].starts_with("zinit:") {
709        return None;
710    }
711
712    let level = parts[1];
713    let service = parts[2].trim_start_matches('(').trim_end_matches(')');
714
715    // Apply filter if provided
716    if let Some(filter_str) = filter {
717        if service != filter_str {
718            return None;
719        }
720    }
721
722    let message = parts[3];
723    let timestamp = Utc::now(); // Zinit doesn't include timestamps, so we use current time
724
725    Some(LogEntry {
726        timestamp,
727        service: service.to_string(),
728        message: format!("[{level}] {message}"),
729    })
730}