Skip to main content

rust_expect/
encoding.rs

1//! Encoding detection and conversion utilities.
2//!
3//! This module provides utilities for handling text encoding in terminal I/O,
4//! including UTF-8 validation, encoding detection, and line ending normalization.
5
6use std::borrow::Cow;
7use std::fmt::Write;
8
9/// Result of encoding a byte sequence to text.
10#[derive(Debug, Clone)]
11pub struct EncodedText {
12    /// The decoded text.
13    pub text: String,
14    /// Number of bytes consumed from input.
15    pub bytes_consumed: usize,
16    /// Whether there were any encoding errors.
17    pub had_errors: bool,
18    /// Number of replacement characters inserted.
19    pub replacements: usize,
20}
21
22impl EncodedText {
23    /// Create a successful encoding result.
24    #[must_use]
25    pub fn ok(text: impl Into<String>, bytes_consumed: usize) -> Self {
26        Self {
27            text: text.into(),
28            bytes_consumed,
29            had_errors: false,
30            replacements: 0,
31        }
32    }
33
34    /// Create an encoding result with errors.
35    #[must_use]
36    pub fn with_errors(
37        text: impl Into<String>,
38        bytes_consumed: usize,
39        replacements: usize,
40    ) -> Self {
41        Self {
42            text: text.into(),
43            bytes_consumed,
44            had_errors: replacements > 0,
45            replacements,
46        }
47    }
48}
49
50/// Decode bytes as UTF-8, replacing invalid sequences.
51///
52/// This is the default behavior for rust-expect. Invalid UTF-8 sequences
53/// are replaced with the Unicode replacement character (U+FFFD).
54#[must_use]
55pub fn decode_utf8_lossy(bytes: &[u8]) -> EncodedText {
56    let text = String::from_utf8_lossy(bytes);
57    let replacements = text.matches('\u{FFFD}').count();
58
59    EncodedText {
60        text: text.into_owned(),
61        bytes_consumed: bytes.len(),
62        had_errors: replacements > 0,
63        replacements,
64    }
65}
66
67/// Decode bytes as UTF-8, returning an error on invalid sequences.
68///
69/// # Errors
70///
71/// Returns an error if the input is not valid UTF-8.
72pub fn decode_utf8_strict(bytes: &[u8]) -> Result<EncodedText, std::str::Utf8Error> {
73    let text = std::str::from_utf8(bytes)?;
74    Ok(EncodedText::ok(text, bytes.len()))
75}
76
77/// Decode bytes as UTF-8, escaping invalid bytes as hex.
78///
79/// Invalid bytes are replaced with `\xHH` escape sequences.
80#[must_use]
81#[allow(unsafe_code)]
82pub fn decode_utf8_escape(bytes: &[u8]) -> EncodedText {
83    let mut result = String::with_capacity(bytes.len());
84    let mut replacements = 0;
85    let mut i = 0;
86
87    while i < bytes.len() {
88        match std::str::from_utf8(&bytes[i..]) {
89            Ok(valid) => {
90                result.push_str(valid);
91                break;
92            }
93            Err(e) => {
94                // Add the valid prefix
95                let valid_up_to = e.valid_up_to();
96                if valid_up_to > 0 {
97                    // SAFETY: `from_utf8` returned `Utf8Error::valid_up_to() == valid_up_to`,
98                    // which guarantees `bytes[i..i + valid_up_to]` is well-formed UTF-8.
99                    result.push_str(unsafe {
100                        std::str::from_utf8_unchecked(&bytes[i..i + valid_up_to])
101                    });
102                }
103                i += valid_up_to;
104
105                // Handle the invalid byte(s)
106                let error_len = e.error_len().unwrap_or(1);
107                for byte in &bytes[i..i + error_len] {
108                    let _ = write!(result, "\\x{byte:02x}");
109                    replacements += 1;
110                }
111                i += error_len;
112            }
113        }
114    }
115
116    EncodedText::with_errors(result, bytes.len(), replacements)
117}
118
119/// Skip invalid UTF-8 sequences.
120///
121/// Invalid bytes are simply removed from the output.
122#[must_use]
123#[allow(unsafe_code)]
124pub fn decode_utf8_skip(bytes: &[u8]) -> EncodedText {
125    let mut result = String::with_capacity(bytes.len());
126    let mut replacements = 0;
127    let mut i = 0;
128
129    while i < bytes.len() {
130        match std::str::from_utf8(&bytes[i..]) {
131            Ok(valid) => {
132                result.push_str(valid);
133                break;
134            }
135            Err(e) => {
136                let valid_up_to = e.valid_up_to();
137                if valid_up_to > 0 {
138                    // SAFETY: `valid_up_to` slice is guaranteed well-formed UTF-8 by
139                    // the contract of `std::str::Utf8Error::valid_up_to`.
140                    result.push_str(unsafe {
141                        std::str::from_utf8_unchecked(&bytes[i..i + valid_up_to])
142                    });
143                }
144                i += valid_up_to;
145
146                let error_len = e.error_len().unwrap_or(1);
147                replacements += error_len;
148                i += error_len;
149            }
150        }
151    }
152
153    EncodedText::with_errors(result, bytes.len(), replacements)
154}
155
156/// Normalize line endings in text.
157///
158/// Converts all line endings (CRLF, CR, LF) to the specified style.
159#[must_use]
160pub fn normalize_line_endings(text: &str, ending: LineEndingStyle) -> Cow<'_, str> {
161    let target = ending.as_str();
162
163    // Check if normalization is needed
164    let needs_crlf = text.contains("\r\n");
165    let needs_cr = text.contains('\r') && !needs_crlf;
166    let needs_lf = text.contains('\n') && !needs_crlf;
167
168    // If already normalized, return as-is
169    match ending {
170        LineEndingStyle::Lf if !needs_crlf && !needs_cr => return Cow::Borrowed(text),
171        LineEndingStyle::CrLf if needs_crlf && !needs_cr && !needs_lf => {
172            return Cow::Borrowed(text);
173        }
174        LineEndingStyle::Cr if needs_cr && !needs_crlf && !needs_lf => return Cow::Borrowed(text),
175        _ => {}
176    }
177
178    // First normalize all endings to LF
179    let normalized = if needs_crlf {
180        text.replace("\r\n", "\n")
181    } else {
182        text.to_string()
183    };
184
185    let normalized = if normalized.contains('\r') {
186        normalized.replace('\r', "\n")
187    } else {
188        normalized
189    };
190
191    // Then convert to target if not LF
192    let result = if target == "\n" {
193        normalized
194    } else {
195        normalized.replace('\n', target)
196    };
197
198    Cow::Owned(result)
199}
200
201/// Line ending styles.
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum LineEndingStyle {
204    /// Unix-style (LF)
205    #[default]
206    Lf,
207    /// Windows-style (CRLF)
208    CrLf,
209    /// Classic Mac (CR)
210    Cr,
211}
212
213impl LineEndingStyle {
214    /// Get the line ending as a string.
215    #[must_use]
216    pub const fn as_str(self) -> &'static str {
217        match self {
218            Self::Lf => "\n",
219            Self::CrLf => "\r\n",
220            Self::Cr => "\r",
221        }
222    }
223
224    /// Get the line ending as bytes.
225    #[must_use]
226    pub const fn as_bytes(self) -> &'static [u8] {
227        match self {
228            Self::Lf => b"\n",
229            Self::CrLf => b"\r\n",
230            Self::Cr => b"\r",
231        }
232    }
233
234    /// Detect the line ending style from environment.
235    #[must_use]
236    pub const fn from_env() -> Self {
237        if cfg!(windows) { Self::CrLf } else { Self::Lf }
238    }
239}
240
241/// Detect the predominant line ending in text.
242#[must_use]
243pub fn detect_line_ending(text: &str) -> Option<LineEndingStyle> {
244    let crlf_count = text.matches("\r\n").count();
245    let lf_only_count = text.matches('\n').count().saturating_sub(crlf_count);
246    let cr_only_count = text
247        .chars()
248        .zip(text.chars().skip(1).chain(std::iter::once('\0')))
249        .filter(|&(c, next)| c == '\r' && next != '\n')
250        .count();
251
252    if crlf_count == 0 && lf_only_count == 0 && cr_only_count == 0 {
253        return None;
254    }
255
256    if crlf_count >= lf_only_count && crlf_count >= cr_only_count {
257        Some(LineEndingStyle::CrLf)
258    } else if lf_only_count >= cr_only_count {
259        Some(LineEndingStyle::Lf)
260    } else {
261        Some(LineEndingStyle::Cr)
262    }
263}
264
265/// Detect encoding from environment variables.
266///
267/// Checks `LC_ALL`, `LC_CTYPE`, and `LANG` in order.
268#[must_use]
269pub fn detect_encoding_from_env() -> DetectedEncoding {
270    let locale = std::env::var("LC_ALL")
271        .or_else(|_| std::env::var("LC_CTYPE"))
272        .or_else(|_| std::env::var("LANG"))
273        .unwrap_or_default();
274
275    let locale_lower = locale.to_lowercase();
276
277    if locale_lower.contains("utf-8") || locale_lower.contains("utf8") {
278        DetectedEncoding::Utf8
279    } else if locale_lower.contains("iso-8859-1") || locale_lower.contains("iso8859-1") {
280        DetectedEncoding::Latin1
281    } else if locale_lower.contains("1252") {
282        DetectedEncoding::Windows1252
283    } else if locale.is_empty() {
284        // Default to UTF-8 for modern systems
285        DetectedEncoding::Utf8
286    } else {
287        DetectedEncoding::Unknown(locale)
288    }
289}
290
291/// Detected encoding from environment.
292#[derive(Debug, Clone, PartialEq, Eq)]
293pub enum DetectedEncoding {
294    /// UTF-8 encoding.
295    Utf8,
296    /// ISO-8859-1 (Latin-1).
297    Latin1,
298    /// Windows-1252.
299    Windows1252,
300    /// Unknown encoding (contains the locale string).
301    Unknown(String),
302}
303
304impl DetectedEncoding {
305    /// Check if this is UTF-8.
306    #[must_use]
307    pub const fn is_utf8(&self) -> bool {
308        matches!(self, Self::Utf8)
309    }
310}
311
312/// Strip ANSI escape sequences from text.
313///
314/// Removes all ANSI control sequences (CSI, OSC, etc.) from the input.
315#[must_use]
316pub fn strip_ansi(text: &str) -> Cow<'_, str> {
317    // Quick check: if no escape character, return as-is
318    if !text.contains('\x1b') {
319        return Cow::Borrowed(text);
320    }
321
322    let mut result = String::with_capacity(text.len());
323    let mut chars = text.chars().peekable();
324
325    while let Some(c) = chars.next() {
326        if c == '\x1b' {
327            // Start of escape sequence
328            if let Some(&next) = chars.peek() {
329                match next {
330                    '[' => {
331                        // CSI sequence: ESC [ ... final byte
332                        chars.next(); // consume '['
333                        while let Some(&param) = chars.peek() {
334                            if param.is_ascii_alphabetic() || param == '@' || param == '`' {
335                                chars.next(); // consume final byte
336                                break;
337                            }
338                            chars.next();
339                        }
340                    }
341                    ']' => {
342                        // OSC sequence: ESC ] ... ST or BEL
343                        chars.next(); // consume ']'
344                        while let Some(osc_char) = chars.next() {
345                            if osc_char == '\x07' || osc_char == '\x1b' {
346                                // BEL or possible ST
347                                if osc_char == '\x1b' && chars.peek() == Some(&'\\') {
348                                    chars.next(); // consume '\\'
349                                }
350                                break;
351                            }
352                        }
353                    }
354                    '(' | ')' | '*' | '+' => {
355                        // Designate character set: ESC ( X
356                        chars.next();
357                        chars.next();
358                    }
359                    _ if next.is_ascii_uppercase() || next == '=' || next == '>' => {
360                        // Simple escape sequence: ESC X
361                        chars.next();
362                    }
363                    _ => {
364                        // Unknown, just skip the ESC
365                    }
366                }
367            }
368        } else {
369            result.push(c);
370        }
371    }
372
373    Cow::Owned(result)
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn decode_valid_utf8() {
382        let result = decode_utf8_lossy(b"hello world");
383        assert_eq!(result.text, "hello world");
384        assert!(!result.had_errors);
385        assert_eq!(result.replacements, 0);
386    }
387
388    #[test]
389    fn decode_invalid_utf8_lossy() {
390        let result = decode_utf8_lossy(b"hello\xff\xfeworld");
391        assert!(result.text.contains('\u{FFFD}'));
392        assert!(result.had_errors);
393        assert!(result.replacements > 0);
394    }
395
396    #[test]
397    fn decode_invalid_utf8_escape() {
398        let result = decode_utf8_escape(b"hello\xffworld");
399        assert!(result.text.contains("\\xff"));
400        assert!(result.had_errors);
401    }
402
403    #[test]
404    fn decode_invalid_utf8_skip() {
405        let result = decode_utf8_skip(b"hello\xff\xfeworld");
406        assert_eq!(result.text, "helloworld");
407        assert!(result.had_errors);
408    }
409
410    #[test]
411    fn normalize_crlf_to_lf() {
412        let text = "line1\r\nline2\r\nline3";
413        let result = normalize_line_endings(text, LineEndingStyle::Lf);
414        assert_eq!(result, "line1\nline2\nline3");
415    }
416
417    #[test]
418    fn normalize_lf_to_crlf() {
419        let text = "line1\nline2\nline3";
420        let result = normalize_line_endings(text, LineEndingStyle::CrLf);
421        assert_eq!(result, "line1\r\nline2\r\nline3");
422    }
423
424    #[test]
425    fn detect_line_ending_lf() {
426        assert_eq!(
427            detect_line_ending("line1\nline2\n"),
428            Some(LineEndingStyle::Lf)
429        );
430    }
431
432    #[test]
433    fn detect_line_ending_crlf() {
434        assert_eq!(
435            detect_line_ending("line1\r\nline2\r\n"),
436            Some(LineEndingStyle::CrLf)
437        );
438    }
439
440    #[test]
441    fn strip_ansi_csi() {
442        let text = "\x1b[32mgreen\x1b[0m text";
443        let result = strip_ansi(text);
444        assert_eq!(result, "green text");
445    }
446
447    #[test]
448    fn strip_ansi_no_escape() {
449        let text = "plain text";
450        let result = strip_ansi(text);
451        assert!(matches!(result, Cow::Borrowed(_)));
452        assert_eq!(result, "plain text");
453    }
454
455    #[test]
456    fn strip_ansi_osc() {
457        let text = "\x1b]0;Window Title\x07normal text";
458        let result = strip_ansi(text);
459        assert_eq!(result, "normal text");
460    }
461}