rhi

Coordinated trade for connected markets
git clone https://radroots.dev/git/rhi.git
Log | Files | Refs | README | LICENSE

state.rs (18803B)


      1 #![forbid(unsafe_code)]
      2 
      3 use std::collections::{HashMap, HashSet};
      4 use std::path::{Path, PathBuf};
      5 use std::sync::Arc;
      6 
      7 use radroots_nostr::prelude::{RadrootsNostrFilter, RadrootsNostrKind, RadrootsNostrTimestamp};
      8 use serde::{Deserialize, Serialize};
      9 use thiserror::Error;
     10 use tokio::sync::Mutex;
     11 
     12 pub type SharedTradeListingState = Arc<Mutex<TradeListingState>>;
     13 
     14 const TRADE_LISTING_STATE_VERSION: u32 = 1;
     15 
     16 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
     17 pub enum TradeOrderStatus {
     18     Requested,
     19     Accepted,
     20     Declined,
     21     Cancelled,
     22     Completed,
     23     Disputed,
     24     Invalid,
     25 }
     26 
     27 #[derive(Clone, Debug, Serialize, Deserialize)]
     28 pub struct TradeOrderState {
     29     pub order_id: String,
     30     pub listing_addr: String,
     31     pub buyer_pubkey: String,
     32     pub seller_pubkey: String,
     33     pub status: TradeOrderStatus,
     34     #[serde(default)]
     35     pub listing_snapshot_event_id: Option<String>,
     36     #[serde(default)]
     37     pub root_event_id: Option<String>,
     38     #[serde(default)]
     39     pub last_event_id: Option<String>,
     40     pub seen_event_ids: HashSet<String>,
     41 }
     42 
     43 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
     44 pub struct ValidatedListingState {
     45     pub event_id: String,
     46 }
     47 
     48 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
     49 pub struct ListingEventState {
     50     pub event_id: String,
     51     pub kind: u32,
     52 }
     53 
     54 #[derive(Debug, Default, Clone, Serialize, Deserialize)]
     55 pub struct TradeListingState {
     56     #[serde(default)]
     57     validated_listings: HashSet<String>,
     58     #[serde(default)]
     59     validated_listing_events: HashMap<String, ValidatedListingState>,
     60     #[serde(default)]
     61     listing_events: HashMap<String, ListingEventState>,
     62     #[serde(default)]
     63     seen_non_order_event_ids: HashSet<String>,
     64     orders: HashMap<String, TradeOrderState>,
     65     last_event_created_at: Option<u32>,
     66 }
     67 
     68 #[derive(Clone, Debug)]
     69 pub struct TradeListingRuntime {
     70     state: SharedTradeListingState,
     71     config: TradeListingRuntimeConfig,
     72     persistence: Option<Arc<TradeListingStatePersistence>>,
     73 }
     74 
     75 #[derive(Clone, Debug, Serialize, Deserialize)]
     76 pub struct TradeListingRuntimeConfig {
     77     pub state_path: PathBuf,
     78     pub replay_window_secs: u64,
     79     pub replay_overlap_secs: u64,
     80 }
     81 
     82 #[derive(Clone, Debug)]
     83 struct TradeListingStatePersistence {
     84     path: PathBuf,
     85 }
     86 
     87 #[derive(Debug, Serialize, Deserialize)]
     88 struct PersistedTradeListingState {
     89     version: u32,
     90     state: TradeListingState,
     91 }
     92 
     93 impl Default for TradeListingRuntimeConfig {
     94     fn default() -> Self {
     95         Self {
     96             state_path: crate::paths::default_subscriber_state_path_for_process()
     97                 .expect("resolve canonical rhi trade-listing state path"),
     98             replay_window_secs: 24 * 60 * 60,
     99             replay_overlap_secs: 5 * 60,
    100         }
    101     }
    102 }
    103 
    104 impl Default for TradeListingRuntime {
    105     fn default() -> Self {
    106         Self {
    107             state: Arc::new(Mutex::new(TradeListingState::default())),
    108             config: TradeListingRuntimeConfig::default(),
    109             persistence: None,
    110         }
    111     }
    112 }
    113 
    114 impl TradeListingRuntime {
    115     pub fn new() -> Self {
    116         Self::default()
    117     }
    118 
    119     pub async fn load(config: TradeListingRuntimeConfig) -> Result<Self, TradeListingRuntimeError> {
    120         let persistence = Arc::new(TradeListingStatePersistence::new(config.state_path.clone()));
    121         let state = persistence.load().await?;
    122         Ok(Self {
    123             state: Arc::new(Mutex::new(state)),
    124             config,
    125             persistence: Some(persistence),
    126         })
    127     }
    128 
    129     pub fn state(&self) -> SharedTradeListingState {
    130         Arc::clone(&self.state)
    131     }
    132 
    133     pub async fn persist(&self) -> Result<(), TradeListingRuntimeError> {
    134         let Some(persistence) = &self.persistence else {
    135             return Ok(());
    136         };
    137         let snapshot = self.state.lock().await.clone();
    138         persistence.persist(&snapshot).await
    139     }
    140 
    141     pub async fn mark_processed_event(
    142         &self,
    143         created_at: u32,
    144     ) -> Result<(), TradeListingRuntimeError> {
    145         {
    146             let mut state = self.state.lock().await;
    147             state.observe_event_created_at(created_at);
    148         }
    149         self.persist().await
    150     }
    151 
    152     pub async fn recovery_filter(&self, kinds: Vec<RadrootsNostrKind>) -> RadrootsNostrFilter {
    153         let since = {
    154             let state = self.state.lock().await;
    155             state.replay_since(
    156                 RadrootsNostrTimestamp::now().as_secs(),
    157                 self.config.replay_window_secs,
    158                 self.config.replay_overlap_secs,
    159             )
    160         };
    161         RadrootsNostrFilter::new()
    162             .kinds(kinds)
    163             .since(RadrootsNostrTimestamp::from(since))
    164     }
    165 }
    166 
    167 impl TradeListingState {
    168     pub fn upsert_listing_event(&mut self, listing_addr: &str, event_id: &str, kind: u32) {
    169         self.listing_events.insert(
    170             listing_addr.to_string(),
    171             ListingEventState {
    172                 event_id: event_id.to_string(),
    173                 kind,
    174             },
    175         );
    176     }
    177 
    178     pub fn listing_event_id(&self, listing_addr: &str) -> Option<&str> {
    179         self.listing_events
    180             .get(listing_addr)
    181             .map(|listing| listing.event_id.as_str())
    182     }
    183 
    184     pub fn mark_listing_validated(&mut self, listing_addr: &str, event_id: &str) {
    185         self.validated_listings.insert(listing_addr.to_string());
    186         self.validated_listing_events.insert(
    187             listing_addr.to_string(),
    188             ValidatedListingState {
    189                 event_id: event_id.to_string(),
    190             },
    191         );
    192     }
    193 
    194     pub fn clear_listing_validation(&mut self, listing_addr: &str) {
    195         self.validated_listings.remove(listing_addr);
    196         self.validated_listing_events.remove(listing_addr);
    197     }
    198 
    199     pub fn validated_listing_event_id(&self, listing_addr: &str) -> Option<&str> {
    200         self.validated_listing_events
    201             .get(listing_addr)
    202             .map(|validated| validated.event_id.as_str())
    203     }
    204 
    205     pub fn is_listing_validated(&self, listing_addr: &str) -> bool {
    206         self.validated_listing_event_id(listing_addr).is_some()
    207     }
    208 
    209     pub fn order_exists(&self, order_id: &str) -> bool {
    210         self.orders.contains_key(order_id)
    211     }
    212 
    213     pub fn get_order_mut(&mut self, order_id: &str) -> Option<&mut TradeOrderState> {
    214         self.orders.get_mut(order_id)
    215     }
    216 
    217     pub fn insert_order(&mut self, order: TradeOrderState) {
    218         self.orders.insert(order.order_id.clone(), order);
    219     }
    220 
    221     pub fn mark_event_seen(&mut self, order_id: &str, event_id: &str) -> bool {
    222         if let Some(state) = self.orders.get_mut(order_id) {
    223             state.seen_event_ids.insert(event_id.to_string())
    224         } else {
    225             false
    226         }
    227     }
    228 
    229     pub fn is_event_seen(&self, order_id: &str, event_id: &str) -> bool {
    230         self.orders
    231             .get(order_id)
    232             .map(|state| state.seen_event_ids.contains(event_id))
    233             .unwrap_or(false)
    234     }
    235 
    236     pub fn mark_non_order_event_seen(&mut self, event_id: &str) -> bool {
    237         self.seen_non_order_event_ids.insert(event_id.to_string())
    238     }
    239 
    240     pub fn is_non_order_event_seen(&self, event_id: &str) -> bool {
    241         self.seen_non_order_event_ids.contains(event_id)
    242     }
    243 
    244     pub fn observe_event_created_at(&mut self, created_at: u32) {
    245         self.last_event_created_at = Some(
    246             self.last_event_created_at
    247                 .map_or(created_at, |current| current.max(created_at)),
    248         );
    249     }
    250 
    251     pub fn last_event_created_at(&self) -> Option<u32> {
    252         self.last_event_created_at
    253     }
    254 
    255     pub fn replay_since(
    256         &self,
    257         now_secs: u64,
    258         replay_window_secs: u64,
    259         replay_overlap_secs: u64,
    260     ) -> u64 {
    261         match self.last_event_created_at {
    262             Some(last) => u64::from(last).saturating_sub(replay_overlap_secs),
    263             None => now_secs.saturating_sub(replay_window_secs),
    264         }
    265     }
    266 }
    267 
    268 impl TradeListingStatePersistence {
    269     fn new(path: PathBuf) -> Self {
    270         Self { path }
    271     }
    272 
    273     async fn load(&self) -> Result<TradeListingState, TradeListingRuntimeError> {
    274         if !tokio::fs::try_exists(&self.path).await? {
    275             return Ok(TradeListingState::default());
    276         }
    277 
    278         let payload = tokio::fs::read_to_string(&self.path).await?;
    279         let snapshot: PersistedTradeListingState = serde_json::from_str(&payload)?;
    280         if snapshot.version != TRADE_LISTING_STATE_VERSION {
    281             return Err(TradeListingRuntimeError::UnsupportedStateVersion(
    282                 snapshot.version,
    283             ));
    284         }
    285         Ok(snapshot.state)
    286     }
    287 
    288     async fn persist(&self, state: &TradeListingState) -> Result<(), TradeListingRuntimeError> {
    289         if let Some(parent) = self.path.parent() {
    290             if !parent.as_os_str().is_empty() {
    291                 tokio::fs::create_dir_all(parent).await?;
    292             }
    293         }
    294 
    295         let snapshot = PersistedTradeListingState {
    296             version: TRADE_LISTING_STATE_VERSION,
    297             state: state.clone(),
    298         };
    299         let payload = serde_json::to_vec_pretty(&snapshot)?;
    300         let temp_path = temp_state_path(&self.path)?;
    301         tokio::fs::write(&temp_path, payload).await?;
    302         tokio::fs::rename(&temp_path, &self.path).await?;
    303         Ok(())
    304     }
    305 }
    306 
    307 fn temp_state_path(path: &Path) -> Result<PathBuf, TradeListingRuntimeError> {
    308     let file_name = path
    309         .file_name()
    310         .ok_or_else(|| TradeListingRuntimeError::InvalidStatePath(path.to_path_buf()))?;
    311     Ok(path.with_file_name(format!("{}.tmp", file_name.to_string_lossy())))
    312 }
    313 
    314 #[derive(Debug, Clone, PartialEq, Eq)]
    315 pub enum TradeListingStateError {
    316     MissingOrder,
    317     InvalidTransition {
    318         from: TradeOrderStatus,
    319         to: TradeOrderStatus,
    320     },
    321 }
    322 
    323 impl core::fmt::Display for TradeListingStateError {
    324     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
    325         match self {
    326             TradeListingStateError::MissingOrder => write!(f, "missing order state"),
    327             TradeListingStateError::InvalidTransition { from, to } => {
    328                 write!(f, "invalid order transition: {from:?} -> {to:?}")
    329             }
    330         }
    331     }
    332 }
    333 
    334 impl std::error::Error for TradeListingStateError {}
    335 
    336 #[derive(Debug, Error)]
    337 pub enum TradeListingRuntimeError {
    338     #[error("invalid trade listing state path: {0}")]
    339     InvalidStatePath(PathBuf),
    340     #[error("unsupported trade listing state version: {0}")]
    341     UnsupportedStateVersion(u32),
    342     #[error("trade listing state io error: {0}")]
    343     Io(#[from] std::io::Error),
    344     #[error("trade listing state json error: {0}")]
    345     Json(#[from] serde_json::Error),
    346 }
    347 
    348 #[cfg(test)]
    349 #[cfg_attr(coverage_nightly, coverage(off))]
    350 mod tests {
    351     use super::{
    352         ListingEventState, PersistedTradeListingState, TradeListingRuntime,
    353         TradeListingRuntimeConfig, TradeListingRuntimeError, TradeListingState,
    354         TradeListingStateError, TradeOrderState, TradeOrderStatus, ValidatedListingState,
    355     };
    356     use std::collections::{HashMap, HashSet};
    357 
    358     fn unique_state_path(suffix: &str) -> std::path::PathBuf {
    359         let nanos = std::time::SystemTime::now()
    360             .duration_since(std::time::UNIX_EPOCH)
    361             .expect("time")
    362             .as_nanos();
    363         std::env::temp_dir().join(format!("rhi-trade-state-{suffix}-{nanos}.json"))
    364     }
    365 
    366     #[test]
    367     fn state_tracks_listings_events_and_replay_anchor() {
    368         let mut state = TradeListingState::default();
    369         assert!(!state.is_listing_validated("addr"));
    370         state.mark_listing_validated("addr", "evt-listing-1");
    371         assert!(state.is_listing_validated("addr"));
    372         assert_eq!(
    373             state.validated_listing_event_id("addr"),
    374             Some("evt-listing-1")
    375         );
    376 
    377         let order = TradeOrderState {
    378             order_id: "order-1".into(),
    379             listing_addr: "addr".into(),
    380             buyer_pubkey: "buyer".into(),
    381             seller_pubkey: "seller".into(),
    382             status: TradeOrderStatus::Requested,
    383             listing_snapshot_event_id: Some("evt-listing-1".into()),
    384             root_event_id: Some("evt-root-1".into()),
    385             last_event_id: Some("evt-root-1".into()),
    386             seen_event_ids: Default::default(),
    387         };
    388         state.insert_order(order);
    389         assert!(!state.is_event_seen("order-1", "evt"));
    390         assert!(state.mark_event_seen("order-1", "evt"));
    391         assert!(state.is_event_seen("order-1", "evt"));
    392         assert!(!state.is_non_order_event_seen("evt-non-order"));
    393         assert!(state.mark_non_order_event_seen("evt-non-order"));
    394         assert!(state.is_non_order_event_seen("evt-non-order"));
    395         state.upsert_listing_event("addr", "evt-listing-1", 30402);
    396         assert_eq!(state.listing_event_id("addr"), Some("evt-listing-1"));
    397         assert_eq!(state.replay_since(1_000, 300, 60), 700);
    398 
    399         state.observe_event_created_at(900);
    400         assert_eq!(state.last_event_created_at(), Some(900));
    401         assert_eq!(state.replay_since(1_000, 300, 60), 840);
    402     }
    403 
    404     #[test]
    405     fn state_covers_missing_order_paths_and_error_display() {
    406         let mut state = TradeListingState::default();
    407         assert!(!state.order_exists("missing"));
    408         assert!(state.get_order_mut("missing").is_none());
    409         assert!(!state.mark_event_seen("missing", "evt-1"));
    410         assert!(!state.is_event_seen("missing", "evt-1"));
    411         assert!(!state.is_non_order_event_seen("evt-2"));
    412 
    413         assert_eq!(
    414             TradeListingStateError::MissingOrder.to_string(),
    415             "missing order state"
    416         );
    417 
    418         let invalid = TradeListingStateError::InvalidTransition {
    419             from: TradeOrderStatus::Requested,
    420             to: TradeOrderStatus::Completed,
    421         };
    422         assert_eq!(
    423             invalid.to_string(),
    424             "invalid order transition: Requested -> Completed"
    425         );
    426     }
    427 
    428     #[tokio::test]
    429     async fn runtime_reuses_shared_trade_listing_state() {
    430         let runtime = TradeListingRuntime::new();
    431         let state = runtime.state();
    432         state
    433             .lock()
    434             .await
    435             .mark_listing_validated("addr", "evt-listing-1");
    436 
    437         assert!(runtime.state().lock().await.is_listing_validated("addr"));
    438     }
    439 
    440     #[tokio::test]
    441     async fn runtime_persists_and_loads_trade_listing_state() {
    442         let path = unique_state_path("roundtrip");
    443         let config = TradeListingRuntimeConfig {
    444             state_path: path.clone(),
    445             replay_window_secs: 600,
    446             replay_overlap_secs: 30,
    447         };
    448         let runtime = TradeListingRuntime::load(config.clone())
    449             .await
    450             .expect("runtime");
    451 
    452         {
    453             let state_handle = runtime.state();
    454             let mut state = state_handle.lock().await;
    455             state.mark_listing_validated("addr", "evt-listing-1");
    456             state.mark_non_order_event_seen("evt-validate-1");
    457             state.observe_event_created_at(456);
    458         }
    459         runtime.persist().await.expect("persist");
    460 
    461         let loaded = TradeListingRuntime::load(config).await.expect("load");
    462         let loaded_state_handle = loaded.state();
    463         let loaded_state = loaded_state_handle.lock().await;
    464         assert!(loaded_state.is_listing_validated("addr"));
    465         assert_eq!(
    466             loaded_state.validated_listing_event_id("addr"),
    467             Some("evt-listing-1")
    468         );
    469         assert!(loaded_state.is_non_order_event_seen("evt-validate-1"));
    470         assert_eq!(loaded_state.last_event_created_at(), Some(456));
    471 
    472         let _ = tokio::fs::remove_file(path).await;
    473     }
    474 
    475     #[tokio::test]
    476     async fn runtime_load_rejects_unsupported_snapshot_version() {
    477         let path = unique_state_path("version");
    478         let payload = PersistedTradeListingState {
    479             version: 99,
    480             state: TradeListingState::default(),
    481         };
    482         tokio::fs::write(&path, serde_json::to_vec(&payload).expect("payload"))
    483             .await
    484             .expect("write");
    485 
    486         let err = TradeListingRuntime::load(TradeListingRuntimeConfig {
    487             state_path: path.clone(),
    488             replay_window_secs: 600,
    489             replay_overlap_secs: 30,
    490         })
    491         .await
    492         .expect_err("unsupported snapshot should fail");
    493         assert!(matches!(
    494             err,
    495             TradeListingRuntimeError::UnsupportedStateVersion(99)
    496         ));
    497 
    498         let _ = tokio::fs::remove_file(path).await;
    499     }
    500 
    501     #[tokio::test]
    502     async fn runtime_loads_legacy_validation_state_without_trusting_it() {
    503         let path = unique_state_path("legacy-validation");
    504         let payload = PersistedTradeListingState {
    505             version: 1,
    506             state: TradeListingState {
    507                 validated_listings: ["addr".to_string()].into_iter().collect(),
    508                 validated_listing_events: HashMap::new(),
    509                 listing_events: HashMap::new(),
    510                 seen_non_order_event_ids: HashSet::new(),
    511                 orders: HashMap::new(),
    512                 last_event_created_at: Some(321),
    513             },
    514         };
    515         tokio::fs::write(&path, serde_json::to_vec(&payload).expect("payload"))
    516             .await
    517             .expect("write");
    518 
    519         let loaded = TradeListingRuntime::load(TradeListingRuntimeConfig {
    520             state_path: path.clone(),
    521             replay_window_secs: 600,
    522             replay_overlap_secs: 30,
    523         })
    524         .await
    525         .expect("load");
    526         let loaded_state_handle = loaded.state();
    527         let loaded_state = loaded_state_handle.lock().await;
    528         assert!(!loaded_state.is_listing_validated("addr"));
    529         assert_eq!(loaded_state.validated_listing_event_id("addr"), None);
    530         assert_eq!(loaded_state.last_event_created_at(), Some(321));
    531 
    532         let _ = tokio::fs::remove_file(path).await;
    533     }
    534 
    535     #[test]
    536     fn state_can_clear_listing_validation() {
    537         let mut state = TradeListingState {
    538             validated_listings: ["addr".to_string()].into_iter().collect(),
    539             validated_listing_events: HashMap::from([(
    540                 "addr".to_string(),
    541                 ValidatedListingState {
    542                     event_id: "evt-listing-1".to_string(),
    543                 },
    544             )]),
    545             listing_events: HashMap::from([(
    546                 "addr".to_string(),
    547                 ListingEventState {
    548                     event_id: "evt-listing-1".to_string(),
    549                     kind: 30402,
    550                 },
    551             )]),
    552             seen_non_order_event_ids: HashSet::new(),
    553             orders: HashMap::new(),
    554             last_event_created_at: None,
    555         };
    556         assert!(state.is_listing_validated("addr"));
    557         state.clear_listing_validation("addr");
    558         assert!(!state.is_listing_validated("addr"));
    559         assert_eq!(state.validated_listing_event_id("addr"), None);
    560     }
    561 }