rhi

Coordinated trade for connected markets
git clone https://radroots.dev/git/rhi.git
Log | Files | Refs | README | LICENSE

commit ab5b066e3a09d740fd5997e9390db2769e2ca8f9
parent 8901cb8522ad92c5ce1d37f5dc8f9de06e8dce55
Author: triesap <tyson@radroots.org>
Date:   Thu, 21 May 2026 21:24:21 +0000

remote_prove: add remote proof worker command

- add provider-neutral remote-prove request handling
- verify expected public values before proof generation
- move remote HTTP client coverage onto sp1_verify
- update lock state for SP1 CUDA worker builds

Diffstat:
MCargo.lock | 24++++++++++++++++++++++++
MCargo.toml | 4+++-
Msrc/cli.rs | 12++++++++++++
Msrc/features/trade_validation_receipt.rs | 141+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
Msrc/lib.rs | 1+
Msrc/main.rs | 9+++++++--
Msrc/proof_smoke.rs | 6++++--
Asrc/remote_prove.rs | 476+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 files changed, 618 insertions(+), 55 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock @@ -3404,6 +3404,7 @@ version = "0.1.0-alpha.2" dependencies = [ "base64 0.22.1", "bincode", + "futures", "radroots_sp1_guest_trade", "radroots_trade", "serde", @@ -4741,6 +4742,28 @@ dependencies = [ ] [[package]] +name = "sp1-cuda" +version = "6.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b00f787fa4b5cbd29e9baddee1e590c5e689333f2b01e5f704293b7f6f17570c" +dependencies = [ + "bincode", + "bytes", + "reqwest", + "serde", + "serde_json", + "sp1-core-executor", + "sp1-core-machine", + "sp1-hypercube", + "sp1-primitives", + "sp1-prover", + "sp1-prover-types", + "thiserror 1.0.69", + "tokio", + "tracing", +] + +[[package]] name = "sp1-curves" version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5105,6 +5128,7 @@ dependencies = [ "sp1-core-executor", "sp1-core-executor-runner", "sp1-core-machine", + "sp1-cuda", "sp1-hypercube", "sp1-primitives", "sp1-prover", diff --git a/Cargo.toml b/Cargo.toml @@ -24,7 +24,9 @@ radroots_trade = { path = "../lib/crates/trade" } [features] default = [] -sp1_proving = ["radroots_sp1_host_trade/sp1_proving"] +sp1_verify = ["radroots_sp1_host_trade/sp1_verify"] +sp1_proving = ["sp1_verify", "radroots_sp1_host_trade/sp1_proving"] +sp1_cuda_proving = ["sp1_proving", "radroots_sp1_host_trade/sp1_cuda"] [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage_nightly)'] } diff --git a/src/cli.rs b/src/cli.rs @@ -27,4 +27,16 @@ pub enum Command { #[arg(long)] output: Option<PathBuf>, }, + #[command( + name = "remote-prove", + about = "Run a provider-neutral remote proof request" + )] + RemoteProve { + #[arg(long)] + input: Option<PathBuf>, + #[arg(long)] + output: Option<PathBuf>, + #[arg(long, default_value = "cpu", value_parser = ["cpu", "cuda"])] + proof_engine: String, + }, } diff --git a/src/features/trade_validation_receipt.rs b/src/features/trade_validation_receipt.rs @@ -37,11 +37,11 @@ use radroots_trade::validation_receipt::{ }; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] use std::time::Duration; use thiserror::Error; -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] use radroots_sp1_host_trade::{ RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE, RadrootsSp1TradeRemoteProverRequest, RadrootsSp1TradeRemoteProverResponse, @@ -205,7 +205,7 @@ impl TradeValidationReceiptProverPolicy { .ok_or(TradeValidationReceiptJobError::RemoteHttpConfigRequired)?; remote_http.validate()?; remote_http_auth_token(remote_http)?; - if !cfg!(feature = "sp1_proving") { + if !cfg!(feature = "sp1_verify") { return Err(TradeValidationReceiptJobError::ProverBackendUnavailable( self.backend.as_str(), )); @@ -835,7 +835,7 @@ async fn run_local_cpu_prove_backend( }) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn run_remote_http_prove_backend( witness: &RadrootsSp1TradeOrderAcceptanceWitness, policy: &TradeValidationReceiptProverPolicy, @@ -862,6 +862,7 @@ async fn run_remote_http_prove_backend( witness: witness.clone(), expected_sp1_program_hash: expected_sp1_program_hash.to_owned(), expected_sp1_verifying_key_hash: expected_sp1_verifying_key_hash.to_owned(), + expected_public_values_hash: execution.public_values_hash.clone(), expected_reducer_program_hash: RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH.to_string(), expected_protocol_version: RADROOTS_SP1_TRADE_PROTOCOL_VERSION.to_string(), expected_witness_version: RADROOTS_SP1_TRADE_WITNESS_VERSION, @@ -888,7 +889,7 @@ async fn run_remote_http_prove_backend( }) } -#[cfg(not(feature = "sp1_proving"))] +#[cfg(not(feature = "sp1_verify"))] async fn run_remote_http_prove_backend( _witness: &RadrootsSp1TradeOrderAcceptanceWitness, _policy: &TradeValidationReceiptProverPolicy, @@ -898,7 +899,7 @@ async fn run_remote_http_prove_backend( )) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_request_id( witness: &RadrootsSp1TradeOrderAcceptanceWitness, ) -> Result<String, TradeValidationReceiptJobError> { @@ -906,7 +907,7 @@ fn remote_http_request_id( Ok(hash_bytes("radroots:rhi-remote-proof-request:v1", &bytes)) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn remote_http_completed_response( config: &TradeValidationReceiptRemoteHttpProverConfig, request: &RadrootsSp1TradeRemoteProverRequest, @@ -945,7 +946,7 @@ async fn remote_http_completed_response( Err(TradeValidationReceiptJobError::RemoteHttpTimeout) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_validate_response_identity( response: &RadrootsSp1TradeRemoteProverResponse, request: &RadrootsSp1TradeRemoteProverRequest, @@ -963,7 +964,7 @@ fn remote_http_validate_response_identity( Ok(()) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_terminal_error( status: &'static str, response: RadrootsSp1TradeRemoteProverResponse, @@ -979,7 +980,7 @@ fn remote_http_terminal_error( } } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn remote_http_verified_artifact( execution: &radroots_sp1_guest_trade::RadrootsSp1TradePublicValuesExecution, policy: &TradeValidationReceiptProverPolicy, @@ -1035,7 +1036,7 @@ async fn remote_http_verified_artifact( Ok(resolved.artifact) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn verify_remote_proof_artifact_io( execution: &radroots_sp1_guest_trade::RadrootsSp1TradePublicValuesExecution, resolved: &RadrootsSp1TradeResolvedProofArtifact, @@ -1052,7 +1053,7 @@ async fn verify_remote_proof_artifact_io( Ok(()) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn remote_http_post_json_io( config: &TradeValidationReceiptRemoteHttpProverConfig, url: &str, @@ -1071,7 +1072,7 @@ async fn remote_http_post_json_io( remote_http_response_json(config, builder.send().await).await } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn remote_http_get_json_io( config: &TradeValidationReceiptRemoteHttpProverConfig, url: &str, @@ -1090,7 +1091,7 @@ async fn remote_http_get_json_io( remote_http_response_json(config, builder.send().await).await } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_client( config: &TradeValidationReceiptRemoteHttpProverConfig, ) -> Result<reqwest::Client, TradeValidationReceiptJobError> { @@ -1100,7 +1101,7 @@ fn remote_http_client( .map_err(|error| TradeValidationReceiptJobError::RemoteHttpTransport(error.to_string())) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] async fn remote_http_response_json( config: &TradeValidationReceiptRemoteHttpProverConfig, response: Result<reqwest::Response, reqwest::Error>, @@ -1133,7 +1134,7 @@ async fn remote_http_response_json( .map_err(TradeValidationReceiptJobError::Serde) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_status_url( config: &TradeValidationReceiptRemoteHttpProverConfig, response: &RadrootsSp1TradeRemoteProverResponse, @@ -1176,7 +1177,7 @@ fn remote_http_status_url( )) } -#[cfg(feature = "sp1_proving")] +#[cfg(feature = "sp1_verify")] fn remote_http_same_origin(base: &reqwest::Url, candidate: &reqwest::Url) -> bool { base.scheme() == candidate.scheme() && base.host_str() == candidate.host_str() @@ -1361,11 +1362,13 @@ struct TradeValidationReceiptTestHooks { std::collections::VecDeque<Result<RadrootsNostrEvent, TradeValidationReceiptJobError>>, publish_event_results: std::collections::VecDeque<Result<String, TradeValidationReceiptJobError>>, - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] remote_http_results: std::collections::VecDeque< Result<RadrootsSp1TradeRemoteProverResponse, TradeValidationReceiptJobError>, >, - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] + remote_http_requests: Vec<RadrootsSp1TradeRemoteProverRequest>, + #[cfg(feature = "sp1_verify")] remote_proof_verification_results: std::collections::VecDeque<Result<(), TradeValidationReceiptJobError>>, published_events: Vec<PublishedEventParts>, @@ -1410,16 +1413,21 @@ fn pop_publish_event_hook( hooks.publish_event_results.pop_front() } -#[cfg(all(test, feature = "sp1_proving"))] +#[cfg(all(test, feature = "sp1_verify"))] fn pop_remote_http_response_hook( request: &RadrootsSp1TradeRemoteProverRequest, ) -> Option<Result<RadrootsSp1TradeRemoteProverResponse, TradeValidationReceiptJobError>> { + trade_validation_receipt_test_hooks() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .remote_http_requests + .push(request.clone()); pop_remote_http_response_hook_without_request().map(|result| { result.and_then(|response| remote_http_test_response_for_request(request, response)) }) } -#[cfg(all(test, feature = "sp1_proving"))] +#[cfg(all(test, feature = "sp1_verify"))] fn pop_remote_http_response_hook_without_request() -> Option<Result<RadrootsSp1TradeRemoteProverResponse, TradeValidationReceiptJobError>> { trade_validation_receipt_test_hooks() @@ -1429,7 +1437,7 @@ fn pop_remote_http_response_hook_without_request() .pop_front() } -#[cfg(all(test, feature = "sp1_proving"))] +#[cfg(all(test, feature = "sp1_verify"))] fn remote_http_test_response_for_request( request: &RadrootsSp1TradeRemoteProverRequest, mut response: RadrootsSp1TradeRemoteProverResponse, @@ -1470,7 +1478,7 @@ fn remote_http_test_response_for_request( Ok(response) } -#[cfg(all(test, feature = "sp1_proving"))] +#[cfg(all(test, feature = "sp1_verify"))] fn pop_remote_proof_verification_hook() -> Option<Result<(), TradeValidationReceiptJobError>> { trade_validation_receipt_test_hooks() .lock() @@ -1515,13 +1523,13 @@ mod tests { RADROOTS_SP1_TRADE_PROTOCOL_VERSION, RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RadrootsSp1TradeInventoryBinWitness, }; - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] use radroots_sp1_host_trade::RadrootsSp1TradeHostError; use radroots_sp1_host_trade::RadrootsSp1TradeProofMode; - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] use radroots_sp1_host_trade::{ - RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RadrootsSp1TradeRemoteProverResponse, - RadrootsSp1TradeRemoteProverStatus, + RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RadrootsSp1TradeRemoteProverRequest, + RadrootsSp1TradeRemoteProverResponse, RadrootsSp1TradeRemoteProverStatus, }; use radroots_trade::validation_receipt::{ RadrootsValidationReceiptExpectedBinding, RadrootsValidationReceiptProofSystem, @@ -1748,7 +1756,7 @@ mod tests { } } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] fn remote_response( status: RadrootsSp1TradeRemoteProverStatus, ) -> RadrootsSp1TradeRemoteProverResponse { @@ -1771,7 +1779,7 @@ mod tests { } } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] fn remote_http_local_response_url(response: &'static str) -> String { let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("test listener"); let addr = listener.local_addr().expect("test listener address"); @@ -1784,7 +1792,7 @@ mod tests { format!("http://{addr}/prove") } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] async fn run_remote_http_job( remote_http_results: Vec< Result<RadrootsSp1TradeRemoteProverResponse, TradeValidationReceiptJobError>, @@ -1801,7 +1809,7 @@ mod tests { .await } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] async fn run_remote_http_job_with_policy( policy: TradeValidationReceiptProverPolicy, remote_http_results: Vec< @@ -1810,6 +1818,31 @@ mod tests { remote_proof_verification_results: Vec<Result<(), TradeValidationReceiptJobError>>, publish_results: Vec<Result<String, TradeValidationReceiptJobError>>, ) -> Result<Vec<super::PublishedEventParts>, TradeValidationReceiptJobError> { + run_remote_http_job_with_policy_and_requests( + policy, + remote_http_results, + remote_proof_verification_results, + publish_results, + ) + .await + .map(|(published, _)| published) + } + + #[cfg(feature = "sp1_verify")] + async fn run_remote_http_job_with_policy_and_requests( + policy: TradeValidationReceiptProverPolicy, + remote_http_results: Vec< + Result<RadrootsSp1TradeRemoteProverResponse, TradeValidationReceiptJobError>, + >, + remote_proof_verification_results: Vec<Result<(), TradeValidationReceiptJobError>>, + publish_results: Vec<Result<String, TradeValidationReceiptJobError>>, + ) -> Result< + ( + Vec<super::PublishedEventParts>, + Vec<RadrootsSp1TradeRemoteProverRequest>, + ), + TradeValidationReceiptJobError, + > { let _guard = test_guard(); let worker = RadrootsNostrKeys::generate(); let requester = RadrootsNostrKeys::generate(); @@ -1847,11 +1880,13 @@ mod tests { handle_trade_validation_receipt_job_request(&job, &worker, &client_for(&worker), &policy) .await?; - Ok(trade_validation_receipt_test_hooks() + let hooks = trade_validation_receipt_test_hooks() .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .published_events - .clone()) + .unwrap_or_else(std::sync::PoisonError::into_inner); + Ok(( + hooks.published_events.clone(), + hooks.remote_http_requests.clone(), + )) } #[test] @@ -1965,10 +2000,11 @@ mod tests { )); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_publishes_only_after_remote_artifact_verification() { - let published = run_remote_http_job( + let (published, requests) = run_remote_http_job_with_policy_and_requests( + remote_http_policy(), vec![Ok(remote_response( RadrootsSp1TradeRemoteProverStatus::Completed, ))], @@ -1995,10 +2031,15 @@ mod tests { result.sp1_execute_public_values_hash.as_deref(), Some(result.public_values_hash.as_str()) ); + assert_eq!(requests.len(), 1); + assert_eq!( + requests[0].expected_public_values_hash, + result.public_values_hash + ); assert!(result.cryptographic_proof_verified); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_polls_running_until_completed() { let mut policy = remote_http_policy(); @@ -2036,7 +2077,7 @@ mod tests { assert!(result.cryptographic_proof_verified); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_accepts_same_origin_status_url() { let mut policy = remote_http_policy(); @@ -2064,7 +2105,7 @@ mod tests { assert_eq!(published.len(), 2); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_rejects_cross_origin_status_url() { let mut accepted = remote_response(RadrootsSp1TradeRemoteProverStatus::Accepted); @@ -2090,7 +2131,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_rejects_absolute_or_scheme_relative_status_path() { let mut absolute = remote_response(RadrootsSp1TradeRemoteProverStatus::Accepted); @@ -2131,7 +2172,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_rejects_polling_request_id_mismatch_before_next_poll() { let mut accepted = remote_response(RadrootsSp1TradeRemoteProverStatus::Accepted); @@ -2163,7 +2204,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_when_verification_fails() { let error = run_remote_http_job( @@ -2193,7 +2234,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_when_reference_digest_mismatches() { let mut response = remote_response(RadrootsSp1TradeRemoteProverStatus::Completed); @@ -2221,7 +2262,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_when_sp1_identity_mismatches() { let mut response = remote_response(RadrootsSp1TradeRemoteProverStatus::Completed); @@ -2247,7 +2288,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_when_public_values_mismatch() { let mut response = remote_response(RadrootsSp1TradeRemoteProverStatus::Completed); @@ -2273,7 +2314,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_terminal_failed_or_rejected() { let mut failed = remote_response(RadrootsSp1TradeRemoteProverStatus::Failed); @@ -2327,7 +2368,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_does_not_publish_timeout_or_oversized_response() { let mut accepted = remote_response(RadrootsSp1TradeRemoteProverStatus::Accepted); @@ -2375,7 +2416,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_rejects_oversized_content_length_before_publish() { let endpoint = remote_http_local_response_url( @@ -2409,7 +2450,7 @@ mod tests { ); } - #[cfg(feature = "sp1_proving")] + #[cfg(feature = "sp1_verify")] #[tokio::test] async fn remote_http_prove_rejects_oversized_http_response_before_publish() { let endpoint = remote_http_local_response_url( diff --git a/src/lib.rs b/src/lib.rs @@ -7,6 +7,7 @@ pub mod features; pub mod identity_storage; pub mod paths; pub mod proof_smoke; +pub mod remote_prove; pub mod rhi; pub use cli::Args as cli_args; diff --git a/src/main.rs b/src/main.rs @@ -6,8 +6,10 @@ use anyhow::Result; #[cfg(not(test))] use clap::Parser; #[cfg(not(test))] -use rhi::proof_smoke; +use rhi::cli::Command; use rhi::{cli_args, config, paths, run_rhi}; +#[cfg(not(test))] +use rhi::{proof_smoke, remote_prove}; use std::path::PathBuf; use std::process::ExitCode; use tracing::info; @@ -210,7 +212,10 @@ async fn run() -> Result<()> { { let args = cli_args::try_parse().map_err(radroots_runtime::RuntimeCliError::from)?; if let Some(command) = args.command { - return proof_smoke::run_cli_command(command).await; + return match command { + Command::ProofSmoke { .. } => proof_smoke::run_cli_command(command).await, + Command::RemoteProve { .. } => remote_prove::run_cli_command(command).await, + }; } } diff --git a/src/proof_smoke.rs b/src/proof_smoke.rs @@ -86,7 +86,9 @@ pub enum RhiProofSmokeError { } pub async fn run_cli_command(command: Command) -> anyhow::Result<()> { - let Command::ProofSmoke { input, output } = command; + let Command::ProofSmoke { input, output } = command else { + return Err(anyhow::anyhow!("proof-smoke command expected")); + }; let request_bytes = read_input(input.as_deref())?; let response = handle_request_bytes(&request_bytes).await; let response_bytes = serde_json::to_vec_pretty(&response)?; @@ -288,7 +290,7 @@ fn capabilities() -> Vec<String> { values } -fn order_acceptance_tiny_witness() -> RadrootsSp1TradeOrderAcceptanceWitness { +pub(crate) fn order_acceptance_tiny_witness() -> RadrootsSp1TradeOrderAcceptanceWitness { RadrootsSp1TradeOrderAcceptanceWitness { witness_version: RADROOTS_SP1_TRADE_WITNESS_VERSION, proof_target: RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET.to_string(), diff --git a/src/remote_prove.rs b/src/remote_prove.rs @@ -0,0 +1,476 @@ +#![forbid(unsafe_code)] +#![cfg_attr(coverage_nightly, coverage(off))] + +use crate::cli::Command; +use radroots_sp1_guest_trade::{ + RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION, + RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION, +}; +use radroots_sp1_host_trade::{ + RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE, + RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode, RadrootsSp1TradeRemoteProverRequest, + RadrootsSp1TradeRemoteProverResponse, RadrootsSp1TradeRemoteProverStatus, +}; +use std::path::Path; + +pub async fn run_cli_command(command: Command) -> anyhow::Result<()> { + let Command::RemoteProve { + input, + output, + proof_engine, + } = command + else { + return Err(anyhow::anyhow!("remote-prove command expected")); + }; + let engine = RadrootsSp1TradeProofEngine::from_label(proof_engine.as_str()) + .ok_or_else(|| anyhow::anyhow!("invalid proof engine"))?; + let request_bytes = read_input(input.as_deref())?; + let response = handle_request_bytes(&request_bytes, engine).await; + let response_bytes = serde_json::to_vec_pretty(&response)?; + write_output(output.as_deref(), &response_bytes)?; + if response.status == RadrootsSp1TradeRemoteProverStatus::Completed { + Ok(()) + } else { + Err(anyhow::anyhow!( + "{}", + response + .message + .as_deref() + .unwrap_or("remote proof request did not complete") + )) + } +} + +pub async fn handle_request_bytes( + bytes: &[u8], + engine: RadrootsSp1TradeProofEngine, +) -> RadrootsSp1TradeRemoteProverResponse { + let request_id = request_id_from_bytes(bytes); + let request = match serde_json::from_slice::<RadrootsSp1TradeRemoteProverRequest>(bytes) { + Ok(request) => request, + Err(error) => { + return rejected_response(request_id, "invalid_json", error.to_string()); + } + }; + match validate_request(&request) { + Ok(()) => prove_request(request, engine).await, + Err(rejection) => { + rejected_response(request.request_id, rejection.reason, rejection.message) + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct RemoteProveRejection { + reason: &'static str, + message: &'static str, +} + +fn validate_request( + request: &RadrootsSp1TradeRemoteProverRequest, +) -> Result<(), RemoteProveRejection> { + if request.schema_version != RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION { + return Err(rejection( + "invalid_schema_version", + "invalid schema_version", + )); + } + if request.request_id.trim().is_empty() { + return Err(rejection("invalid_request_id", "invalid request_id")); + } + if request.proof_target != RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET { + return Err(rejection( + "unsupported_proof_target", + "unsupported proof_target", + )); + } + if request.proof_mode != RadrootsSp1TradeProofMode::Core { + return Err(rejection( + "unsupported_proof_mode", + "unsupported proof_mode", + )); + } + if request.sp1_version_line != RADROOTS_SP1_TRADE_SP1_VERSION_LINE { + return Err(rejection( + "unsupported_sp1_version", + "unsupported sp1 version", + )); + } + if request.expected_reducer_program_hash != RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH { + return Err(rejection( + "reducer_program_hash_mismatch", + "expected reducer program hash mismatch", + )); + } + if request.expected_protocol_version != RADROOTS_SP1_TRADE_PROTOCOL_VERSION { + return Err(rejection( + "protocol_version_mismatch", + "expected protocol version mismatch", + )); + } + if request.expected_witness_version != RADROOTS_SP1_TRADE_WITNESS_VERSION { + return Err(rejection( + "witness_version_mismatch", + "expected witness version mismatch", + )); + } + if request.witness.witness_version != request.expected_witness_version { + return Err(rejection( + "witness_version_mismatch", + "witness version mismatch", + )); + } + if request.witness.proof_target != request.proof_target { + return Err(rejection( + "proof_target_mismatch", + "witness proof target mismatch", + )); + } + if request.witness.reducer_program_hash != request.expected_reducer_program_hash { + return Err(rejection( + "reducer_program_hash_mismatch", + "witness reducer program hash mismatch", + )); + } + if request.witness.radroots_protocol_version != request.expected_protocol_version { + return Err(rejection( + "protocol_version_mismatch", + "witness protocol version mismatch", + )); + } + if request.witness.sp1_program_hash.as_deref() + != Some(request.expected_sp1_program_hash.as_str()) + { + return Err(rejection( + "sp1_program_hash_mismatch", + "witness SP1 program hash mismatch", + )); + } + if request.witness.sp1_verifying_key_hash.as_deref() + != Some(request.expected_sp1_verifying_key_hash.as_str()) + { + return Err(rejection( + "sp1_verifying_key_hash_mismatch", + "witness SP1 verifying key hash mismatch", + )); + } + for value in [ + request.expected_sp1_program_hash.as_str(), + request.expected_sp1_verifying_key_hash.as_str(), + request.expected_public_values_hash.as_str(), + ] { + if !is_hash32(value) { + return Err(rejection("invalid_hash", "expected hash field is invalid")); + } + } + let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values( + &request.witness, + ) + .map_err(|_| { + rejection( + "public_values_execution_failed", + "public values execution failed", + ) + })?; + if execution.public_values_hash != request.expected_public_values_hash { + return Err(rejection( + "public_values_hash_mismatch", + "expected public values hash mismatch", + )); + } + Ok(()) +} + +fn rejection(reason: &'static str, message: &'static str) -> RemoteProveRejection { + RemoteProveRejection { reason, message } +} + +#[cfg(feature = "sp1_proving")] +async fn prove_request( + request: RadrootsSp1TradeRemoteProverRequest, + engine: RadrootsSp1TradeProofEngine, +) -> RadrootsSp1TradeRemoteProverResponse { + match radroots_sp1_host_trade::generate_order_acceptance_sp1_proof_with_engine( + &request.witness, + request.proof_mode, + engine, + ) + .await + { + Ok(bundle) => completed_response(request, bundle), + Err(error) => failed_response( + request.request_id, + "proof_generation_failed", + error.to_string(), + ), + } +} + +#[cfg(not(feature = "sp1_proving"))] +async fn prove_request( + request: RadrootsSp1TradeRemoteProverRequest, + _engine: RadrootsSp1TradeProofEngine, +) -> RadrootsSp1TradeRemoteProverResponse { + failed_response( + request.request_id, + "proof_generation_unavailable", + "remote-prove requires the sp1_proving feature".to_string(), + ) +} + +#[cfg(feature = "sp1_proving")] +fn completed_response( + request: RadrootsSp1TradeRemoteProverRequest, + bundle: radroots_sp1_host_trade::RadrootsSp1TradeProofBundle, +) -> RadrootsSp1TradeRemoteProverResponse { + if bundle.execution.public_values_hash != request.expected_public_values_hash { + return failed_response( + request.request_id, + "public_values_hash_mismatch", + "generated public values hash mismatch".to_string(), + ); + } + if bundle.proof.program_hash.as_deref() != Some(request.expected_sp1_program_hash.as_str()) { + return failed_response( + request.request_id, + "sp1_program_hash_mismatch", + "generated SP1 program hash mismatch".to_string(), + ); + } + if bundle.proof.verifying_key_hash.as_deref() + != Some(request.expected_sp1_verifying_key_hash.as_str()) + { + return failed_response( + request.request_id, + "sp1_verifying_key_hash_mismatch", + "generated SP1 verifying key hash mismatch".to_string(), + ); + } + RadrootsSp1TradeRemoteProverResponse { + schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, + request_id: request.request_id, + status: RadrootsSp1TradeRemoteProverStatus::Completed, + status_url: None, + status_path: None, + proof_system: Some(request.proof_mode.proof_system()), + proof_mode: Some(request.proof_mode), + public_values_hash: Some(bundle.execution.public_values_hash), + sp1_program_hash: bundle.proof.program_hash.clone(), + sp1_verifying_key_hash: bundle.proof.verifying_key_hash.clone(), + proof_artifact: Some(bundle.proof), + resolved_proof_envelope_base64: None, + reason_code: None, + message: None, + detail: None, + } +} + +fn rejected_response( + request_id: String, + reason: impl Into<String>, + message: impl Into<String>, +) -> RadrootsSp1TradeRemoteProverResponse { + terminal_response( + request_id, + RadrootsSp1TradeRemoteProverStatus::Rejected, + reason, + message, + ) +} + +fn failed_response( + request_id: String, + reason: impl Into<String>, + message: impl Into<String>, +) -> RadrootsSp1TradeRemoteProverResponse { + terminal_response( + request_id, + RadrootsSp1TradeRemoteProverStatus::Failed, + reason, + message, + ) +} + +fn terminal_response( + request_id: String, + status: RadrootsSp1TradeRemoteProverStatus, + reason: impl Into<String>, + message: impl Into<String>, +) -> RadrootsSp1TradeRemoteProverResponse { + RadrootsSp1TradeRemoteProverResponse { + schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, + request_id, + status, + status_url: None, + status_path: None, + proof_system: None, + proof_mode: None, + public_values_hash: None, + sp1_program_hash: None, + sp1_verifying_key_hash: None, + proof_artifact: None, + resolved_proof_envelope_base64: None, + reason_code: Some(reason.into()), + message: Some(message.into()), + detail: None, + } +} + +fn request_id_from_bytes(bytes: &[u8]) -> String { + serde_json::from_slice::<serde_json::Value>(bytes) + .ok() + .and_then(|value| { + let request_id = value + .get("request_id") + .and_then(serde_json::Value::as_str) + .map(str::trim)?; + if request_id.is_empty() { + None + } else { + Some(request_id.to_owned()) + } + }) + .unwrap_or_else(|| "invalid-request".to_string()) +} + +fn is_hash32(value: &str) -> bool { + value.len() == 66 + && value.starts_with("0x") + && value[2..] + .bytes() + .all(|byte| byte.is_ascii_digit() || (b'a'..=b'f').contains(&byte)) +} + +fn read_input(input: Option<&Path>) -> anyhow::Result<Vec<u8>> { + match input { + Some(path) => Ok(std::fs::read(path)?), + None => { + use std::io::Read; + let mut bytes = Vec::new(); + std::io::stdin().read_to_end(&mut bytes)?; + Ok(bytes) + } + } +} + +fn write_output(output: Option<&Path>, bytes: &[u8]) -> anyhow::Result<()> { + match output { + Some(path) => { + std::fs::write(path, bytes)?; + Ok(()) + } + None => { + println!("{}", String::from_utf8_lossy(bytes)); + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::handle_request_bytes; + use radroots_sp1_guest_trade::{ + RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET, RADROOTS_SP1_TRADE_PROTOCOL_VERSION, + RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH, RADROOTS_SP1_TRADE_WITNESS_VERSION, + }; + use radroots_sp1_host_trade::{ + RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, RADROOTS_SP1_TRADE_SP1_VERSION_LINE, + RadrootsSp1TradeProofEngine, RadrootsSp1TradeProofMode, + RadrootsSp1TradeRemoteProverRequest, RadrootsSp1TradeRemoteProverStatus, + }; + + fn hash32(ch: char) -> String { + format!("0x{}", ch.to_string().repeat(64)) + } + + fn request() -> RadrootsSp1TradeRemoteProverRequest { + let mut witness = crate::proof_smoke::order_acceptance_tiny_witness(); + witness.sp1_program_hash = Some(hash32('a')); + witness.sp1_verifying_key_hash = Some(hash32('b')); + let execution = radroots_sp1_host_trade::execute_order_acceptance_public_values(&witness) + .expect("public values"); + RadrootsSp1TradeRemoteProverRequest { + schema_version: RADROOTS_SP1_TRADE_REMOTE_PROVER_SCHEMA_VERSION, + request_id: "request-1".to_string(), + proof_target: RADROOTS_SP1_TRADE_ORDER_ACCEPTANCE_PROOF_TARGET.to_string(), + proof_mode: RadrootsSp1TradeProofMode::Core, + sp1_version_line: RADROOTS_SP1_TRADE_SP1_VERSION_LINE.to_string(), + witness, + expected_sp1_program_hash: hash32('a'), + expected_sp1_verifying_key_hash: hash32('b'), + expected_public_values_hash: execution.public_values_hash, + expected_reducer_program_hash: RADROOTS_SP1_TRADE_REDUCER_PROGRAM_HASH.to_string(), + expected_protocol_version: RADROOTS_SP1_TRADE_PROTOCOL_VERSION.to_string(), + expected_witness_version: RADROOTS_SP1_TRADE_WITNESS_VERSION, + } + } + + #[tokio::test] + async fn remote_prove_rejects_unknown_provider_fields() { + let response = handle_request_bytes( + br#"{"schema_version":1,"request_id":"request-1","proof_target":"order_acceptance_v1","proof_mode":"core","sp1_version_line":"sp1-sdk-6.2.1","expected_sp1_program_hash":"0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa","expected_sp1_verifying_key_hash":"0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb","expected_public_values_hash":"0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc","expected_reducer_program_hash":"0xdddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd","expected_protocol_version":"radroots.sp1.trade.v1","expected_witness_version":1,"provider":"runpod"}"#, + RadrootsSp1TradeProofEngine::Cpu, + ) + .await; + assert_eq!(response.request_id, "request-1"); + assert_eq!( + response.status, + RadrootsSp1TradeRemoteProverStatus::Rejected + ); + assert_eq!(response.reason_code.as_deref(), Some("invalid_json")); + } + + #[tokio::test] + async fn remote_prove_rejects_expected_public_values_hash_mismatch() { + let mut request = request(); + request.expected_public_values_hash = hash32('c'); + let response = handle_request_bytes( + &serde_json::to_vec(&request).expect("request json"), + RadrootsSp1TradeProofEngine::Cpu, + ) + .await; + assert_eq!( + response.status, + RadrootsSp1TradeRemoteProverStatus::Rejected + ); + assert_eq!( + response.reason_code.as_deref(), + Some("public_values_hash_mismatch") + ); + } + + #[tokio::test] + async fn remote_prove_rejects_unsupported_modes_before_generation() { + let mut request = request(); + request.proof_mode = RadrootsSp1TradeProofMode::Compressed; + let response = handle_request_bytes( + &serde_json::to_vec(&request).expect("request json"), + RadrootsSp1TradeProofEngine::Cpu, + ) + .await; + assert_eq!( + response.status, + RadrootsSp1TradeRemoteProverStatus::Rejected + ); + assert_eq!( + response.reason_code.as_deref(), + Some("unsupported_proof_mode") + ); + } + + #[cfg(not(feature = "sp1_proving"))] + #[tokio::test] + async fn remote_prove_reports_generation_unavailable_without_proving_feature() { + let request = request(); + let response = handle_request_bytes( + &serde_json::to_vec(&request).expect("request json"), + RadrootsSp1TradeProofEngine::Cpu, + ) + .await; + assert_eq!(response.status, RadrootsSp1TradeRemoteProverStatus::Failed); + assert_eq!( + response.reason_code.as_deref(), + Some("proof_generation_unavailable") + ); + } +}