1pub mod ff1error;
2
3use num_bigint::{BigUint, ToBigUint};
4use num_traits::{Zero, Pow, ToPrimitive};
5use std::convert::TryInto;
6use ff1error::FF1Error;
7use num_integer::Integer;
8use neuedu_cryptos::block_ciphers::sm4_cbc_encrypt;
9
10fn xor_bytes(a: &mut [u8], b: &[u8]) {
11 assert_eq!(a.len(), b.len(), "XOR slices must have equal length");
12 for (a_byte, b_byte) in a.iter_mut().zip(b.iter()) {
13 *a_byte ^= *b_byte;
14 }
15}
16
17fn ciph(key: &[u8; 16], input: &[u8; 16]) -> Result<[u8; 16], FF1Error> {
18 let iv = [0u8; 16];
19 let ciphertext = sm4_cbc_encrypt(*key, None, iv, input).unwrap();
20
21 if ciphertext.len() >= 16 {
22 Ok(ciphertext[0..16].try_into().unwrap())
23 } else {
24 Err(FF1Error::CipherLengthError)
25 }
26}
27
28fn prf(key: &[u8; 16], input: &[u8]) -> Result<[u8; 16], FF1Error> {
29 let block_size = 16;
30 let num_blocks = (input.len() + block_size - 1) / block_size;
31 let padded_len = num_blocks * block_size;
32 let mut padded_input = Vec::with_capacity(padded_len);
33 padded_input.extend_from_slice(input);
34 padded_input.resize(padded_len, 0u8);
35
36 let mut mac = [0u8; 16];
37
38 for block in padded_input.chunks_exact(block_size) {
39 let mut block_array: [u8; 16] = block.try_into().expect("Chunk size is 16");
40 xor_bytes(&mut block_array, &mac); mac = ciph(key, &block_array)?; }
43
44 Ok(mac)
45}
46
47fn num_radix(s: &[u32], radix: u32) -> Result<BigUint, FF1Error> {
48 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
49 let mut num = BigUint::zero();
50 for &digit in s {
51 if digit >= radix {
52 return Err(FF1Error::InvalidDigit(digit, radix));
53 }
54 num = num * &big_radix + digit.to_biguint().ok_or(FF1Error::BigUintConversion)?;
55 }
56 Ok(num)
57}
58
59fn str_radix(mut c: BigUint, radix: u32, len: usize) -> Result<Vec<u32>, FF1Error> {
60 if c.is_zero() {
61 return Ok(vec![0u32; len]);
62 }
63
64 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
65 let mut s = Vec::with_capacity(len);
66
67 while !c.is_zero() {
68 let (quotient, remainder) = c.div_rem(&big_radix);
69 s.push(remainder.to_u32().ok_or(FF1Error::BigUintConversion)?);
70 c = quotient;
71 }
72
73 if s.len() > len {
74 return Err(FF1Error::StrLenMismatch);
76 }
77
78 s.resize(len, 0u32);
80 s.reverse(); Ok(s)
82}
83
84fn num_bytes(bytes: &[u8]) -> BigUint {
85 BigUint::from_bytes_be(bytes)
86}
87
88fn bytes_radix(n: &BigUint, len: usize) -> Result<Vec<u8>, FF1Error> {
89 let bytes = n.to_bytes_be();
90 if bytes.len() > len {
91 Err(FF1Error::NumToBytesConversion)
92 } else {
93 let padding_len = len - bytes.len();
94 let mut result = Vec::with_capacity(len);
95 result.extend(std::iter::repeat(0u8).take(padding_len));
96 result.extend(bytes);
97 Ok(result)
98 }
99}
100
101
102pub fn ff1_encrypt(
119 key: &[u8; 16],
120 radix: u32,
121 minlen: usize,
122 maxlen: usize,
123 max_tlen: usize,
124 tweak: &[u8],
125 x_digits: &[u32],
126) -> Result<Vec<u32>, FF1Error> {
127 let n = x_digits.len();
128 let t = tweak.len();
129
130 if !(2..=65536).contains(&radix) { return Err(FF1Error::InvalidRadix { radix });
133 }
134 if !(minlen..=maxlen).contains(&n) {
135 return Err(FF1Error::InvalidLength { n });
136 }
137 if t > max_tlen {
138 return Err(FF1Error::InvalidTweakLength { t });
139 }
140
141 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
143 let min_radix_pow = big_radix.pow(minlen);
144 if min_radix_pow < 100u32.to_biguint().unwrap() {
145 return Err(FF1Error::ConstraintViolation { radix, minlen });
146 }
147
148 for &digit in x_digits {
149 if digit >= radix {
150 return Err(FF1Error::InvalidDigit(digit, radix));
151 }
152 }
153
154 let u = n / 2;
156 let v = n - u;
157
158 let mut a_digits: Vec<u32> = x_digits[0..u].to_vec();
160 let mut b_digits: Vec<u32> = x_digits[u..n].to_vec();
161
162 let v_log_radix = v as f64 * (radix as f64).log2();
164 let ceil_v_log_radix = v_log_radix.ceil() as usize;
165 let b = (ceil_v_log_radix + 7) / 8; let ceil_b_div_4 = (b + 3) / 4; let d = 4 * ceil_b_div_4 + 4;
170
171 let mut p = Vec::with_capacity(16);
174 p.push(1); p.push(2); p.push(1); let radix_bytes = radix.to_be_bytes(); p.extend_from_slice(&radix_bytes[1..4]); p.push(10); p.push((u % 256) as u8); p.extend_from_slice(&(n as u32).to_be_bytes()); p.extend_from_slice(&(t as u32).to_be_bytes()); let p_array: [u8; 16] = p.try_into().expect("P should be 16 bytes");
185
186
187 for i in 0..10 {
189 let num_b = num_radix(&b_digits, radix)?;
191 let num_b_bytes = bytes_radix(&num_b, b)?;
192
193 let q_len_before_padding = t + 1 + b; let num_zeros = (16 - (q_len_before_padding % 16)) % 16; let mut q = Vec::with_capacity(t + num_zeros + 1 + b);
197 q.extend_from_slice(tweak); q.extend(std::iter::repeat(0u8).take(num_zeros)); q.push(i as u8); q.extend_from_slice(&num_b_bytes); let mut prf_input = Vec::with_capacity(16 + q.len());
204 prf_input.extend_from_slice(&p_array);
205 prf_input.extend_from_slice(&q);
206 let r = prf(key, &prf_input)?; let num_s_blocks = (d + 15) / 16; let mut s_bytes = Vec::with_capacity(num_s_blocks * 16);
211
212 s_bytes.extend_from_slice(&r);
213
214 if num_s_blocks > 1 {
215 let mut r_xor_j = r;
216 for j in 1..num_s_blocks { let j_bytes = (j as u32).to_be_bytes(); let mut j_block = [0u8; 16];
222 j_block[12..16].copy_from_slice(&j_bytes); xor_bytes(&mut r_xor_j, &j_block); let s_block = ciph(key, &r_xor_j)?; s_bytes.extend_from_slice(&s_block);
228
229 r_xor_j = r; }
234 }
235 s_bytes.truncate(d);
237
238 let y = num_bytes(&s_bytes);
240
241 let m = if i % 2 == 0 { u } else { v };
243
244 let num_a = num_radix(&a_digits, radix)?;
246 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
247 let modulus = big_radix.pow(m);
248 let c = (num_a + y) % modulus;
249
250 let c_digits = str_radix(c, radix, m)?;
252
253 a_digits = b_digits; b_digits = c_digits; }
259
260 let mut result_digits = a_digits;
262 result_digits.extend(b_digits);
263
264 Ok(result_digits)
265}
266
267pub fn ff1_decrypt(
282 key: &[u8; 16],
283 radix: u32,
284 minlen: usize,
285 maxlen: usize,
286 max_tlen: usize,
287 tweak: &[u8],
288 x_digits: &[u32],
289) -> Result<Vec<u32>, FF1Error> {
290 let n = x_digits.len();
291 let t = tweak.len();
292
293 if !(2..=65536).contains(&radix) {
295 return Err(FF1Error::InvalidRadix { radix });
296 }
297 if !(minlen..=maxlen).contains(&n) {
298 return Err(FF1Error::InvalidLength { n });
299 }
300 if t > max_tlen {
301 return Err(FF1Error::InvalidTweakLength { t });
302 }
303 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
304 let min_radix_pow = big_radix.pow(minlen as u32);
305 if min_radix_pow < 100u32.to_biguint().unwrap() {
306 return Err(FF1Error::ConstraintViolation { radix, minlen });
307 }
308 for &digit in x_digits {
309 if digit >= radix {
310 return Err(FF1Error::InvalidDigit(digit, radix));
311 }
312 }
313
314 let u = n / 2;
316 let v = n - u;
317
318 let mut a_digits: Vec<u32> = x_digits[0..u].to_vec();
320 let mut b_digits: Vec<u32> = x_digits[u..n].to_vec();
321
322 let v_log_radix = v as f64 * (radix as f64).log2();
324 let ceil_v_log_radix = v_log_radix.ceil();
325 let b = if ceil_v_log_radix <= 0.0 { 0 } else { (ceil_v_log_radix as usize + 7) / 8 };
326
327 let ceil_b_div_4 = (b + 3) / 4;
329 let d = 4 * ceil_b_div_4 + 4;
330
331 let mut p = Vec::with_capacity(16);
333 p.push(1); p.push(2); p.push(1);
334 let radix_bytes = radix.to_be_bytes();
335 p.extend_from_slice(&radix_bytes[1..4]);
336 p.push(10);
337 p.push((u % 256) as u8);
338 p.extend_from_slice(&(n as u32).to_be_bytes());
339 p.extend_from_slice(&(t as u32).to_be_bytes());
340 let p_array: [u8; 16] = p.try_into().expect("P should be 16 bytes");
341
342 for i in (0..10).rev() {
344 let m = if i % 2 == 0 { u } else { v };
346
347 let num_a = num_radix(&a_digits, radix)?; let num_a_bytes = bytes_radix(&num_a, b)?;
350
351 let q_len_before_padding = t + 1 + b;
352 let num_zeros = (16 - (q_len_before_padding % 16)) % 16;
353
354 let mut q = Vec::with_capacity(t + num_zeros + 1 + b);
355 q.extend_from_slice(tweak);
356 q.extend(std::iter::repeat(0u8).take(num_zeros));
357 q.push(i as u8);
358 q.extend_from_slice(&num_a_bytes); let mut prf_input = Vec::with_capacity(16 + q.len());
362 prf_input.extend_from_slice(&p_array);
363 prf_input.extend_from_slice(&q);
364 let r = prf(key, &prf_input)?;
365
366 let num_s_blocks_total = (d + 15) / 16;
368 let mut s_bytes = Vec::with_capacity(num_s_blocks_total * 16);
369 s_bytes.extend_from_slice(&r);
370 if num_s_blocks_total > 1 {
371 let r_xor_j_base = r;
372 for j_val in 1..num_s_blocks_total {
373 let j_bytes = (j_val as u32).to_be_bytes();
374 let mut j_block = [0u8; 16];
375 j_block[12..16].copy_from_slice(&j_bytes);
376 let mut r_xor_j = r_xor_j_base;
377 xor_bytes(&mut r_xor_j, &j_block);
378 let s_block = ciph(key, &r_xor_j)?;
379 s_bytes.extend_from_slice(&s_block);
380 }
381 }
382 s_bytes.truncate(d);
383
384 let y = num_bytes(&s_bytes);
386
387 let num_b = num_radix(&b_digits, radix)?;
389 let big_radix = radix.to_biguint().ok_or(FF1Error::BigUintConversion)?;
390 let modulus = big_radix.pow(m as u32);
391
392 let y_mod = y % &modulus;
394 let c = if num_b >= y_mod {
395 (num_b - y_mod) % &modulus } else {
397 (num_b + &modulus - y_mod) % &modulus
400 };
401
402
403 let c_digits = str_radix(c, radix, m)?;
405
406 b_digits = a_digits;
408
409 a_digits = c_digits;
411 }
412
413 let mut result_digits = a_digits;
415 result_digits.extend(b_digits);
416
417 Ok(result_digits)
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn round_trip_test() {
426 let pt_str = "3216";
427 let tweak_str = "1329999";
428 let key: [u8; 16] = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
429 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
430 let radix: u32 = 10;
431
432 let tweak_bytes = tweak_str.as_bytes();
433
434 let minlen = 2;
435 let maxlen = 100;
436 let max_tlen = 32;
437
438 let x_digits: Vec<u32> = pt_str
439 .chars()
440 .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
441 .collect::<Result<Vec<_>, _>>().unwrap(); let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
444 let decrypted_digits = ff1_decrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &result_digits).unwrap();
445 assert_eq!(decrypted_digits, x_digits, "Encryption result does not match expected ciphertext");
446 }
447
448 #[test]
449 fn csdn_case_test() {
450 let pt_str = "3216";
451 let tweak_str = "1329999";
452 let key: [u8; 16] = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
453 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
454 let expected_ciphertext_str = "8956";
455 let radix: u32 = 10;
456
457 let tweak_bytes = tweak_str.as_bytes();
458
459 let minlen = 2;
460 let maxlen = 100;
461 let max_tlen = 32;
462
463 let x_digits: Vec<u32> = pt_str
464 .chars()
465 .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
466 .collect::<Result<Vec<_>, _>>().unwrap(); let expected_digits: Vec<u32> = expected_ciphertext_str
470 .chars()
471 .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
472 .collect::<Result<Vec<_>, _>>().unwrap();
473
474 let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
475 assert_eq!(result_digits, expected_digits, "Encryption result does not match expected ciphertext");
476 }
477
478 #[test]
479 fn failed_test() {
480 let pt_str = "620805";
481 let tweak_str = "4601000000004101LS6A2E0F4NA000030";
482 let key = b"6666666600000000";
483 let expected_ciphertext_str = "003131";
484 let radix: u32 = 10;
485
486 let tweak_bytes = tweak_str.as_bytes();
487
488 let minlen = 2;
489 let maxlen = 100;
490 let max_tlen = 50;
491
492 let x_digits: Vec<u32> = pt_str
493 .chars()
494 .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
495 .collect::<Result<Vec<_>, _>>().unwrap(); let expected_digits: Vec<u32> = expected_ciphertext_str
499 .chars()
500 .map(|c| c.to_digit(radix).ok_or_else(|| FF1Error::InvalidCharDigit(c, radix)))
501 .collect::<Result<Vec<_>, _>>().unwrap();
502
503 let result_digits = ff1_encrypt(&key, radix, minlen, maxlen, max_tlen, tweak_bytes, &x_digits).unwrap();
504 assert_eq!(result_digits, expected_digits, "Encryption result does not match expected ciphertext");
505 }
506}