remote_prove.rs (16788B)
1 #![forbid(unsafe_code)] 2 #![cfg_attr(coverage_nightly, coverage(off))] 3 4 use crate::cli::Command; 5 use radroots_sp1_guest_trade::{ 6 RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION, 7 RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION, 8 }; 9 use radroots_sp1_host_trade::{ 10 RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE, 11 RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode, RadrootsSp1TradeRemoteProverRequest, 12 RadrootsSp1TradeRemoteProverResponse, RadrootsSp1TradeRemoteProverStatus, 13 }; 14 use std::path::Path; 15 16 pub async fn run_cli_command(command: Command) -> anyhow::Result<()> { 17 let Command::RemoteProve { 18 input, 19 output, 20 proof_engine, 21 } = command 22 else { 23 return Err(anyhow::anyhow!("remote-prove command expected")); 24 }; 25 let engine = RadrootsSp1TradeProofEngine::from_label(proof_engine.as_str()) 26 .ok_or_else(|| anyhow::anyhow!("invalid proof engine"))?; 27 let request_bytes = read_input(input.as_deref())?; 28 let response = handle_request_bytes(&request_bytes, engine).await; 29 let response_bytes = serde_json::to_vec_pretty(&response)?; 30 write_output(output.as_deref(), &response_bytes)?; 31 if response.status == RadrootsSp1TradeRemoteProverStatus::Completed { 32 Ok(()) 33 } else { 34 Err(anyhow::anyhow!( 35 "{}", 36 response 37 .message 38 .as_deref() 39 .unwrap_or("remote proof request did not complete") 40 )) 41 } 42 } 43 44 pub async fn handle_request_bytes( 45 bytes: &[u8], 46 engine: RadrootsSp1TradeProofEngine, 47 ) -> RadrootsSp1TradeRemoteProverResponse { 48 let request_id = request_id_from_bytes(bytes); 49 let request = match serde_json::from_slice::<RadrootsSp1TradeRemoteProverRequest>(bytes) { 50 Ok(request) => request, 51 Err(error) => { 52 return rejected_response(request_id, "invalid_json", error.to_string()); 53 } 54 }; 55 match validate_request(&request) { 56 Ok(()) => prove_request(request, engine).await, 57 Err(rejection) => { 58 rejected_response(request.request_id, rejection.reason, rejection.message) 59 } 60 } 61 } 62 63 #[derive(Clone, Copy, Debug, PartialEq, Eq)] 64 struct RemoteProveRejection { 65 reason: &'static str, 66 message: &'static str, 67 } 68 69 fn validate_request( 70 request: &RadrootsSp1TradeRemoteProverRequest, 71 ) -> Result<(), RemoteProveRejection> { 72 if request.schema_version != RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION { 73 return Err(rejection( 74 "invalid_schema_version", 75 "invalid schema_version", 76 )); 77 } 78 if request.request_id.trim().is_empty() { 79 return Err(rejection("invalid_request_id", "invalid request_id")); 80 } 81 if request.proof_target != RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET { 82 return Err(rejection( 83 "unsupported_proof_target", 84 "unsupported proof_target", 85 )); 86 } 87 if request.proof_mode != RadrootsSp1TradeProofMode::Core { 88 return Err(rejection( 89 "unsupported_proof_mode", 90 "unsupported proof_mode", 91 )); 92 } 93 if request.sp1_version_line != RADROOTS_SP1_TRADE_SP1_VERSION_LINE { 94 return Err(rejection( 95 "unsupported_sp1_version", 96 "unsupported sp1 version", 97 )); 98 } 99 if request.expected_reducer_program_hash != RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH { 100 return Err(rejection( 101 "reducer_program_hash_mismatch", 102 "expected reducer program hash mismatch", 103 )); 104 } 105 if request.expected_protocol_version != RADROOTS_SP1_TRADE_PROTOCOL_VERSION { 106 return Err(rejection( 107 "protocol_version_mismatch", 108 "expected protocol version mismatch", 109 )); 110 } 111 if request.expected_witness_version != RADROOTS_SP1_TRADE_WITNESS_VERSION { 112 return Err(rejection( 113 "witness_version_mismatch", 114 "expected witness version mismatch", 115 )); 116 } 117 if request.witness.witness_version != request.expected_witness_version { 118 return Err(rejection( 119 "witness_version_mismatch", 120 "witness version mismatch", 121 )); 122 } 123 if request.witness.proof_target != request.proof_target { 124 return Err(rejection( 125 "proof_target_mismatch", 126 "witness proof target mismatch", 127 )); 128 } 129 if request.witness.reducer_program_hash != request.expected_reducer_program_hash { 130 return Err(rejection( 131 "reducer_program_hash_mismatch", 132 "witness reducer program hash mismatch", 133 )); 134 } 135 if request.witness.radroots_protocol_version != request.expected_protocol_version { 136 return Err(rejection( 137 "protocol_version_mismatch", 138 "witness protocol version mismatch", 139 )); 140 } 141 if request.witness.sp1_program_hash.as_deref() 142 != Some(request.expected_sp1_program_hash.as_str()) 143 { 144 return Err(rejection( 145 "sp1_program_hash_mismatch", 146 "witness SP1 program hash mismatch", 147 )); 148 } 149 if request.witness.sp1_verifying_key_hash.as_deref() 150 != Some(request.expected_sp1_verifying_key_hash.as_str()) 151 { 152 return Err(rejection( 153 "sp1_verifying_key_hash_mismatch", 154 "witness SP1 verifying key hash mismatch", 155 )); 156 } 157 for value in [ 158 request.expected_sp1_program_hash.as_str(), 159 request.expected_sp1_verifying_key_hash.as_str(), 160 request.expected_public_values_hash.as_str(), 161 ] { 162 if !is_hash32(value) { 163 return Err(rejection("invalid_hash", "expected hash field is invalid")); 164 } 165 } 166 let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values( 167 &request.witness, 168 ) 169 .map_err(|_| { 170 rejection( 171 "public_values_execution_failed", 172 "public values execution failed", 173 ) 174 })?; 175 if execution.public_values_hash != request.expected_public_values_hash { 176 return Err(rejection( 177 "public_values_hash_mismatch", 178 "expected public values hash mismatch", 179 )); 180 } 181 Ok(()) 182 } 183 184 fn rejection(reason: &'static str, message: &'static str) -> RemoteProveRejection { 185 RemoteProveRejection { reason, message } 186 } 187 188 #[cfg(feature = "sp1_proving")] 189 async fn prove_request( 190 request: RadrootsSp1TradeRemoteProverRequest, 191 engine: RadrootsSp1TradeProofEngine, 192 ) -> RadrootsSp1TradeRemoteProverResponse { 193 match radroots_sp1_host_trade::generate_order_acceptance_sp1_proof_with_engine( 194 &request.witness, 195 request.proof_mode, 196 engine, 197 ) 198 .await 199 { 200 Ok(bundle) => completed_response(request, bundle), 201 Err(error) => failed_response( 202 request.request_id, 203 "proof_generation_failed", 204 error.to_string(), 205 ), 206 } 207 } 208 209 #[cfg(not(feature = "sp1_proving"))] 210 async fn prove_request( 211 request: RadrootsSp1TradeRemoteProverRequest, 212 _engine: RadrootsSp1TradeProofEngine, 213 ) -> RadrootsSp1TradeRemoteProverResponse { 214 failed_response( 215 request.request_id, 216 "proof_generation_unavailable", 217 "remote-prove requires the sp1_proving feature".to_string(), 218 ) 219 } 220 221 #[cfg(feature = "sp1_proving")] 222 fn completed_response( 223 request: RadrootsSp1TradeRemoteProverRequest, 224 bundle: radroots_sp1_host_trade::RadrootsSp1TradeProofBundle, 225 ) -> RadrootsSp1TradeRemoteProverResponse { 226 if bundle.execution.public_values_hash != request.expected_public_values_hash { 227 return failed_response( 228 request.request_id, 229 "public_values_hash_mismatch", 230 "generated public values hash mismatch".to_string(), 231 ); 232 } 233 if bundle.proof.program_hash.as_deref() != Some(request.expected_sp1_program_hash.as_str()) { 234 return failed_response( 235 request.request_id, 236 "sp1_program_hash_mismatch", 237 "generated SP1 program hash mismatch".to_string(), 238 ); 239 } 240 if bundle.proof.verifying_key_hash.as_deref() 241 != Some(request.expected_sp1_verifying_key_hash.as_str()) 242 { 243 return failed_response( 244 request.request_id, 245 "sp1_verifying_key_hash_mismatch", 246 "generated SP1 verifying key hash mismatch".to_string(), 247 ); 248 } 249 RadrootsSp1TradeRemoteProverResponse { 250 schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, 251 request_id: request.request_id, 252 status: RadrootsSp1TradeRemoteProverStatus::Completed, 253 status_url: None, 254 status_path: None, 255 proof_system: Some(request.proof_mode.proof_system()), 256 proof_mode: Some(request.proof_mode), 257 public_values_hash: Some(bundle.execution.public_values_hash), 258 sp1_program_hash: bundle.proof.program_hash.clone(), 259 sp1_verifying_key_hash: bundle.proof.verifying_key_hash.clone(), 260 proof_artifact: Some(bundle.proof), 261 resolved_proof_envelope_base64: None, 262 reason_code: None, 263 message: None, 264 detail: None, 265 } 266 } 267 268 fn rejected_response( 269 request_id: String, 270 reason: impl Into<String>, 271 message: impl Into<String>, 272 ) -> RadrootsSp1TradeRemoteProverResponse { 273 terminal_response( 274 request_id, 275 RadrootsSp1TradeRemoteProverStatus::Rejected, 276 reason, 277 message, 278 ) 279 } 280 281 fn failed_response( 282 request_id: String, 283 reason: impl Into<String>, 284 message: impl Into<String>, 285 ) -> RadrootsSp1TradeRemoteProverResponse { 286 terminal_response( 287 request_id, 288 RadrootsSp1TradeRemoteProverStatus::Failed, 289 reason, 290 message, 291 ) 292 } 293 294 fn terminal_response( 295 request_id: String, 296 status: RadrootsSp1TradeRemoteProverStatus, 297 reason: impl Into<String>, 298 message: impl Into<String>, 299 ) -> RadrootsSp1TradeRemoteProverResponse { 300 RadrootsSp1TradeRemoteProverResponse { 301 schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, 302 request_id, 303 status, 304 status_url: None, 305 status_path: None, 306 proof_system: None, 307 proof_mode: None, 308 public_values_hash: None, 309 sp1_program_hash: None, 310 sp1_verifying_key_hash: None, 311 proof_artifact: None, 312 resolved_proof_envelope_base64: None, 313 reason_code: Some(reason.into()), 314 message: Some(message.into()), 315 detail: None, 316 } 317 } 318 319 fn request_id_from_bytes(bytes: &[u8]) -> String { 320 serde_json::from_slice::<serde_json::Value>(bytes) 321 .ok() 322 .and_then(|value| { 323 let request_id = value 324 .get("request_id") 325 .and_then(serde_json::Value::as_str) 326 .map(str::trim)?; 327 if request_id.is_empty() { 328 None 329 } else { 330 Some(request_id.to_owned()) 331 } 332 }) 333 .unwrap_or_else(|| "invalid-request".to_string()) 334 } 335 336 fn is_hash32(value: &str) -> bool { 337 value.len() == 66 338 && value.starts_with("0x") 339 && value[2..] 340 .bytes() 341 .all(|byte| byte.is_ascii_digit() || (b'a'..=b'f').contains(&byte)) 342 } 343 344 fn read_input(input: Option<&Path>) -> anyhow::Result<Vec<u8>> { 345 match input { 346 Some(path) => Ok(std::fs::read(path)?), 347 None => { 348 use std::io::Read; 349 let mut bytes = Vec::new(); 350 std::io::stdin().read_to_end(&mut bytes)?; 351 Ok(bytes) 352 } 353 } 354 } 355 356 fn write_output(output: Option<&Path>, bytes: &[u8]) -> anyhow::Result<()> { 357 match output { 358 Some(path) => { 359 std::fs::write(path, bytes)?; 360 Ok(()) 361 } 362 None => { 363 println!("{}", String::from_utf8_lossy(bytes)); 364 Ok(()) 365 } 366 } 367 } 368 369 #[cfg(test)] 370 mod tests { 371 use super::handle_request_bytes; 372 use radroots_sp1_guest_trade::{ 373 RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION, 374 RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION, 375 }; 376 use radroots_sp1_host_trade::{ 377 RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE, 378 RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode, 379 RadrootsSp1TradeRemoteProverRequest, RadrootsSp1TradeRemoteProverStatus, 380 }; 381 382 fn hash32(ch: char) -> String { 383 format!("0x{}", ch.to_string().repeat(64)) 384 } 385 386 fn request() -> RadrootsSp1TradeRemoteProverRequest { 387 let mut witness = crate::proof_smoke::order_acceptance_tiny_witness(); 388 witness.sp1_program_hash = Some(hash32('a')); 389 witness.sp1_verifying_key_hash = Some(hash32('b')); 390 let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values(&witness) 391 .expect("public values"); 392 RadrootsSp1TradeRemoteProverRequest { 393 schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, 394 request_id: "request-1".to_string(), 395 proof_target: RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET.to_string(), 396 proof_mode: RadrootsSp1TradeProofMode::Core, 397 sp1_version_line: RADROOTS_SP1_TRADE_SP1_VERSION_LINE.to_string(), 398 witness, 399 expected_sp1_program_hash: hash32('a'), 400 expected_sp1_verifying_key_hash: hash32('b'), 401 expected_public_values_hash: execution.public_values_hash, 402 expected_reducer_program_hash: RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH.to_string(), 403 expected_protocol_version: RADROOTS_SP1_TRADE_PROTOCOL_VERSION.to_string(), 404 expected_witness_version: RADROOTS_SP1_TRADE_WITNESS_VERSION, 405 } 406 } 407 408 #[tokio::test] 409 async fn remote_prove_rejects_unknown_provider_fields() { 410 let response = handle_request_bytes( 411 br#"{"schema_version":1,"request_id":"request-1","proof_target":"order_acceptance_v1","proof_mode":"core","sp1_version_line":"sp1-sdk-6.2.1","expected_sp1_program_hash":"0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","expected_sp1_verifying_key_hash":"0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","expected_public_values_hash":"0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc","expected_reducer_program_hash":"0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd","expected_protocol_version":"radroots.sp1.trade.v1","expected_witness_version":1,"provider":"runpod"}"#, 412 RadrootsSp1TradeProofEngine::Cpu, 413 ) 414 .await; 415 assert_eq!(response.request_id, "request-1"); 416 assert_eq!( 417 response.status, 418 RadrootsSp1TradeRemoteProverStatus::Rejected 419 ); 420 assert_eq!(response.reason_code.as_deref(), Some("invalid_json")); 421 } 422 423 #[tokio::test] 424 async fn remote_prove_rejects_expected_public_values_hash_mismatch() { 425 let mut request = request(); 426 request.expected_public_values_hash = hash32('c'); 427 let response = handle_request_bytes( 428 &serde_json::to_vec(&request).expect("request json"), 429 RadrootsSp1TradeProofEngine::Cpu, 430 ) 431 .await; 432 assert_eq!( 433 response.status, 434 RadrootsSp1TradeRemoteProverStatus::Rejected 435 ); 436 assert_eq!( 437 response.reason_code.as_deref(), 438 Some("public_values_hash_mismatch") 439 ); 440 } 441 442 #[tokio::test] 443 async fn remote_prove_rejects_unsupported_modes_before_generation() { 444 let mut request = request(); 445 request.proof_mode = RadrootsSp1TradeProofMode::Compressed; 446 let response = handle_request_bytes( 447 &serde_json::to_vec(&request).expect("request json"), 448 RadrootsSp1TradeProofEngine::Cpu, 449 ) 450 .await; 451 assert_eq!( 452 response.status, 453 RadrootsSp1TradeRemoteProverStatus::Rejected 454 ); 455 assert_eq!( 456 response.reason_code.as_deref(), 457 Some("unsupported_proof_mode") 458 ); 459 } 460 461 #[cfg(not(feature = "sp1_proving"))] 462 #[tokio::test] 463 async fn remote_prove_reports_generation_unavailable_without_proving_feature() { 464 let request = request(); 465 let response = handle_request_bytes( 466 &serde_json::to_vec(&request).expect("request json"), 467 RadrootsSp1TradeProofEngine::Cpu, 468 ) 469 .await; 470 assert_eq!(response.status, RadrootsSp1TradeRemoteProverStatus::Failed); 471 assert_eq!( 472 response.reason_code.as_deref(), 473 Some("proof_generation_unavailable") 474 ); 475 } 476 }