rhi

Coordinated trade for connected markets
git clone https://radroots.dev/git/rhi.git
Log | Files | Refs | README | LICENSE

remote_prove.rs (16788B)


      1 #![forbid(unsafe_code)]
      2 #![cfg_attr(coverage_nightly, coverage(off))]
      3 
      4 use crate::cli::Command;
      5 use radroots_sp1_guest_trade::{
      6     RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION,
      7     RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION,
      8 };
      9 use radroots_sp1_host_trade::{
     10     RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE,
     11     RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode, RadrootsSp1TradeRemoteProverRequest,
     12     RadrootsSp1TradeRemoteProverResponse, RadrootsSp1TradeRemoteProverStatus,
     13 };
     14 use std::path::Path;
     15 
     16 pub async fn run_cli_command(command: Command) -> anyhow::Result<()> {
     17     let Command::RemoteProve {
     18         input,
     19         output,
     20         proof_engine,
     21     } = command
     22     else {
     23         return Err(anyhow::anyhow!("remote-prove command expected"));
     24     };
     25     let engine = RadrootsSp1TradeProofEngine::from_label(proof_engine.as_str())
     26         .ok_or_else(|| anyhow::anyhow!("invalid proof engine"))?;
     27     let request_bytes = read_input(input.as_deref())?;
     28     let response = handle_request_bytes(&request_bytes, engine).await;
     29     let response_bytes = serde_json::to_vec_pretty(&response)?;
     30     write_output(output.as_deref(), &response_bytes)?;
     31     if response.status == RadrootsSp1TradeRemoteProverStatus::Completed {
     32         Ok(())
     33     } else {
     34         Err(anyhow::anyhow!(
     35             "{}",
     36             response
     37                 .message
     38                 .as_deref()
     39                 .unwrap_or("remote proof request did not complete")
     40         ))
     41     }
     42 }
     43 
     44 pub async fn handle_request_bytes(
     45     bytes: &[u8],
     46     engine: RadrootsSp1TradeProofEngine,
     47 ) -> RadrootsSp1TradeRemoteProverResponse {
     48     let request_id = request_id_from_bytes(bytes);
     49     let request = match serde_json::from_slice::<RadrootsSp1TradeRemoteProverRequest>(bytes) {
     50         Ok(request) => request,
     51         Err(error) => {
     52             return rejected_response(request_id, "invalid_json", error.to_string());
     53         }
     54     };
     55     match validate_request(&request) {
     56         Ok(()) => prove_request(request, engine).await,
     57         Err(rejection) => {
     58             rejected_response(request.request_id, rejection.reason, rejection.message)
     59         }
     60     }
     61 }
     62 
     63 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
     64 struct RemoteProveRejection {
     65     reason: &'static str,
     66     message: &'static str,
     67 }
     68 
     69 fn validate_request(
     70     request: &RadrootsSp1TradeRemoteProverRequest,
     71 ) -> Result<(), RemoteProveRejection> {
     72     if request.schema_version != RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION {
     73         return Err(rejection(
     74             "invalid_schema_version",
     75             "invalid schema_version",
     76         ));
     77     }
     78     if request.request_id.trim().is_empty() {
     79         return Err(rejection("invalid_request_id", "invalid request_id"));
     80     }
     81     if request.proof_target != RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET {
     82         return Err(rejection(
     83             "unsupported_proof_target",
     84             "unsupported proof_target",
     85         ));
     86     }
     87     if request.proof_mode != RadrootsSp1TradeProofMode::Core {
     88         return Err(rejection(
     89             "unsupported_proof_mode",
     90             "unsupported proof_mode",
     91         ));
     92     }
     93     if request.sp1_version_line != RADROOTS_SP1_TRADE_SP1_VERSION_LINE {
     94         return Err(rejection(
     95             "unsupported_sp1_version",
     96             "unsupported sp1 version",
     97         ));
     98     }
     99     if request.expected_reducer_program_hash != RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH {
    100         return Err(rejection(
    101             "reducer_program_hash_mismatch",
    102             "expected reducer program hash mismatch",
    103         ));
    104     }
    105     if request.expected_protocol_version != RADROOTS_SP1_TRADE_PROTOCOL_VERSION {
    106         return Err(rejection(
    107             "protocol_version_mismatch",
    108             "expected protocol version mismatch",
    109         ));
    110     }
    111     if request.expected_witness_version != RADROOTS_SP1_TRADE_WITNESS_VERSION {
    112         return Err(rejection(
    113             "witness_version_mismatch",
    114             "expected witness version mismatch",
    115         ));
    116     }
    117     if request.witness.witness_version != request.expected_witness_version {
    118         return Err(rejection(
    119             "witness_version_mismatch",
    120             "witness version mismatch",
    121         ));
    122     }
    123     if request.witness.proof_target != request.proof_target {
    124         return Err(rejection(
    125             "proof_target_mismatch",
    126             "witness proof target mismatch",
    127         ));
    128     }
    129     if request.witness.reducer_program_hash != request.expected_reducer_program_hash {
    130         return Err(rejection(
    131             "reducer_program_hash_mismatch",
    132             "witness reducer program hash mismatch",
    133         ));
    134     }
    135     if request.witness.radroots_protocol_version != request.expected_protocol_version {
    136         return Err(rejection(
    137             "protocol_version_mismatch",
    138             "witness protocol version mismatch",
    139         ));
    140     }
    141     if request.witness.sp1_program_hash.as_deref()
    142         != Some(request.expected_sp1_program_hash.as_str())
    143     {
    144         return Err(rejection(
    145             "sp1_program_hash_mismatch",
    146             "witness SP1 program hash mismatch",
    147         ));
    148     }
    149     if request.witness.sp1_verifying_key_hash.as_deref()
    150         != Some(request.expected_sp1_verifying_key_hash.as_str())
    151     {
    152         return Err(rejection(
    153             "sp1_verifying_key_hash_mismatch",
    154             "witness SP1 verifying key hash mismatch",
    155         ));
    156     }
    157     for value in [
    158         request.expected_sp1_program_hash.as_str(),
    159         request.expected_sp1_verifying_key_hash.as_str(),
    160         request.expected_public_values_hash.as_str(),
    161     ] {
    162         if !is_hash32(value) {
    163             return Err(rejection("invalid_hash", "expected hash field is invalid"));
    164         }
    165     }
    166     let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values(
    167         &request.witness,
    168     )
    169     .map_err(|_| {
    170         rejection(
    171             "public_values_execution_failed",
    172             "public values execution failed",
    173         )
    174     })?;
    175     if execution.public_values_hash != request.expected_public_values_hash {
    176         return Err(rejection(
    177             "public_values_hash_mismatch",
    178             "expected public values hash mismatch",
    179         ));
    180     }
    181     Ok(())
    182 }
    183 
    184 fn rejection(reason: &'static str, message: &'static str) -> RemoteProveRejection {
    185     RemoteProveRejection { reason, message }
    186 }
    187 
    188 #[cfg(feature = "sp1_proving")]
    189 async fn prove_request(
    190     request: RadrootsSp1TradeRemoteProverRequest,
    191     engine: RadrootsSp1TradeProofEngine,
    192 ) -> RadrootsSp1TradeRemoteProverResponse {
    193     match radroots_sp1_host_trade::generate_order_acceptance_sp1_proof_with_engine(
    194         &request.witness,
    195         request.proof_mode,
    196         engine,
    197     )
    198     .await
    199     {
    200         Ok(bundle) => completed_response(request, bundle),
    201         Err(error) => failed_response(
    202             request.request_id,
    203             "proof_generation_failed",
    204             error.to_string(),
    205         ),
    206     }
    207 }
    208 
    209 #[cfg(not(feature = "sp1_proving"))]
    210 async fn prove_request(
    211     request: RadrootsSp1TradeRemoteProverRequest,
    212     _engine: RadrootsSp1TradeProofEngine,
    213 ) -> RadrootsSp1TradeRemoteProverResponse {
    214     failed_response(
    215         request.request_id,
    216         "proof_generation_unavailable",
    217         "remote-prove requires the sp1_proving feature".to_string(),
    218     )
    219 }
    220 
    221 #[cfg(feature = "sp1_proving")]
    222 fn completed_response(
    223     request: RadrootsSp1TradeRemoteProverRequest,
    224     bundle: radroots_sp1_host_trade::RadrootsSp1TradeProofBundle,
    225 ) -> RadrootsSp1TradeRemoteProverResponse {
    226     if bundle.execution.public_values_hash != request.expected_public_values_hash {
    227         return failed_response(
    228             request.request_id,
    229             "public_values_hash_mismatch",
    230             "generated public values hash mismatch".to_string(),
    231         );
    232     }
    233     if bundle.proof.program_hash.as_deref() != Some(request.expected_sp1_program_hash.as_str()) {
    234         return failed_response(
    235             request.request_id,
    236             "sp1_program_hash_mismatch",
    237             "generated SP1 program hash mismatch".to_string(),
    238         );
    239     }
    240     if bundle.proof.verifying_key_hash.as_deref()
    241         != Some(request.expected_sp1_verifying_key_hash.as_str())
    242     {
    243         return failed_response(
    244             request.request_id,
    245             "sp1_verifying_key_hash_mismatch",
    246             "generated SP1 verifying key hash mismatch".to_string(),
    247         );
    248     }
    249     RadrootsSp1TradeRemoteProverResponse {
    250         schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION,
    251         request_id: request.request_id,
    252         status: RadrootsSp1TradeRemoteProverStatus::Completed,
    253         status_url: None,
    254         status_path: None,
    255         proof_system: Some(request.proof_mode.proof_system()),
    256         proof_mode: Some(request.proof_mode),
    257         public_values_hash: Some(bundle.execution.public_values_hash),
    258         sp1_program_hash: bundle.proof.program_hash.clone(),
    259         sp1_verifying_key_hash: bundle.proof.verifying_key_hash.clone(),
    260         proof_artifact: Some(bundle.proof),
    261         resolved_proof_envelope_base64: None,
    262         reason_code: None,
    263         message: None,
    264         detail: None,
    265     }
    266 }
    267 
    268 fn rejected_response(
    269     request_id: String,
    270     reason: impl Into<String>,
    271     message: impl Into<String>,
    272 ) -> RadrootsSp1TradeRemoteProverResponse {
    273     terminal_response(
    274         request_id,
    275         RadrootsSp1TradeRemoteProverStatus::Rejected,
    276         reason,
    277         message,
    278     )
    279 }
    280 
    281 fn failed_response(
    282     request_id: String,
    283     reason: impl Into<String>,
    284     message: impl Into<String>,
    285 ) -> RadrootsSp1TradeRemoteProverResponse {
    286     terminal_response(
    287         request_id,
    288         RadrootsSp1TradeRemoteProverStatus::Failed,
    289         reason,
    290         message,
    291     )
    292 }
    293 
    294 fn terminal_response(
    295     request_id: String,
    296     status: RadrootsSp1TradeRemoteProverStatus,
    297     reason: impl Into<String>,
    298     message: impl Into<String>,
    299 ) -> RadrootsSp1TradeRemoteProverResponse {
    300     RadrootsSp1TradeRemoteProverResponse {
    301         schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION,
    302         request_id,
    303         status,
    304         status_url: None,
    305         status_path: None,
    306         proof_system: None,
    307         proof_mode: None,
    308         public_values_hash: None,
    309         sp1_program_hash: None,
    310         sp1_verifying_key_hash: None,
    311         proof_artifact: None,
    312         resolved_proof_envelope_base64: None,
    313         reason_code: Some(reason.into()),
    314         message: Some(message.into()),
    315         detail: None,
    316     }
    317 }
    318 
    319 fn request_id_from_bytes(bytes: &[u8]) -> String {
    320     serde_json::from_slice::<serde_json::Value>(bytes)
    321         .ok()
    322         .and_then(|value| {
    323             let request_id = value
    324                 .get("request_id")
    325                 .and_then(serde_json::Value::as_str)
    326                 .map(str::trim)?;
    327             if request_id.is_empty() {
    328                 None
    329             } else {
    330                 Some(request_id.to_owned())
    331             }
    332         })
    333         .unwrap_or_else(|| "invalid-request".to_string())
    334 }
    335 
    336 fn is_hash32(value: &str) -> bool {
    337     value.len() == 66
    338         && value.starts_with("0x")
    339         && value[2..]
    340             .bytes()
    341             .all(|byte| byte.is_ascii_digit() || (b'a'..=b'f').contains(&byte))
    342 }
    343 
    344 fn read_input(input: Option<&Path>) -> anyhow::Result<Vec<u8>> {
    345     match input {
    346         Some(path) => Ok(std::fs::read(path)?),
    347         None => {
    348             use std::io::Read;
    349             let mut bytes = Vec::new();
    350             std::io::stdin().read_to_end(&mut bytes)?;
    351             Ok(bytes)
    352         }
    353     }
    354 }
    355 
    356 fn write_output(output: Option<&Path>, bytes: &[u8]) -> anyhow::Result<()> {
    357     match output {
    358         Some(path) => {
    359             std::fs::write(path, bytes)?;
    360             Ok(())
    361         }
    362         None => {
    363             println!("{}", String::from_utf8_lossy(bytes));
    364             Ok(())
    365         }
    366     }
    367 }
    368 
    369 #[cfg(test)]
    370 mod tests {
    371     use super::handle_request_bytes;
    372     use radroots_sp1_guest_trade::{
    373         RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION,
    374         RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION,
    375     };
    376     use radroots_sp1_host_trade::{
    377         RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE,
    378         RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode,
    379         RadrootsSp1TradeRemoteProverRequest, RadrootsSp1TradeRemoteProverStatus,
    380     };
    381 
    382     fn hash32(ch: char) -> String {
    383         format!("0x{}", ch.to_string().repeat(64))
    384     }
    385 
    386     fn request() -> RadrootsSp1TradeRemoteProverRequest {
    387         let mut witness = crate::proof_smoke::order_acceptance_tiny_witness();
    388         witness.sp1_program_hash = Some(hash32('a'));
    389         witness.sp1_verifying_key_hash = Some(hash32('b'));
    390         let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values(&witness)
    391             .expect("public values");
    392         RadrootsSp1TradeRemoteProverRequest {
    393             schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION,
    394             request_id: "request-1".to_string(),
    395             proof_target: RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET.to_string(),
    396             proof_mode: RadrootsSp1TradeProofMode::Core,
    397             sp1_version_line: RADROOTS_SP1_TRADE_SP1_VERSION_LINE.to_string(),
    398             witness,
    399             expected_sp1_program_hash: hash32('a'),
    400             expected_sp1_verifying_key_hash: hash32('b'),
    401             expected_public_values_hash: execution.public_values_hash,
    402             expected_reducer_program_hash: RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH.to_string(),
    403             expected_protocol_version: RADROOTS_SP1_TRADE_PROTOCOL_VERSION.to_string(),
    404             expected_witness_version: RADROOTS_SP1_TRADE_WITNESS_VERSION,
    405         }
    406     }
    407 
    408     #[tokio::test]
    409     async fn remote_prove_rejects_unknown_provider_fields() {
    410         let response = handle_request_bytes(
    411             br#"{"schema_version":1,"request_id":"request-1","proof_target":"order_acceptance_v1","proof_mode":"core","sp1_version_line":"sp1-sdk-6.2.1","expected_sp1_program_hash":"0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","expected_sp1_verifying_key_hash":"0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","expected_public_values_hash":"0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc","expected_reducer_program_hash":"0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd","expected_protocol_version":"radroots.sp1.trade.v1","expected_witness_version":1,"provider":"runpod"}"#,
    412             RadrootsSp1TradeProofEngine::Cpu,
    413         )
    414         .await;
    415         assert_eq!(response.request_id, "request-1");
    416         assert_eq!(
    417             response.status,
    418             RadrootsSp1TradeRemoteProverStatus::Rejected
    419         );
    420         assert_eq!(response.reason_code.as_deref(), Some("invalid_json"));
    421     }
    422 
    423     #[tokio::test]
    424     async fn remote_prove_rejects_expected_public_values_hash_mismatch() {
    425         let mut request = request();
    426         request.expected_public_values_hash = hash32('c');
    427         let response = handle_request_bytes(
    428             &serde_json::to_vec(&request).expect("request json"),
    429             RadrootsSp1TradeProofEngine::Cpu,
    430         )
    431         .await;
    432         assert_eq!(
    433             response.status,
    434             RadrootsSp1TradeRemoteProverStatus::Rejected
    435         );
    436         assert_eq!(
    437             response.reason_code.as_deref(),
    438             Some("public_values_hash_mismatch")
    439         );
    440     }
    441 
    442     #[tokio::test]
    443     async fn remote_prove_rejects_unsupported_modes_before_generation() {
    444         let mut request = request();
    445         request.proof_mode = RadrootsSp1TradeProofMode::Compressed;
    446         let response = handle_request_bytes(
    447             &serde_json::to_vec(&request).expect("request json"),
    448             RadrootsSp1TradeProofEngine::Cpu,
    449         )
    450         .await;
    451         assert_eq!(
    452             response.status,
    453             RadrootsSp1TradeRemoteProverStatus::Rejected
    454         );
    455         assert_eq!(
    456             response.reason_code.as_deref(),
    457             Some("unsupported_proof_mode")
    458         );
    459     }
    460 
    461     #[cfg(not(feature = "sp1_proving"))]
    462     #[tokio::test]
    463     async fn remote_prove_reports_generation_unavailable_without_proving_feature() {
    464         let request = request();
    465         let response = handle_request_bytes(
    466             &serde_json::to_vec(&request).expect("request json"),
    467             RadrootsSp1TradeProofEngine::Cpu,
    468         )
    469         .await;
    470         assert_eq!(response.status, RadrootsSp1TradeRemoteProverStatus::Failed);
    471         assert_eq!(
    472             response.reason_code.as_deref(),
    473             Some("proof_generation_unavailable")
    474         );
    475     }
    476 }