1use crate::connection::ConnectionManager;
2use crate::error::{Result, ZinitError};
3use crate::models::{
4 LogEntry, LogStream, Protocol, ServerCapabilities, ServiceState, ServiceStatus, ServiceTarget,
5};
6use crate::protocol::ProtocolHandler;
7use crate::retry::RetryStrategy;
8use chrono::Utc;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::io::{AsyncBufReadExt, BufReader};
15use tokio::sync::OnceCell;
16use tracing::{debug, trace};
17
18#[derive(Debug, Clone)]
20pub struct ClientConfig {
21 pub socket_path: PathBuf,
23 pub connection_timeout: Duration,
25 pub operation_timeout: Duration,
27 pub max_retries: usize,
29 pub retry_delay: Duration,
31 pub max_retry_delay: Duration,
33 pub retry_jitter: bool,
35}
36
37impl Default for ClientConfig {
38 fn default() -> Self {
39 Self {
40 socket_path: PathBuf::from("/var/run/zinit.sock"),
41 connection_timeout: Duration::from_secs(5),
42 operation_timeout: Duration::from_secs(30),
43 max_retries: 3,
44 retry_delay: Duration::from_millis(100),
45 max_retry_delay: Duration::from_secs(5),
46 retry_jitter: true,
47 }
48 }
49}
50
51#[derive(Debug)]
53pub struct ZinitClient {
54 connection_manager: ConnectionManager,
56 #[allow(dead_code)]
58 config: ClientConfig,
59 protocol: OnceCell<Protocol>,
61 capabilities: OnceCell<ServerCapabilities>,
63 request_id: Arc<AtomicU64>,
65}
66
67impl ZinitClient {
68 pub fn new(socket_path: impl AsRef<Path>) -> Self {
70 Self::with_config(ClientConfig {
71 socket_path: socket_path.as_ref().to_path_buf(),
72 ..Default::default()
73 })
74 }
75
76 pub fn with_config(config: ClientConfig) -> Self {
78 let retry_strategy = RetryStrategy::new(
79 config.max_retries,
80 config.retry_delay,
81 config.max_retry_delay,
82 config.retry_jitter,
83 );
84
85 let connection_manager = ConnectionManager::new(
86 &config.socket_path,
87 config.connection_timeout,
88 config.operation_timeout,
89 retry_strategy,
90 );
91
92 Self {
93 connection_manager,
94 config,
95 protocol: OnceCell::new(),
96 capabilities: OnceCell::new(),
97 request_id: Arc::new(AtomicU64::new(1)),
98 }
99 }
100
101 fn next_request_id(&self) -> u64 {
103 self.request_id.fetch_add(1, Ordering::SeqCst)
104 }
105
106 async fn detect_protocol(&self) -> Result<Protocol> {
108 debug!("Detecting server protocol");
109
110 let request_id = self.next_request_id();
112 let json_rpc_request = ProtocolHandler::format_json_rpc_request(
113 "service_list",
114 serde_json::Value::Array(vec![]),
115 request_id,
116 )?;
117
118 match self
119 .connection_manager
120 .send_command(&json_rpc_request)
121 .await
122 {
123 Ok(response) => {
124 if response.contains("\"jsonrpc\":\"2.0\"") {
126 debug!("Detected JSON-RPC protocol (new server)");
127 return Ok(Protocol::JsonRpc);
128 }
129 }
130 Err(_) => {
131 }
133 }
134
135 let raw_command = ProtocolHandler::format_raw_command("list", &[]);
137 match self.connection_manager.send_command(&raw_command).await {
138 Ok(response) => {
139 if response.contains("\"state\":\"ok\"") || response.contains("\"state\":\"error\"")
141 {
142 debug!("Detected raw command protocol (old server)");
143 return Ok(Protocol::RawCommands);
144 }
145 }
146 Err(e) => {
147 return Err(ZinitError::ProtocolDetectionFailed(format!(
148 "Failed to detect protocol: {e}"
149 )));
150 }
151 }
152
153 Err(ZinitError::ProtocolDetectionFailed(
154 "Unable to determine server protocol".to_string(),
155 ))
156 }
157
158 async fn detect_capabilities(&self) -> Result<ServerCapabilities> {
160 let protocol = self.get_protocol().await?;
161 debug!("Detecting server capabilities for protocol: {}", protocol);
162
163 let capabilities = match protocol {
164 Protocol::JsonRpc => {
165 ServerCapabilities::full()
167 }
168 Protocol::RawCommands => {
169 ServerCapabilities::legacy()
171 }
172 };
173
174 debug!("Detected capabilities: {:?}", capabilities);
175 Ok(capabilities)
176 }
177
178 async fn get_protocol(&self) -> Result<Protocol> {
180 if let Some(protocol) = self.protocol.get() {
181 return Ok(*protocol);
182 }
183
184 let protocol = self.detect_protocol().await?;
185 let _ = self.protocol.set(protocol);
186 Ok(protocol)
187 }
188
189 async fn get_capabilities(&self) -> Result<&ServerCapabilities> {
191 if let Some(capabilities) = self.capabilities.get() {
192 return Ok(capabilities);
193 }
194
195 let capabilities = self.detect_capabilities().await?;
196 let _ = self.capabilities.set(capabilities);
197 Ok(self.capabilities.get().unwrap())
198 }
199
200 async fn execute_command(
202 &self,
203 method: &str,
204 args: &[&str],
205 params: Option<serde_json::Value>,
206 ) -> Result<serde_json::Value> {
207 let protocol = self.get_protocol().await?;
208 let request_id = self.next_request_id();
209
210 let request = ProtocolHandler::format_request(protocol, method, args, params, request_id)?;
211 let response = self.connection_manager.send_command(&request).await?;
212 ProtocolHandler::parse_response_by_protocol(protocol, &response)
213 }
214
215 pub async fn list(&self) -> Result<HashMap<String, ServiceState>> {
217 debug!("Listing all services");
218
219 let protocol = self.get_protocol().await?;
220 let response = match protocol {
221 Protocol::JsonRpc => self.execute_command("service_list", &[], None).await?,
222 Protocol::RawCommands => self.execute_command("list", &[], None).await?,
223 };
224
225 let map: HashMap<String, String> = serde_json::from_value(response)?;
226 let result = map
227 .into_iter()
228 .map(|(name, state_str)| {
229 let state = match state_str.as_str() {
230 "Unknown" => ServiceState::Unknown,
231 "Blocked" => ServiceState::Blocked,
232 "Spawned" => ServiceState::Spawned,
233 "Running" => ServiceState::Running,
234 "Success" => ServiceState::Success,
235 "Error" => ServiceState::Error,
236 "TestFailure" => ServiceState::TestFailure,
237 _ => ServiceState::Unknown,
238 };
239 (name, state)
240 })
241 .collect();
242
243 Ok(result)
244 }
245
246 pub async fn status(&self, service: impl AsRef<str>) -> Result<ServiceStatus> {
248 let service_name = service.as_ref();
249 debug!("Getting status for service: {}", service_name);
250
251 let protocol = self.get_protocol().await?;
252 let response = match protocol {
253 Protocol::JsonRpc => {
254 let params = serde_json::json!([service_name]);
255 self.execute_command("service_status", &[], Some(params))
256 .await?
257 }
258 Protocol::RawCommands => {
259 self.execute_command("status", &[service_name], None)
260 .await?
261 }
262 };
263
264 let status = self.parse_status_response(response, service_name).await?;
266 Ok(status)
267 }
268
269 async fn parse_status_response(
271 &self,
272 response: serde_json::Value,
273 service_name: &str,
274 ) -> Result<ServiceStatus> {
275 let protocol = self.get_protocol().await?;
276
277 match protocol {
278 Protocol::JsonRpc => {
279 let name = response
281 .get("name")
282 .and_then(|v| v.as_str())
283 .unwrap_or(service_name)
284 .to_string();
285
286 let pid = response.get("pid").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
287
288 let state_str = response
289 .get("state")
290 .and_then(|v| v.as_str())
291 .unwrap_or("Unknown");
292
293 let target_str = response
294 .get("target")
295 .and_then(|v| v.as_str())
296 .unwrap_or("Down");
297
298 let after = response
299 .get("after")
300 .and_then(|v| v.as_object())
301 .map(|obj| {
302 obj.iter()
303 .map(|(k, v)| (k.clone(), v.as_str().unwrap_or("Unknown").to_string()))
304 .collect()
305 })
306 .unwrap_or_default();
307
308 Ok(ServiceStatus {
309 name,
310 pid,
311 state: self.parse_service_state(state_str),
312 target: self.parse_service_target(target_str),
313 after,
314 })
315 }
316 Protocol::RawCommands => {
317 match serde_json::from_value::<ServiceStatus>(response.clone()) {
319 Ok(mut status) => {
320 status.state = self.parse_service_state(&status.state.to_string());
322 status.target = self.parse_service_target(&status.target.to_string());
323 Ok(status)
324 }
325 Err(_) => {
326 let name = service_name.to_string();
328 let pid = response.get("pid").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
329
330 let state_str = response
331 .get("state")
332 .and_then(|v| v.as_str())
333 .unwrap_or("Unknown");
334
335 let target_str = response
336 .get("target")
337 .and_then(|v| v.as_str())
338 .unwrap_or("Down");
339
340 let after = response
341 .get("after")
342 .and_then(|v| v.as_object())
343 .map(|obj| {
344 obj.iter()
345 .map(|(k, v)| {
346 (k.clone(), v.as_str().unwrap_or("Unknown").to_string())
347 })
348 .collect()
349 })
350 .unwrap_or_default();
351
352 Ok(ServiceStatus {
353 name,
354 pid,
355 state: self.parse_service_state(state_str),
356 target: self.parse_service_target(target_str),
357 after,
358 })
359 }
360 }
361 }
362 }
363 }
364
365 fn parse_service_state(&self, state_str: &str) -> ServiceState {
367 match state_str {
368 "Unknown" => ServiceState::Unknown,
369 "Blocked" => ServiceState::Blocked,
370 "Spawned" => ServiceState::Spawned,
371 "Running" => ServiceState::Running,
372 "Success" => ServiceState::Success,
373 "Error" => ServiceState::Error,
374 "TestFailure" => ServiceState::TestFailure,
375 _ => ServiceState::Unknown,
376 }
377 }
378
379 fn parse_service_target(&self, target_str: &str) -> ServiceTarget {
381 match target_str {
382 "Up" => ServiceTarget::Up,
383 "Down" => ServiceTarget::Down,
384 _ => ServiceTarget::Down,
385 }
386 }
387
388 pub async fn start(&self, service: impl AsRef<str>) -> Result<()> {
390 let service_name = service.as_ref();
391 debug!("Starting service: {}", service_name);
392
393 let protocol = self.get_protocol().await?;
394 match protocol {
395 Protocol::JsonRpc => {
396 let params = serde_json::json!([service_name]);
397 self.execute_command("service_start", &[], Some(params))
398 .await?;
399 }
400 Protocol::RawCommands => {
401 self.execute_command("start", &[service_name], None).await?;
402 }
403 }
404
405 Ok(())
406 }
407
408 pub async fn stop(&self, service: impl AsRef<str>) -> Result<()> {
410 let service_name = service.as_ref();
411 debug!("Stopping service: {}", service_name);
412
413 let protocol = self.get_protocol().await?;
414 match protocol {
415 Protocol::JsonRpc => {
416 let params = serde_json::json!([service_name]);
417 self.execute_command("service_stop", &[], Some(params))
418 .await?;
419 }
420 Protocol::RawCommands => {
421 self.execute_command("stop", &[service_name], None).await?;
422 }
423 }
424
425 Ok(())
426 }
427
428 pub async fn restart(&self, service: impl AsRef<str>) -> Result<()> {
430 let service_name = service.as_ref();
431 debug!("Restarting service: {}", service_name);
432
433 self.stop(service_name).await?;
435
436 let mut attempts = 0;
438 let max_attempts = 20;
439
440 while attempts < max_attempts {
441 let status = self.status(service_name).await?;
442 if status.pid == 0 && status.target == ServiceTarget::Down {
443 return self.start(service_name).await;
445 }
446
447 attempts += 1;
448 tokio::time::sleep(Duration::from_secs(1)).await;
449 }
450
451 self.kill(service_name, "SIGKILL").await?;
453 self.start(service_name).await
454 }
455
456 pub async fn monitor(&self, service: impl AsRef<str>) -> Result<()> {
458 let service_name = service.as_ref();
459 debug!("Monitoring service: {}", service_name);
460
461 let protocol = self.get_protocol().await?;
462 match protocol {
463 Protocol::JsonRpc => {
464 let params = serde_json::json!([service_name]);
465 self.execute_command("service_monitor", &[], Some(params))
466 .await?;
467 }
468 Protocol::RawCommands => {
469 self.execute_command("monitor", &[service_name], None)
470 .await?;
471 }
472 }
473
474 Ok(())
475 }
476
477 pub async fn forget(&self, service: impl AsRef<str>) -> Result<()> {
479 let service_name = service.as_ref();
480 debug!("Forgetting service: {}", service_name);
481
482 let protocol = self.get_protocol().await?;
483 match protocol {
484 Protocol::JsonRpc => {
485 let params = serde_json::json!([service_name]);
486 self.execute_command("service_forget", &[], Some(params))
487 .await?;
488 }
489 Protocol::RawCommands => {
490 self.execute_command("forget", &[service_name], None)
491 .await?;
492 }
493 }
494
495 Ok(())
496 }
497
498 pub async fn kill(&self, service: impl AsRef<str>, signal: impl AsRef<str>) -> Result<()> {
500 let service_name = service.as_ref();
501 let signal_name = signal.as_ref();
502 debug!(
503 "Sending signal {} to service: {}",
504 signal_name, service_name
505 );
506
507 let protocol = self.get_protocol().await?;
508 match protocol {
509 Protocol::JsonRpc => {
510 let params = serde_json::json!([service_name, signal_name]);
511 self.execute_command("service_kill", &[], Some(params))
512 .await?;
513 }
514 Protocol::RawCommands => {
515 self.execute_command("kill", &[service_name, signal_name], None)
516 .await?;
517 }
518 }
519
520 Ok(())
521 }
522
523 pub async fn logs(&self, follow: bool, filter: Option<impl AsRef<str>>) -> Result<LogStream> {
525 let command = if follow {
526 "log".to_string()
527 } else {
528 "log snapshot".to_string()
529 };
530
531 debug!("Streaming logs with command: {}", command);
532 let stream = self.connection_manager.stream_logs(&command).await?;
533 let reader = BufReader::new(stream);
534 let mut lines = reader.lines();
535
536 let filter_str = filter.as_ref().map(|f| f.as_ref().to_string());
538
539 let log_stream = async_stream::stream! {
540 while let Some(line_result) = lines.next_line().await.transpose() {
541 match line_result {
542 Ok(line) => {
543 trace!("Received log line: {}", line);
544
545 if let Some(entry) = parse_log_line(&line, &filter_str) {
547 yield Ok(entry);
548 }
549 }
550 Err(e) => {
551 yield Err(ZinitError::ConnectionError(e));
552 break;
553 }
554 }
555 }
556 };
557
558 Ok(LogStream {
559 inner: Box::pin(log_stream),
560 })
561 }
562
563 pub async fn shutdown(&self) -> Result<()> {
565 debug!("Shutting down the system");
566 self.connection_manager.execute_command("shutdown").await?;
567 Ok(())
568 }
569
570 pub async fn reboot(&self) -> Result<()> {
572 debug!("Rebooting the system");
573 self.connection_manager.execute_command("reboot").await?;
574 Ok(())
575 }
576
577 pub async fn get_service(&self, service: impl AsRef<str>) -> Result<serde_json::Value> {
579 let service_name = service.as_ref();
580 debug!("Getting raw service info for: {}", service_name);
581
582 let protocol = self.get_protocol().await?;
584 match protocol {
585 Protocol::JsonRpc => {
586 let params = serde_json::json!([service_name]);
588 self.execute_command("service_status", &[], Some(params))
589 .await
590 }
591 Protocol::RawCommands => {
592 self.execute_command("status", &[service_name], None).await
594 }
595 }
596 }
597
598 pub async fn create_service(
600 &self,
601 name: impl AsRef<str>,
602 config: serde_json::Value,
603 ) -> Result<()> {
604 let service_name = name.as_ref();
605 debug!("Creating service: {}", service_name);
606
607 let capabilities = self.get_capabilities().await?;
609 if !capabilities.supports_create {
610 return Err(ZinitError::FeatureNotSupported(format!(
611 "Dynamic service creation is not supported by this zinit server ({}). \
612 Please create a service configuration file manually in /etc/zinit/{}.yaml",
613 capabilities.protocol, service_name
614 )));
615 }
616
617 let protocol = self.get_protocol().await?;
619 match protocol {
620 Protocol::JsonRpc => {
621 let params = serde_json::json!([service_name, config]);
623 self.execute_command("service_create", &[], Some(params))
624 .await?;
625 }
626 Protocol::RawCommands => {
627 return Err(ZinitError::FeatureNotSupported(
630 "Dynamic service creation requires zinit v0.2.25+".to_string(),
631 ));
632 }
633 }
634
635 Ok(())
636 }
637
638 pub async fn delete_service(&self, name: impl AsRef<str>) -> Result<()> {
640 let service_name = name.as_ref();
641 debug!("Deleting service: {}", service_name);
642
643 match self.status(service_name).await {
645 Ok(status) => {
646 if status.state == ServiceState::Running || status.target == ServiceTarget::Up {
647 if let Err(e) = self.stop(service_name).await {
649 debug!("Warning: Failed to stop service {}: {}", service_name, e);
650 }
651
652 let mut attempts = 0;
654 let max_attempts = 10;
655
656 while attempts < max_attempts {
657 match self.status(service_name).await {
658 Ok(status) => {
659 if status.pid == 0 && status.target == ServiceTarget::Down {
660 break;
661 }
662 }
663 Err(_) => {
664 break;
666 }
667 }
668
669 attempts += 1;
670 tokio::time::sleep(Duration::from_millis(500)).await;
671 }
672 }
673 }
674 Err(e) => {
675 debug!("Warning: Could not get status for {}: {}", service_name, e);
676 }
678 }
679
680 self.forget(service_name).await?;
682
683 let protocol = self.get_protocol().await?;
685 if let Protocol::JsonRpc = protocol {
686 let params = serde_json::json!([service_name]);
687 if let Err(e) = self
688 .execute_command("service_delete", &[], Some(params))
689 .await
690 {
691 debug!(
692 "Warning: Could not delete service config file for {}: {}",
693 service_name, e
694 );
695 }
697 }
698
699 Ok(())
700 }
701}
702
703fn parse_log_line(line: &str, filter: &Option<String>) -> Option<LogEntry> {
705 let parts: Vec<&str> = line.splitn(4, ' ').collect();
707
708 if parts.len() < 4 || !parts[0].starts_with("zinit:") {
709 return None;
710 }
711
712 let level = parts[1];
713 let service = parts[2].trim_start_matches('(').trim_end_matches(')');
714
715 if let Some(filter_str) = filter {
717 if service != filter_str {
718 return None;
719 }
720 }
721
722 let message = parts[3];
723 let timestamp = Utc::now(); Some(LogEntry {
726 timestamp,
727 service: service.to_string(),
728 message: format!("[{level}] {message}"),
729 })
730}