handshake.rs (18295B)
1 use crate::error::RadrootsSimplexSmpTransportError; 2 use crate::frame::{ 3 RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE, 4 decode_padded_bytes, encode_padded_bytes, 5 }; 6 use alloc::string::{String, ToString}; 7 use alloc::vec::Vec; 8 use radroots_simplex_smp_proto::prelude::{ 9 RADROOTS_SIMPLEX_SMP_CURRENT_TRANSPORT_VERSION, RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION, 10 RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION, 11 RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION, RadrootsSimplexSmpVersionRange, 12 }; 13 14 pub const RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1: &str = "smp/1"; 15 pub const RADROOTS_SIMPLEX_SMP_TLS_V1_3_CIPHER_SUITE: &str = "TLS_CHACHA20_POLY1305_SHA256"; 16 pub const RADROOTS_SIMPLEX_SMP_TLS_SIGNATURE_ALGORITHM: &str = "ed25519"; 17 pub const RADROOTS_SIMPLEX_SMP_TLS_KEY_EXCHANGE_GROUP: &str = "x25519"; 18 19 #[derive(Debug, Clone, PartialEq, Eq)] 20 pub struct RadrootsSimplexSmpTransportServerProof { 21 pub certificate_payload: Vec<u8>, 22 pub signed_server_key: Vec<u8>, 23 } 24 25 #[derive(Debug, Clone, PartialEq, Eq)] 26 pub struct RadrootsSimplexSmpServerHello { 27 pub version_range: RadrootsSimplexSmpVersionRange, 28 pub session_identifier: Vec<u8>, 29 pub server_proof: Option<RadrootsSimplexSmpTransportServerProof>, 30 pub ignored_part: Vec<u8>, 31 } 32 33 #[derive(Debug, Clone, PartialEq, Eq)] 34 pub struct RadrootsSimplexSmpClientHello { 35 pub chosen_version: u16, 36 pub server_key_hash: Vec<u8>, 37 pub client_key: Option<Vec<u8>>, 38 pub proxy_server: bool, 39 pub ignored_part: Vec<u8>, 40 } 41 42 #[derive(Debug, Clone, PartialEq, Eq)] 43 pub struct RadrootsSimplexSmpTlsPolicy { 44 pub expected_server_identity: String, 45 pub supported_versions: RadrootsSimplexSmpVersionRange, 46 pub require_current_alpn: bool, 47 pub allow_session_resumption: bool, 48 pub allowed_certificate_chain_lengths: [usize; 3], 49 pub require_tls_unique_binding: bool, 50 pub require_server_proof: bool, 51 } 52 53 impl RadrootsSimplexSmpTlsPolicy { 54 pub fn modern(expected_server_identity: impl Into<String>) -> Self { 55 Self { 56 expected_server_identity: expected_server_identity.into(), 57 supported_versions: RadrootsSimplexSmpVersionRange::single( 58 RADROOTS_SIMPLEX_SMP_CURRENT_TRANSPORT_VERSION, 59 ), 60 require_current_alpn: true, 61 allow_session_resumption: false, 62 allowed_certificate_chain_lengths: [2, 3, 4], 63 require_tls_unique_binding: true, 64 require_server_proof: false, 65 } 66 } 67 } 68 69 #[derive(Debug, Clone, PartialEq, Eq)] 70 pub struct RadrootsSimplexSmpTlsHandshakeEvidence { 71 pub confirmed_alpn: Option<String>, 72 pub session_resumed: bool, 73 pub certificate_chain_length: usize, 74 pub online_certificate_fingerprint: String, 75 pub tls_unique_channel_binding: Option<Vec<u8>>, 76 } 77 78 impl RadrootsSimplexSmpServerHello { 79 pub fn encode(&self) -> Result<Vec<u8>, RadrootsSimplexSmpTransportError> { 80 let mut payload = Vec::new(); 81 payload.extend_from_slice(&self.version_range.min.to_be_bytes()); 82 payload.extend_from_slice(&self.version_range.max.to_be_bytes()); 83 push_short_bytes(&mut payload, &self.session_identifier)?; 84 if let Some(proof) = &self.server_proof { 85 payload.extend_from_slice(&proof.certificate_payload); 86 push_large_bytes(&mut payload, &proof.signed_server_key)?; 87 } 88 payload.extend_from_slice(&self.ignored_part); 89 encode_padded_bytes( 90 &payload, 91 RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, 92 RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE, 93 ) 94 } 95 96 pub fn decode(bytes: &[u8]) -> Result<Self, RadrootsSimplexSmpTransportError> { 97 let payload = decode_padded_bytes( 98 bytes, 99 RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, 100 RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE, 101 )?; 102 let Some(version_bytes) = payload.get(..4) else { 103 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 104 "smp_version_range", 105 )); 106 }; 107 let min = u16::from_be_bytes([version_bytes[0], version_bytes[1]]); 108 let max = u16::from_be_bytes([version_bytes[2], version_bytes[3]]); 109 let version_range = RadrootsSimplexSmpVersionRange::new(min, max) 110 .map_err(RadrootsSimplexSmpTransportError::from)?; 111 let (session_identifier, cursor) = read_short_bytes(&payload, 4)?; 112 if session_identifier.len() > u8::MAX as usize { 113 return Err( 114 RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength( 115 session_identifier.len(), 116 ), 117 ); 118 } 119 let (server_proof, ignored_part) = parse_optional_server_proof(&payload[cursor..]); 120 121 Ok(Self { 122 version_range, 123 session_identifier, 124 server_proof, 125 ignored_part, 126 }) 127 } 128 } 129 130 impl RadrootsSimplexSmpClientHello { 131 pub fn encode(&self) -> Result<Vec<u8>, RadrootsSimplexSmpTransportError> { 132 let mut payload = Vec::new(); 133 payload.extend_from_slice(&self.chosen_version.to_be_bytes()); 134 push_short_bytes(&mut payload, &self.server_key_hash)?; 135 if let Some(client_key) = &self.client_key { 136 push_short_bytes(&mut payload, client_key)?; 137 } 138 if self.chosen_version >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION { 139 payload.push(if self.proxy_server { b'T' } else { b'F' }); 140 } 141 if self.chosen_version >= RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION { 142 payload.push(b'0'); 143 } 144 payload.extend_from_slice(&self.ignored_part); 145 encode_padded_bytes( 146 &payload, 147 RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, 148 RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE, 149 ) 150 } 151 152 pub fn decode(bytes: &[u8]) -> Result<Self, RadrootsSimplexSmpTransportError> { 153 let payload = decode_padded_bytes( 154 bytes, 155 RADROOTS_SIMPLEX_SMP_TRANSPORT_BLOCK_SIZE, 156 RADROOTS_SIMPLEX_SMP_TRANSPORT_PAD_BYTE, 157 )?; 158 let Some(version_bytes) = payload.get(..2) else { 159 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 160 "chosen_version", 161 )); 162 }; 163 let chosen_version = u16::from_be_bytes([version_bytes[0], version_bytes[1]]); 164 let (server_key_hash, mut cursor) = read_short_bytes(&payload, 2)?; 165 let client_key = if chosen_version 166 >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION 167 && matches!(payload.get(cursor), Some(b'T' | b'F')) 168 { 169 None 170 } else { 171 let (client_key, consumed) = parse_optional_client_key(&payload[cursor..]); 172 cursor += consumed; 173 client_key 174 }; 175 let proxy_server = 176 if chosen_version >= RADROOTS_SIMPLEX_SMP_PROXY_SERVER_HANDSHAKE_TRANSPORT_VERSION { 177 let Some(value) = payload.get(cursor) else { 178 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 179 "proxy_server", 180 )); 181 }; 182 cursor += 1; 183 match *value { 184 b'T' => true, 185 b'F' => false, 186 _ => { 187 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 188 "proxy_server", 189 )); 190 } 191 } 192 } else { 193 false 194 }; 195 if chosen_version >= RADROOTS_SIMPLEX_SMP_SERVICE_CERTS_TRANSPORT_VERSION { 196 let Some(tag) = payload.get(cursor) else { 197 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 198 "client_service", 199 )); 200 }; 201 cursor += 1; 202 if *tag != b'0' { 203 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 204 "client_service", 205 )); 206 } 207 } 208 let ignored_part = payload[cursor..].to_vec(); 209 210 Ok(Self { 211 chosen_version, 212 server_key_hash, 213 client_key, 214 proxy_server, 215 ignored_part, 216 }) 217 } 218 } 219 220 pub fn negotiate_transport_version( 221 offered: RadrootsSimplexSmpVersionRange, 222 supported: RadrootsSimplexSmpVersionRange, 223 confirmed_alpn: Option<&str>, 224 ) -> Result<u16, RadrootsSimplexSmpTransportError> { 225 if confirmed_alpn == Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1) { 226 let min = offered.min.max(supported.min); 227 let max = offered.max.min(supported.max); 228 if min > max { 229 return Err(RadrootsSimplexSmpTransportError::NoMutualTransportVersion { 230 offered: offered.to_string(), 231 supported: supported.to_string(), 232 }); 233 } 234 return Ok(max); 235 } 236 237 if offered.contains(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION) 238 && supported.contains(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION) 239 { 240 return Ok(RADROOTS_SIMPLEX_SMP_INITIAL_TRANSPORT_VERSION); 241 } 242 243 Err(RadrootsSimplexSmpTransportError::NoMutualTransportVersion { 244 offered: offered.to_string(), 245 supported: supported.to_string(), 246 }) 247 } 248 249 pub fn validate_tls_handshake( 250 policy: &RadrootsSimplexSmpTlsPolicy, 251 server_hello: &RadrootsSimplexSmpServerHello, 252 evidence: &RadrootsSimplexSmpTlsHandshakeEvidence, 253 ) -> Result<u16, RadrootsSimplexSmpTransportError> { 254 if policy.require_current_alpn 255 && evidence.confirmed_alpn.as_deref() != Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1) 256 { 257 return Err(RadrootsSimplexSmpTransportError::UnsupportedAlpn( 258 evidence.confirmed_alpn.clone().unwrap_or_default(), 259 )); 260 } 261 if !policy.allow_session_resumption && evidence.session_resumed { 262 return Err(RadrootsSimplexSmpTransportError::SessionResumptionNotAllowed); 263 } 264 if !policy 265 .allowed_certificate_chain_lengths 266 .contains(&evidence.certificate_chain_length) 267 { 268 return Err( 269 RadrootsSimplexSmpTransportError::InvalidCertificateChainLength( 270 evidence.certificate_chain_length, 271 ), 272 ); 273 } 274 if evidence.online_certificate_fingerprint != policy.expected_server_identity { 275 return Err(RadrootsSimplexSmpTransportError::ServerIdentityMismatch { 276 expected: policy.expected_server_identity.clone(), 277 actual: evidence.online_certificate_fingerprint.clone(), 278 }); 279 } 280 if policy.require_server_proof && server_hello.server_proof.is_none() { 281 return Err(RadrootsSimplexSmpTransportError::MissingServerProof); 282 } 283 if policy.require_tls_unique_binding { 284 let Some(binding) = evidence.tls_unique_channel_binding.as_ref() else { 285 return Err(RadrootsSimplexSmpTransportError::MissingChannelBinding); 286 }; 287 if binding.as_slice() != server_hello.session_identifier.as_slice() { 288 return Err(RadrootsSimplexSmpTransportError::SessionBindingMismatch); 289 } 290 } 291 292 negotiate_transport_version( 293 server_hello.version_range, 294 policy.supported_versions, 295 evidence.confirmed_alpn.as_deref(), 296 ) 297 } 298 299 fn push_short_bytes( 300 buffer: &mut Vec<u8>, 301 bytes: &[u8], 302 ) -> Result<(), RadrootsSimplexSmpTransportError> { 303 if bytes.len() > u8::MAX as usize { 304 return Err(RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength(bytes.len())); 305 } 306 buffer.push(bytes.len() as u8); 307 buffer.extend_from_slice(bytes); 308 Ok(()) 309 } 310 311 fn push_large_bytes( 312 buffer: &mut Vec<u8>, 313 bytes: &[u8], 314 ) -> Result<(), RadrootsSimplexSmpTransportError> { 315 let len = u16::try_from(bytes.len()).map_err(|_| { 316 RadrootsSimplexSmpTransportError::InvalidSessionIdentifierLength(bytes.len()) 317 })?; 318 buffer.extend_from_slice(&len.to_be_bytes()); 319 buffer.extend_from_slice(bytes); 320 Ok(()) 321 } 322 323 fn read_short_bytes( 324 payload: &[u8], 325 offset: usize, 326 ) -> Result<(Vec<u8>, usize), RadrootsSimplexSmpTransportError> { 327 let Some(&length) = payload.get(offset) else { 328 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 329 "short_field", 330 )); 331 }; 332 let start = offset + 1; 333 let end = start + length as usize; 334 let Some(value) = payload.get(start..end) else { 335 return Err( 336 radroots_simplex_smp_proto::prelude::RadrootsSimplexSmpProtoError::UnexpectedEof.into(), 337 ); 338 }; 339 Ok((value.to_vec(), end)) 340 } 341 342 fn read_large_bytes( 343 payload: &[u8], 344 offset: usize, 345 ) -> Result<(Vec<u8>, usize), RadrootsSimplexSmpTransportError> { 346 let Some(length_bytes) = payload.get(offset..offset + 2) else { 347 return Err(RadrootsSimplexSmpTransportError::MissingHandshakeField( 348 "large_field", 349 )); 350 }; 351 let length = u16::from_be_bytes([length_bytes[0], length_bytes[1]]) as usize; 352 let start = offset + 2; 353 let end = start + length; 354 let Some(value) = payload.get(start..end) else { 355 return Err( 356 radroots_simplex_smp_proto::prelude::RadrootsSimplexSmpProtoError::UnexpectedEof.into(), 357 ); 358 }; 359 Ok((value.to_vec(), end)) 360 } 361 362 fn parse_optional_server_proof( 363 remainder: &[u8], 364 ) -> (Option<RadrootsSimplexSmpTransportServerProof>, Vec<u8>) { 365 let Some(&cert_count) = remainder.first() else { 366 return (None, remainder.to_vec()); 367 }; 368 if cert_count == 0 { 369 return (None, remainder.to_vec()); 370 } 371 let mut cursor = 1; 372 for _ in 0..cert_count { 373 let Ok((_, next_cursor)) = read_large_bytes(remainder, cursor) else { 374 return (None, remainder.to_vec()); 375 }; 376 cursor = next_cursor; 377 } 378 let Ok((signed_server_key, cursor)) = read_large_bytes(remainder, cursor) else { 379 return (None, remainder.to_vec()); 380 }; 381 ( 382 Some(RadrootsSimplexSmpTransportServerProof { 383 certificate_payload: remainder[..cursor - signed_server_key.len() - 2].to_vec(), 384 signed_server_key, 385 }), 386 remainder[cursor..].to_vec(), 387 ) 388 } 389 390 fn parse_optional_client_key(remainder: &[u8]) -> (Option<Vec<u8>>, usize) { 391 let Some(&length) = remainder.first() else { 392 return (None, 0); 393 }; 394 let end = 1 + length as usize; 395 if length == 0 || end > remainder.len() { 396 return (None, remainder.len()); 397 } 398 (Some(remainder[1..end].to_vec()), end) 399 } 400 401 #[cfg(test)] 402 mod tests { 403 use super::*; 404 405 #[test] 406 fn roundtrips_server_hello_and_validates_binding() { 407 let hello = RadrootsSimplexSmpServerHello { 408 version_range: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(), 409 session_identifier: b"tls-unique-binding".to_vec(), 410 server_proof: Some(RadrootsSimplexSmpTransportServerProof { 411 certificate_payload: encode_certificate_chain_payload([b"cert-chain".as_slice()]), 412 signed_server_key: b"signed-key".to_vec(), 413 }), 414 ignored_part: b"ignored".to_vec(), 415 }; 416 417 let decoded = RadrootsSimplexSmpServerHello::decode(&hello.encode().unwrap()).unwrap(); 418 assert_eq!(decoded, hello); 419 420 let policy = RadrootsSimplexSmpTlsPolicy { 421 expected_server_identity: "fingerprint".to_string(), 422 supported_versions: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(), 423 require_current_alpn: false, 424 allow_session_resumption: false, 425 allowed_certificate_chain_lengths: [2, 3, 4], 426 require_tls_unique_binding: true, 427 require_server_proof: true, 428 }; 429 let version = validate_tls_handshake( 430 &policy, 431 &decoded, 432 &RadrootsSimplexSmpTlsHandshakeEvidence { 433 confirmed_alpn: Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1.to_string()), 434 session_resumed: false, 435 certificate_chain_length: 3, 436 online_certificate_fingerprint: "fingerprint".to_string(), 437 tls_unique_channel_binding: Some(b"tls-unique-binding".to_vec()), 438 }, 439 ) 440 .unwrap(); 441 assert_eq!(version, 17); 442 } 443 444 #[test] 445 fn falls_back_to_initial_transport_version_without_current_alpn() { 446 let version = negotiate_transport_version( 447 RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(), 448 RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(), 449 None, 450 ) 451 .unwrap(); 452 assert_eq!(version, 6); 453 } 454 455 #[test] 456 fn rejects_mismatched_server_identity() { 457 let hello = RadrootsSimplexSmpServerHello { 458 version_range: RadrootsSimplexSmpVersionRange::new(6, 17).unwrap(), 459 session_identifier: b"bind".to_vec(), 460 server_proof: None, 461 ignored_part: Vec::new(), 462 }; 463 let policy = RadrootsSimplexSmpTlsPolicy::modern("expected"); 464 let error = validate_tls_handshake( 465 &policy, 466 &hello, 467 &RadrootsSimplexSmpTlsHandshakeEvidence { 468 confirmed_alpn: Some(RADROOTS_SIMPLEX_SMP_TLS_ALPN_V1.to_string()), 469 session_resumed: false, 470 certificate_chain_length: 2, 471 online_certificate_fingerprint: "actual".to_string(), 472 tls_unique_channel_binding: Some(b"bind".to_vec()), 473 }, 474 ) 475 .unwrap_err(); 476 assert!(matches!( 477 error, 478 RadrootsSimplexSmpTransportError::ServerIdentityMismatch { .. } 479 )); 480 } 481 482 fn encode_certificate_chain_payload<'a, I>(certificates: I) -> Vec<u8> 483 where 484 I: IntoIterator<Item = &'a [u8]>, 485 { 486 let certificates: Vec<&[u8]> = certificates.into_iter().collect(); 487 let mut payload = vec![certificates.len() as u8]; 488 for certificate in certificates { 489 payload.extend_from_slice(&(certificate.len() as u16).to_be_bytes()); 490 payload.extend_from_slice(certificate); 491 } 492 payload 493 } 494 }