lib

Core libraries for Radroots
git clone https://radroots.dev/git/lib.git
Log | Files | Refs | README | LICENSE

handshake.rs (18295B)


      1 use crate::error::RadrootsSimplexSmpTransportError;
      2 use crate::frame::{
      3     RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE,
      4     decode_padded_bytes, encode_padded_bytes,
      5 };
      6 use alloc::string::{String, ToString};
      7 use alloc::vec::Vec;
      8 use radroots_simplex_smp_proto::prelude::{
      9     RADROOTS_SIMPLEX_SMP_CURRENT_TRANSPORT_VERSION, RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION,
     10     RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION,
     11     RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION, RadrootsSimplexSmpVersionRange,
     12 };
     13 
     14 pub const RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1: &str = "smp/1";
     15 pub const RADROOTS_SIMPLEX_SMP_TLS_V1_3_CIPHER_SUITE: &str = "TLS_CHACHA20_POLY1305_SHA256";
     16 pub const RADROOTS_SIMPLEX_SMP_TLS_SIGNATURE_ALGORITHM: &str = "ed25519";
     17 pub const RADROOTS_SIMPLEX_SMP_TLS_KEY_EXCHANGE_GROUP: &str = "x25519";
     18 
     19 #[derive(Debug, Clone, PartialEq, Eq)]
     20 pub struct RadrootsSimplexSmpTransportServerProof {
     21     pub certificate_payload: Vec<u8>,
     22     pub signed_server_key: Vec<u8>,
     23 }
     24 
     25 #[derive(Debug, Clone, PartialEq, Eq)]
     26 pub struct RadrootsSimplexSmpServerHello {
     27     pub version_range: RadrootsSimplexSmpVersionRange,
     28     pub session_identifier: Vec<u8>,
     29     pub server_proof: Option<RadrootsSimplexSmpTransportServerProof>,
     30     pub ignored_part: Vec<u8>,
     31 }
     32 
     33 #[derive(Debug, Clone, PartialEq, Eq)]
     34 pub struct RadrootsSimplexSmpClientHello {
     35     pub chosen_version: u16,
     36     pub server_key_hash: Vec<u8>,
     37     pub client_key: Option<Vec<u8>>,
     38     pub proxy_server: bool,
     39     pub ignored_part: Vec<u8>,
     40 }
     41 
     42 #[derive(Debug, Clone, PartialEq, Eq)]
     43 pub struct RadrootsSimplexSmpTlsPolicy {
     44     pub expected_server_identity: String,
     45     pub supported_versions: RadrootsSimplexSmpVersionRange,
     46     pub require_current_alpn: bool,
     47     pub allow_session_resumption: bool,
     48     pub allowed_certificate_chain_lengths: [usize; 3],
     49     pub require_tls_unique_binding: bool,
     50     pub require_server_proof: bool,
     51 }
     52 
     53 impl RadrootsSimplexSmpTlsPolicy {
     54     pub fn modern(expected_server_identity: impl Into<String>) -> Self {
     55         Self {
     56             expected_server_identity: expected_server_identity.into(),
     57             supported_versions: RadrootsSimplexSmpVersionRange::single(
     58                 RADROOTS_SIMPLEX_SMP_CURRENT_TRANSPORT_VERSION,
     59             ),
     60             require_current_alpn: true,
     61             allow_session_resumption: false,
     62             allowed_certificate_chain_lengths: [2, 3, 4],
     63             require_tls_unique_binding: true,
     64             require_server_proof: false,
     65         }
     66     }
     67 }
     68 
     69 #[derive(Debug, Clone, PartialEq, Eq)]
     70 pub struct RadrootsSimplexSmpTlsHandshakeEvidence {
     71     pub confirmed_alpn: Option<String>,
     72     pub session_resumed: bool,
     73     pub certificate_chain_length: usize,
     74     pub online_certificate_fingerprint: String,
     75     pub tls_unique_channel_binding: Option<Vec<u8>>,
     76 }
     77 
     78 impl RadrootsSimplexSmpServerHello {
     79     pub fn encode(&self) -> Result<Vec<u8>, RadrootsSimplexSmpTransportError> {
     80         let mut payload = Vec::new();
     81         payload.extend_from_slice(&self.version_range.min.to_be_bytes());
     82         payload.extend_from_slice(&self.version_range.max.to_be_bytes());
     83         push_short_bytes(&mut payload, &self.session_identifier)?;
     84         if let Some(proof) = &self.server_proof {
     85             payload.extend_from_slice(&proof.certificate_payload);
     86             push_large_bytes(&mut payload, &proof.signed_server_key)?;
     87         }
     88         payload.extend_from_slice(&self.ignored_part);
     89         encode_padded_bytes(
     90             &payload,
     91             RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE,
     92             RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE,
     93         )
     94     }
     95 
     96     pub fn decode(bytes: &[u8]) -> Result<Self, RadrootsSimplexSmpTransportError> {
     97         let payload = decode_padded_bytes(
     98             bytes,
     99             RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE,
    100             RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE,
    101         )?;
    102         let Some(version_bytes) = payload.get(..4) else {
    103             return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    104                 "smp_version_range",
    105             ));
    106         };
    107         let min = u16::from_be_bytes([version_bytes[0], version_bytes[1]]);
    108         let max = u16::from_be_bytes([version_bytes[2], version_bytes[3]]);
    109         let version_range = RadrootsSimplexSmpVersionRange::new(min, max)
    110             .map_err(RadrootsSimplexSmpTransportError::from)?;
    111         let (session_identifier, cursor) = read_short_bytes(&payload, 4)?;
    112         if session_identifier.len() > u8::MAX as usize {
    113             return Err(
    114                 RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength(
    115                     session_identifier.len(),
    116                 ),
    117             );
    118         }
    119         let (server_proof, ignored_part) = parse_optional_server_proof(&payload[cursor..]);
    120 
    121         Ok(Self {
    122             version_range,
    123             session_identifier,
    124             server_proof,
    125             ignored_part,
    126         })
    127     }
    128 }
    129 
    130 impl RadrootsSimplexSmpClientHello {
    131     pub fn encode(&self) -> Result<Vec<u8>, RadrootsSimplexSmpTransportError> {
    132         let mut payload = Vec::new();
    133         payload.extend_from_slice(&self.chosen_version.to_be_bytes());
    134         push_short_bytes(&mut payload, &self.server_key_hash)?;
    135         if let Some(client_key) = &self.client_key {
    136             push_short_bytes(&mut payload, client_key)?;
    137         }
    138         if self.chosen_version >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION {
    139             payload.push(if self.proxy_server { b'T' } else { b'F' });
    140         }
    141         if self.chosen_version >= RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION {
    142             payload.push(b'0');
    143         }
    144         payload.extend_from_slice(&self.ignored_part);
    145         encode_padded_bytes(
    146             &payload,
    147             RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE,
    148             RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE,
    149         )
    150     }
    151 
    152     pub fn decode(bytes: &[u8]) -> Result<Self, RadrootsSimplexSmpTransportError> {
    153         let payload = decode_padded_bytes(
    154             bytes,
    155             RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE,
    156             RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE,
    157         )?;
    158         let Some(version_bytes) = payload.get(..2) else {
    159             return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    160                 "chosen_version",
    161             ));
    162         };
    163         let chosen_version = u16::from_be_bytes([version_bytes[0], version_bytes[1]]);
    164         let (server_key_hash, mut cursor) = read_short_bytes(&payload, 2)?;
    165         let client_key = if chosen_version
    166             >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION
    167             && matches!(payload.get(cursor), Some(b'T' | b'F'))
    168         {
    169             None
    170         } else {
    171             let (client_key, consumed) = parse_optional_client_key(&payload[cursor..]);
    172             cursor += consumed;
    173             client_key
    174         };
    175         let proxy_server =
    176             if chosen_version >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION {
    177                 let Some(value) = payload.get(cursor) else {
    178                     return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    179                         "proxy_server",
    180                     ));
    181                 };
    182                 cursor += 1;
    183                 match *value {
    184                     b'T' => true,
    185                     b'F' => false,
    186                     _ => {
    187                         return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    188                             "proxy_server",
    189                         ));
    190                     }
    191                 }
    192             } else {
    193                 false
    194             };
    195         if chosen_version >= RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION {
    196             let Some(tag) = payload.get(cursor) else {
    197                 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    198                     "client_service",
    199                 ));
    200             };
    201             cursor += 1;
    202             if *tag != b'0' {
    203                 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    204                     "client_service",
    205                 ));
    206             }
    207         }
    208         let ignored_part = payload[cursor..].to_vec();
    209 
    210         Ok(Self {
    211             chosen_version,
    212             server_key_hash,
    213             client_key,
    214             proxy_server,
    215             ignored_part,
    216         })
    217     }
    218 }
    219 
    220 pub fn negotiate_transport_version(
    221     offered: RadrootsSimplexSmpVersionRange,
    222     supported: RadrootsSimplexSmpVersionRange,
    223     confirmed_alpn: Option<&str>,
    224 ) -> Result<u16, RadrootsSimplexSmpTransportError> {
    225     if confirmed_alpn == Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1) {
    226         let min = offered.min.max(supported.min);
    227         let max = offered.max.min(supported.max);
    228         if min > max {
    229             return Err(RadrootsSimplexSmpTransportError::NoMutualTransportVersion {
    230                 offered: offered.to_string(),
    231                 supported: supported.to_string(),
    232             });
    233         }
    234         return Ok(max);
    235     }
    236 
    237     if offered.contains(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION)
    238         && supported.contains(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION)
    239     {
    240         return Ok(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION);
    241     }
    242 
    243     Err(RadrootsSimplexSmpTransportError::NoMutualTransportVersion {
    244         offered: offered.to_string(),
    245         supported: supported.to_string(),
    246     })
    247 }
    248 
    249 pub fn validate_tls_handshake(
    250     policy: &RadrootsSimplexSmpTlsPolicy,
    251     server_hello: &RadrootsSimplexSmpServerHello,
    252     evidence: &RadrootsSimplexSmpTlsHandshakeEvidence,
    253 ) -> Result<u16, RadrootsSimplexSmpTransportError> {
    254     if policy.require_current_alpn
    255         && evidence.confirmed_alpn.as_deref() != Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1)
    256     {
    257         return Err(RadrootsSimplexSmpTransportError::UnsupportedAlpn(
    258             evidence.confirmed_alpn.clone().unwrap_or_default(),
    259         ));
    260     }
    261     if !policy.allow_session_resumption && evidence.session_resumed {
    262         return Err(RadrootsSimplexSmpTransportError::SessionResumptionNotAllowed);
    263     }
    264     if !policy
    265         .allowed_certificate_chain_lengths
    266         .contains(&evidence.certificate_chain_length)
    267     {
    268         return Err(
    269             RadrootsSimplexSmpTransportError::InvalidCertificateChainLength(
    270                 evidence.certificate_chain_length,
    271             ),
    272         );
    273     }
    274     if evidence.online_certificate_fingerprint != policy.expected_server_identity {
    275         return Err(RadrootsSimplexSmpTransportError::ServerIdentityMismatch {
    276             expected: policy.expected_server_identity.clone(),
    277             actual: evidence.online_certificate_fingerprint.clone(),
    278         });
    279     }
    280     if policy.require_server_proof && server_hello.server_proof.is_none() {
    281         return Err(RadrootsSimplexSmpTransportError::MissingServerProof);
    282     }
    283     if policy.require_tls_unique_binding {
    284         let Some(binding) = evidence.tls_unique_channel_binding.as_ref() else {
    285             return Err(RadrootsSimplexSmpTransportError::MissingChannelBinding);
    286         };
    287         if binding.as_slice() != server_hello.session_identifier.as_slice() {
    288             return Err(RadrootsSimplexSmpTransportError::SessionBindingMismatch);
    289         }
    290     }
    291 
    292     negotiate_transport_version(
    293         server_hello.version_range,
    294         policy.supported_versions,
    295         evidence.confirmed_alpn.as_deref(),
    296     )
    297 }
    298 
    299 fn push_short_bytes(
    300     buffer: &mut Vec<u8>,
    301     bytes: &[u8],
    302 ) -> Result<(), RadrootsSimplexSmpTransportError> {
    303     if bytes.len() > u8::MAX as usize {
    304         return Err(RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength(bytes.len()));
    305     }
    306     buffer.push(bytes.len() as u8);
    307     buffer.extend_from_slice(bytes);
    308     Ok(())
    309 }
    310 
    311 fn push_large_bytes(
    312     buffer: &mut Vec<u8>,
    313     bytes: &[u8],
    314 ) -> Result<(), RadrootsSimplexSmpTransportError> {
    315     let len = u16::try_from(bytes.len()).map_err(|_| {
    316         RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength(bytes.len())
    317     })?;
    318     buffer.extend_from_slice(&len.to_be_bytes());
    319     buffer.extend_from_slice(bytes);
    320     Ok(())
    321 }
    322 
    323 fn read_short_bytes(
    324     payload: &[u8],
    325     offset: usize,
    326 ) -> Result<(Vec<u8>, usize), RadrootsSimplexSmpTransportError> {
    327     let Some(&length) = payload.get(offset) else {
    328         return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    329             "short_field",
    330         ));
    331     };
    332     let start = offset + 1;
    333     let end = start + length as usize;
    334     let Some(value) = payload.get(start..end) else {
    335         return Err(
    336             radroots_simplex_smp_proto::prelude::RadrootsSimplexSmpProtoError::UnexpectedEof.into(),
    337         );
    338     };
    339     Ok((value.to_vec(), end))
    340 }
    341 
    342 fn read_large_bytes(
    343     payload: &[u8],
    344     offset: usize,
    345 ) -> Result<(Vec<u8>, usize), RadrootsSimplexSmpTransportError> {
    346     let Some(length_bytes) = payload.get(offset..offset + 2) else {
    347         return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField(
    348             "large_field",
    349         ));
    350     };
    351     let length = u16::from_be_bytes([length_bytes[0], length_bytes[1]]) as usize;
    352     let start = offset + 2;
    353     let end = start + length;
    354     let Some(value) = payload.get(start..end) else {
    355         return Err(
    356             radroots_simplex_smp_proto::prelude::RadrootsSimplexSmpProtoError::UnexpectedEof.into(),
    357         );
    358     };
    359     Ok((value.to_vec(), end))
    360 }
    361 
    362 fn parse_optional_server_proof(
    363     remainder: &[u8],
    364 ) -> (Option<RadrootsSimplexSmpTransportServerProof>, Vec<u8>) {
    365     let Some(&cert_count) = remainder.first() else {
    366         return (None, remainder.to_vec());
    367     };
    368     if cert_count == 0 {
    369         return (None, remainder.to_vec());
    370     }
    371     let mut cursor = 1;
    372     for _ in 0..cert_count {
    373         let Ok((_, next_cursor)) = read_large_bytes(remainder, cursor) else {
    374             return (None, remainder.to_vec());
    375         };
    376         cursor = next_cursor;
    377     }
    378     let Ok((signed_server_key, cursor)) = read_large_bytes(remainder, cursor) else {
    379         return (None, remainder.to_vec());
    380     };
    381     (
    382         Some(RadrootsSimplexSmpTransportServerProof {
    383             certificate_payload: remainder[..cursor - signed_server_key.len() - 2].to_vec(),
    384             signed_server_key,
    385         }),
    386         remainder[cursor..].to_vec(),
    387     )
    388 }
    389 
    390 fn parse_optional_client_key(remainder: &[u8]) -> (Option<Vec<u8>>, usize) {
    391     let Some(&length) = remainder.first() else {
    392         return (None, 0);
    393     };
    394     let end = 1 + length as usize;
    395     if length == 0 || end > remainder.len() {
    396         return (None, remainder.len());
    397     }
    398     (Some(remainder[1..end].to_vec()), end)
    399 }
    400 
    401 #[cfg(test)]
    402 mod tests {
    403     use super::*;
    404 
    405     #[test]
    406     fn roundtrips_server_hello_and_validates_binding() {
    407         let hello = RadrootsSimplexSmpServerHello {
    408             version_range: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(),
    409             session_identifier: b"tls-unique-binding".to_vec(),
    410             server_proof: Some(RadrootsSimplexSmpTransportServerProof {
    411                 certificate_payload: encode_certificate_chain_payload([b"cert-chain".as_slice()]),
    412                 signed_server_key: b"signed-key".to_vec(),
    413             }),
    414             ignored_part: b"ignored".to_vec(),
    415         };
    416 
    417         let decoded = RadrootsSimplexSmpServerHello::decode(&hello.encode().unwrap()).unwrap();
    418         assert_eq!(decoded, hello);
    419 
    420         let policy = RadrootsSimplexSmpTlsPolicy {
    421             expected_server_identity: "fingerprint".to_string(),
    422             supported_versions: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(),
    423             require_current_alpn: false,
    424             allow_session_resumption: false,
    425             allowed_certificate_chain_lengths: [2, 3, 4],
    426             require_tls_unique_binding: true,
    427             require_server_proof: true,
    428         };
    429         let version = validate_tls_handshake(
    430             &policy,
    431             &decoded,
    432             &RadrootsSimplexSmpTlsHandshakeEvidence {
    433                 confirmed_alpn: Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1.to_string()),
    434                 session_resumed: false,
    435                 certificate_chain_length: 3,
    436                 online_certificate_fingerprint: "fingerprint".to_string(),
    437                 tls_unique_channel_binding: Some(b"tls-unique-binding".to_vec()),
    438             },
    439         )
    440         .unwrap();
    441         assert_eq!(version, 17);
    442     }
    443 
    444     #[test]
    445     fn falls_back_to_initial_transport_version_without_current_alpn() {
    446         let version = negotiate_transport_version(
    447             RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(),
    448             RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(),
    449             None,
    450         )
    451         .unwrap();
    452         assert_eq!(version, 6);
    453     }
    454 
    455     #[test]
    456     fn rejects_mismatched_server_identity() {
    457         let hello = RadrootsSimplexSmpServerHello {
    458             version_range: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(),
    459             session_identifier: b"bind".to_vec(),
    460             server_proof: None,
    461             ignored_part: Vec::new(),
    462         };
    463         let policy = RadrootsSimplexSmpTlsPolicy::modern("expected");
    464         let error = validate_tls_handshake(
    465             &policy,
    466             &hello,
    467             &RadrootsSimplexSmpTlsHandshakeEvidence {
    468                 confirmed_alpn: Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1.to_string()),
    469                 session_resumed: false,
    470                 certificate_chain_length: 2,
    471                 online_certificate_fingerprint: "actual".to_string(),
    472                 tls_unique_channel_binding: Some(b"bind".to_vec()),
    473             },
    474         )
    475         .unwrap_err();
    476         assert!(matches!(
    477             error,
    478             RadrootsSimplexSmpTransportError::ServerIdentityMismatch { .. }
    479         ));
    480     }
    481 
    482     fn encode_certificate_chain_payload<'a, I>(certificates: I) -> Vec<u8>
    483     where
    484         I: IntoIterator<Item = &'a [u8]>,
    485     {
    486         let certificates: Vec<&[u8]> = certificates.into_iter().collect();
    487         let mut payload = vec![certificates.len() as u8];
    488         for certificate in certificates {
    489             payload.extend_from_slice(&(certificate.len() as u16).to_be_bytes());
    490             payload.extend_from_slice(certificate);
    491         }
    492         payload
    493     }
    494 }