Skip to main content

sm4_ff1/
lib.rs

1pub mod ff1error;
2
3use num_bigint::{BigUint, ToBigUint};
4use num_traits::{Zero, Pow, ToPrimitive};
5use std::convert::TryInto;
6use ff1error::FF1Error;
7use num_integer::Integer;
8use neuedu_cryptos::block_ciphers::sm4_cbc_encrypt;
9
10fn xor_bytes(a: &mut [u8], b: &[u8]) {
11    assert_eq!(a.len(), b.len(), "XOR slices must have equal length");
12    for (a_byte, b_byte) in a.iter_mut().zip(b.iter()) {
13        *a_byte ^= *b_byte;
14    }
15}
16
17fn ciph(key: &[u8; 16], input: &[u8; 16]) -> Result<[u8; 16], FF1Error> {
18    let iv = [0u8; 16];
19    let ciphertext = sm4_cbc_encrypt(*key, None, iv, input).unwrap();
20
21    if ciphertext.len() >= 16 {
22        Ok(ciphertext[0..16].try_into().unwrap())
23    } else {
24        Err(FF1Error::CipherLengthError)
25    }
26}
27
28fn prf(key: &[u8; 16], input: &[u8]) -> Result<[u8; 16], FF1Error> {
29    let block_size = 16;
30    let num_blocks = (input.len() + block_size - 1) / block_size;
31    let padded_len = num_blocks * block_size;
32    let mut padded_input = Vec::with_capacity(padded_len);
33    padded_input.extend_from_slice(input);
34    padded_input.resize(padded_len, 0u8);
35
36    let mut mac = [0u8; 16];
37
38    for block in padded_input.chunks_exact(block_size) {
39        let mut block_array: [u8; 16] = block.try_into().expect("Chunk size is 16");
40        xor_bytes(&mut block_array, &mac); // XOR with previous MAC
41        mac = ciph(key, &block_array)?;    // Encrypt using adapted CIPH
42    }
43
44    Ok(mac)
45}
46
47fn num_radix(s: &[u32], radix: u32) -> Result<BigUint, FF1Error> {
48    let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
49    let mut num = BigUint::zero();
50    for &digit in s {
51        if digit >= radix {
52            return Err(FF1Error::InvalidDigit(digit, radix));
53        }
54        num = num * &big_radix + digit.to_biguint().ok_or(FF1Error::BigUintConversion)?;
55    }
56    Ok(num)
57}
58
59fn str_radix(mut c: BigUint, radix: u32, len: usize) -> Result<Vec<u32>, FF1Error> {
60    if c.is_zero() {
61        return Ok(vec![0u32; len]);
62    }
63
64    let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
65    let mut s = Vec::with_capacity(len);
66
67    while !c.is_zero() {
68        let (quotient, remainder) = c.div_rem(&big_radix);
69        s.push(remainder.to_u32().ok_or(FF1Error::BigUintConversion)?);
70        c = quotient;
71    }
72
73    if s.len() > len {
74        // overflow, c >= radix^len, which shouldn't happen after mod op
75        return Err(FF1Error::StrLenMismatch);
76    }
77
78    // Pad with leading zeros
79    s.resize(len, 0u32);
80    s.reverse(); // Digits are generated in reverse order
81    Ok(s)
82}
83
84fn num_bytes(bytes: &[u8]) -> BigUint {
85    BigUint::from_bytes_be(bytes)
86}
87
88fn bytes_radix(n: &BigUint, len: usize) -> Result<Vec<u8>, FF1Error> {
89    let bytes = n.to_bytes_be();
90    if bytes.len() > len {
91        Err(FF1Error::NumToBytesConversion)
92    } else {
93        let padding_len = len - bytes.len();
94        let mut result = Vec::with_capacity(len);
95        result.extend(std::iter::repeat(0u8).take(padding_len));
96        result.extend(bytes);
97        Ok(result)
98    }
99}
100
101
102// --- FF1 Encrypt Function ---
103
104/// Implements FF1 Encryption according to the provided algorithm image.
105///
106/// # Arguments
107/// * `key` - The 128-bit (16 byte) key for SM4.
108/// * `radix` - The base of the numeral string X (2 <= radix <= 2^16).
109/// * `minlen` - Minimum allowed length for X.
110/// * `maxlen` - Maximum allowed length for X.
111/// * `max_tlen` - Maximum allowed byte length for the tweak T.
112/// * `tweak` - The tweak T (byte string).
113/// * `x_digits` - The input numeral string X as a slice of digits (u32 values 0..radix-1).
114///
115/// # Returns
116/// * `Ok(Vec<u32>)` - The encrypted numeral string Y as a vector of digits.
117/// * `Err(FF1Error)` - An error if input parameters are invalid or crypto operations fail.
118pub fn ff1_encrypt(
119    key: &[u8; 16],
120    radix: u32,
121    minlen: usize,
122    maxlen: usize,
123    max_tlen: usize,
124    tweak: &[u8],
125    x_digits: &[u32],
126) -> Result<Vec<u32>, FF1Error> {
127    let n = x_digits.len();
128    let t = tweak.len();
129
130    // Validate input parameters
131    if !(2..=65536).contains(&radix) { // 2^16 = 65536
132        return Err(FF1Error::InvalidRadix { radix });
133    }
134    if !(minlen..=maxlen).contains(&n) {
135        return Err(FF1Error::InvalidLength { n });
136    }
137    if t > max_tlen {
138        return Err(FF1Error::InvalidTweakLength { t });
139    }
140
141    // Check radix^minlen >= 100 constraint
142    let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
143    let min_radix_pow = big_radix.pow(minlen);
144    if min_radix_pow < 100u32.to_biguint().unwrap() {
145        return Err(FF1Error::ConstraintViolation { radix, minlen });
146    }
147
148    for &digit in x_digits {
149        if digit >= radix {
150            return Err(FF1Error::InvalidDigit(digit, radix));
151        }
152    }
153
154    // 1. Let u = floor(n/2); v = n - u.
155    let u = n / 2;
156    let v = n - u;
157
158    // 2. Let A = X[1..u]; B = X[u+1..n].
159    let mut a_digits: Vec<u32> = x_digits[0..u].to_vec();
160    let mut b_digits: Vec<u32> = x_digits[u..n].to_vec();
161
162    // 3. Let b = ceil( ceil(v * log2(radix)) / 8 ).
163    let v_log_radix = v as f64 * (radix as f64).log2();
164    let ceil_v_log_radix = v_log_radix.ceil() as usize;
165    let b = (ceil_v_log_radix + 7) / 8; // ceil(x / 8) = (x + 7) / 8 for integers
166
167    // 4. Let d = 4 * ceil(b / 4) + 4.
168    let ceil_b_div_4 = (b + 3) / 4; // ceil(b / 4)
169    let d = 4 * ceil_b_div_4 + 4;
170
171    // 5. Let P = [1]^1 || [2]^1 || [1]^1 || [radix]^3 || [10]^1 || [u mod 256]^1 || [n]^4 || [t]^4.
172    // P is always 16 bytes.
173    let mut p = Vec::with_capacity(16);
174    p.push(1); // [1]^1
175    p.push(2); // [2]^1
176    p.push(1); // [1]^1
177    let radix_bytes = radix.to_be_bytes(); // u32 -> [u8; 4]
178    p.extend_from_slice(&radix_bytes[1..4]); // Take bytes 1, 2, 3 (total 3 bytes)
179
180    p.push(10); // [10]^1
181    p.push((u % 256) as u8); // [u mod 256]^1
182    p.extend_from_slice(&(n as u32).to_be_bytes()); // [n]^4
183    p.extend_from_slice(&(t as u32).to_be_bytes()); // [t]^4
184    let p_array: [u8; 16] = p.try_into().expect("P should be 16 bytes");
185
186
187    // 6. For i from 0 to 9:
188    for i in 0..10 {
189        // i. Let Q = T || [0]^(-t-b-1) mod 16 || [i]^1 || [NUM_radix(B)]^b.
190        let num_b = num_radix(&b_digits, radix)?;
191        let num_b_bytes = bytes_radix(&num_b, b)?;
192
193        let q_len_before_padding = t + 1 + b; // T + [i]^1 + [NUM_radix(B)]^b
194        let num_zeros = (16 - (q_len_before_padding % 16)) % 16; // (-t-b-1) mod 16
195
196        let mut q = Vec::with_capacity(t + num_zeros + 1 + b);
197        q.extend_from_slice(tweak); // T
198        q.extend(std::iter::repeat(0u8).take(num_zeros)); // [0] padding
199        q.push(i as u8); // [i]^1
200        q.extend_from_slice(&num_b_bytes); // [NUM_radix(B)]^b
201
202        // ii. Let R = PRF(P || Q).
203        let mut prf_input = Vec::with_capacity(16 + q.len());
204        prf_input.extend_from_slice(&p_array);
205        prf_input.extend_from_slice(&q);
206        let r = prf(key, &prf_input)?; // R is 16 bytes
207
208        // iii. Let S be the first d bytes of R || CIPH_K(R ^ [1]^16) || ... || CIPH_K(R ^ [ceil(d/16)-1]^16).
209        let num_s_blocks = (d + 15) / 16; // ceil(d/16)
210        let mut s_bytes = Vec::with_capacity(num_s_blocks * 16);
211
212        s_bytes.extend_from_slice(&r);
213
214        if num_s_blocks > 1 {
215            let mut r_xor_j = r;
216            for j in 1..num_s_blocks { // Loop from 1 up to ceil(d/16) - 1
217                // Prepare R ^ [j]^16
218                // [j]^16 means j represented as 16 bytes, big-endian.
219                // Since j is small (max 9 for i=9, d depends on b), only last bytes matter.
220                let j_bytes = (j as u32).to_be_bytes(); // Use u32 for j
221                let mut j_block = [0u8; 16];
222                j_block[12..16].copy_from_slice(&j_bytes); // Place j in the last 4 bytes
223
224                xor_bytes(&mut r_xor_j, &j_block); // R ^ [j]^16 (modifies r_xor_j)
225
226                let s_block = ciph(key, &r_xor_j)?; // CIPH_K(R ^ [j]^16)
227                s_bytes.extend_from_slice(&s_block);
228
229                // Important: Reset r_xor_j for the next iteration's XOR
230                // Or, more simply, XOR the *original* R with j_block each time
231                r_xor_j = r; // Reset to original R before next XOR
232                // Alternative: XOR again with j_block to undo, but starting fresh is cleaner.
233            }
234        }
235        // Truncate S to exactly d bytes
236        s_bytes.truncate(d);
237
238        // iv. Let y = NUM(S).
239        let y = num_bytes(&s_bytes);
240
241        // v. If i is even, let m = u; else, let m = v.
242        let m = if i % 2 == 0 { u } else { v };
243
244        // vi. Let c = (NUM_radix(A) + y) mod radix^m.
245        let num_a = num_radix(&a_digits, radix)?;
246        let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
247        let modulus = big_radix.pow(m);
248        let c = (num_a + y) % modulus;
249
250        // vii. Let C = STR_radix^m(c).
251        let c_digits = str_radix(c, radix, m)?;
252
253        // viii. Let A = B.
254        a_digits = b_digits; // Move B to A
255
256        // ix. Let B = C.
257        b_digits = c_digits; // Move C to B
258    }
259
260    // 7. Return A || B.
261    let mut result_digits = a_digits;
262    result_digits.extend(b_digits);
263
264    Ok(result_digits)
265}
266
267/// Implements FF1 Decryption according to the provided Algorithm 8 image.
268///
269/// # Arguments
270/// * `key` - The 128-bit (16 byte) key for SM4.
271/// * `radix` - The base of the numeral string X (2 <= radix <= 2^16).
272/// * `minlen` - Minimum allowed length for X.
273/// * `maxlen` - Maximum allowed length for X.
274/// * `max_tlen` - Maximum allowed byte length for the tweak T.
275/// * `tweak` - The tweak T (byte string).
276/// * `x_digits` - The input ciphertext numeral string X as a slice of digits (u32 values 0..radix-1).
277///
278/// # Returns
279/// * `Ok(Vec<u32>)` - The decrypted plaintext numeral string Y as a vector of digits.
280/// * `Err(FF1Error)` - An error if input parameters are invalid or crypto operations fail.
281pub fn ff1_decrypt(
282    key: &[u8; 16],
283    radix: u32,
284    minlen: usize,
285    maxlen: usize,
286    max_tlen: usize,
287    tweak: &[u8],
288    x_digits: &[u32],
289) -> Result<Vec<u32>, FF1Error> {
290    let n = x_digits.len();
291    let t = tweak.len();
292
293    // Validate input parameters
294    if !(2..=65536).contains(&radix) {
295        return Err(FF1Error::InvalidRadix { radix });
296    }
297    if !(minlen..=maxlen).contains(&n) {
298        return Err(FF1Error::InvalidLength { n });
299    }
300    if t > max_tlen {
301        return Err(FF1Error::InvalidTweakLength { t });
302    }
303    let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
304    let min_radix_pow = big_radix.pow(minlen as u32);
305    if min_radix_pow < 100u32.to_biguint().unwrap() {
306        return Err(FF1Error::ConstraintViolation { radix, minlen });
307    }
308    for &digit in x_digits {
309        if digit >= radix {
310            return Err(FF1Error::InvalidDigit(digit, radix));
311        }
312    }
313
314    // 1. Let u = floor(n/2); v = n - u.
315    let u = n / 2;
316    let v = n - u;
317
318    // 2. Let A = X[1..u]; B = X[u+1..n].
319    let mut a_digits: Vec<u32> = x_digits[0..u].to_vec();
320    let mut b_digits: Vec<u32> = x_digits[u..n].to_vec();
321
322    // 3. Let b = ceil( ceil(v * log2(radix)) / 8 ).
323    let v_log_radix = v as f64 * (radix as f64).log2();
324    let ceil_v_log_radix = v_log_radix.ceil();
325    let b = if ceil_v_log_radix <= 0.0 { 0 } else { (ceil_v_log_radix as usize + 7) / 8 };
326
327    // 4. Let d = 4 * ceil(b / 4) + 4.
328    let ceil_b_div_4 = (b + 3) / 4;
329    let d = 4 * ceil_b_div_4 + 4;
330
331    // 5. Let P = [1]^1 || [2]^1 || [1]^1 || [radix]^3 || [10]^1 || [u mod 256]^1 || [n]^4 || [t]^4.
332    let mut p = Vec::with_capacity(16);
333    p.push(1); p.push(2); p.push(1);
334    let radix_bytes = radix.to_be_bytes();
335    p.extend_from_slice(&radix_bytes[1..4]);
336    p.push(10);
337    p.push((u % 256) as u8);
338    p.extend_from_slice(&(n as u32).to_be_bytes());
339    p.extend_from_slice(&(t as u32).to_be_bytes());
340    let p_array: [u8; 16] = p.try_into().expect("P should be 16 bytes");
341
342    // 6. For i from 9 down to 0:
343    for i in (0..10).rev() {
344        // v. If i is even, let m = u; else, let m = v.
345        let m = if i % 2 == 0 { u } else { v };
346
347        // i. Let Q = T || [0]^(-t-b-1) mod 16 || [i]^1 || [NUM_radix(A)]^b.
348        let num_a = num_radix(&a_digits, radix)?; // Use A here
349        let num_a_bytes = bytes_radix(&num_a, b)?;
350
351        let q_len_before_padding = t + 1 + b;
352        let num_zeros = (16 - (q_len_before_padding % 16)) % 16;
353
354        let mut q = Vec::with_capacity(t + num_zeros + 1 + b);
355        q.extend_from_slice(tweak);
356        q.extend(std::iter::repeat(0u8).take(num_zeros));
357        q.push(i as u8);
358        q.extend_from_slice(&num_a_bytes); // Use NUM_radix(A) bytes
359
360        // ii. Let R = PRF(P || Q).
361        let mut prf_input = Vec::with_capacity(16 + q.len());
362        prf_input.extend_from_slice(&p_array);
363        prf_input.extend_from_slice(&q);
364        let r = prf(key, &prf_input)?;
365
366        // iii. Let S be the first d bytes of R || CIPH_K(R ^ [1]^16) || ...
367        let num_s_blocks_total = (d + 15) / 16;
368        let mut s_bytes = Vec::with_capacity(num_s_blocks_total * 16);
369        s_bytes.extend_from_slice(&r);
370        if num_s_blocks_total > 1 {
371            let r_xor_j_base = r;
372            for j_val in 1..num_s_blocks_total {
373                let j_bytes = (j_val as u32).to_be_bytes();
374                let mut j_block = [0u8; 16];
375                j_block[12..16].copy_from_slice(&j_bytes);
376                let mut r_xor_j = r_xor_j_base;
377                xor_bytes(&mut r_xor_j, &j_block);
378                let s_block = ciph(key, &r_xor_j)?;
379                s_bytes.extend_from_slice(&s_block);
380            }
381        }
382        s_bytes.truncate(d);
383
384        // iv. Let y = NUM(S).
385        let y = num_bytes(&s_bytes);
386
387        // vi. Let c = (NUM_radix(B) - y) mod radix^m.
388        let num_b = num_radix(&b_digits, radix)?;
389        let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
390        let modulus = big_radix.pow(m as u32);
391
392        // (a - b) mod n == (a - (b mod n) + n) mod n
393        let y_mod = y % &modulus;
394        let c = if num_b >= y_mod {
395            (num_b - y_mod) % &modulus // Standard case
396        } else {
397            // num_b < y_mod, so num_b - y_mod is negative
398            // Add modulus before taking the final modulo
399            (num_b + &modulus - y_mod) % &modulus
400        };
401
402
403        // vii. Let C = STR_radix^m(c).
404        let c_digits = str_radix(c, radix, m)?;
405
406        // viii. Let B = A.
407        b_digits = a_digits;
408
409        // ix. Let A = C.
410        a_digits = c_digits;
411    }
412
413    // 7. Return A || B.
414    let mut result_digits = a_digits;
415    result_digits.extend(b_digits);
416
417    Ok(result_digits)
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn round_trip_test() {
426        let pt_str = "3216";
427        let tweak_str = "1329999";
428        let key: [u8; 16] = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
429            0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
430        let radix: u32 = 10;
431
432        let tweak_bytes = tweak_str.as_bytes();
433
434        let minlen = 2;
435        let maxlen = 100;
436        let max_tlen = 32;
437
438        let x_digits: Vec<u32> = pt_str
439            .chars()
440            .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
441            .collect::<Result<Vec<_>, _>>().unwrap(); // Propagate potential char conversion error
442
443        let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
444        let decrypted_digits = ff1_decrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &result_digits).unwrap();
445        assert_eq!(decrypted_digits, x_digits, "Encryption result does not match expected ciphertext");
446    }
447
448    #[test]
449    fn csdn_case_test() {
450        let pt_str = "3216";
451        let tweak_str = "1329999";
452        let key: [u8; 16] = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
453            0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
454        let expected_ciphertext_str = "8956";
455        let radix: u32 = 10;
456
457        let tweak_bytes = tweak_str.as_bytes();
458
459        let minlen = 2;
460        let maxlen = 100;
461        let max_tlen = 32;
462
463        let x_digits: Vec<u32> = pt_str
464            .chars()
465            .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
466            .collect::<Result<Vec<_>, _>>().unwrap(); // Propagate potential char conversion error
467
468        // Convert expected ciphertext string to Vec<u32> digits for comparison
469        let expected_digits: Vec<u32> = expected_ciphertext_str
470            .chars()
471            .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
472            .collect::<Result<Vec<_>, _>>().unwrap();
473
474        let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
475        assert_eq!(result_digits, expected_digits, "Encryption result does not match expected ciphertext");
476    }
477
478    #[test]
479    fn failed_test() {
480        let pt_str = "620805";
481        let tweak_str = "4601000000004101LS6A2E0F4NA000030";
482        let key = b"6666666600000000";
483        let expected_ciphertext_str = "003131";
484        let radix: u32 = 10;
485
486        let tweak_bytes = tweak_str.as_bytes();
487
488        let minlen = 2;
489        let maxlen = 100;
490        let max_tlen = 50;
491
492        let x_digits: Vec<u32> = pt_str
493            .chars()
494            .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
495            .collect::<Result<Vec<_>, _>>().unwrap(); // Propagate potential char conversion error
496
497        // Convert expected ciphertext string to Vec<u32> digits for comparison
498        let expected_digits: Vec<u32> = expected_ciphertext_str
499            .chars()
500            .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
501            .collect::<Result<Vec<_>, _>>().unwrap();
502
503        let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
504        assert_eq!(result_digits, expected_digits, "Encryption result does not match expected ciphertext");
505    }
506}