state.rs (18803B)
1 #![forbid(unsafe_code)] 2 3 use std::collections::{HashMap, HashSet}; 4 use std::path::{Path, PathBuf}; 5 use std::sync::Arc; 6 7 use radroots_nostr::prelude::{RadrootsNostrFilter, RadrootsNostrKind, RadrootsNostrTimestamp}; 8 use serde::{Deserialize, Serialize}; 9 use thiserror::Error; 10 use tokio::sync::Mutex; 11 12 pub type SharedTradeListingState = Arc<Mutex<TradeListingState>>; 13 14 const TRADE_LISTING_STATE_VERSION: u32 = 1; 15 16 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] 17 pub enum TradeOrderStatus { 18 Requested, 19 Accepted, 20 Declined, 21 Cancelled, 22 Completed, 23 Disputed, 24 Invalid, 25 } 26 27 #[derive(Clone, Debug, Serialize, Deserialize)] 28 pub struct TradeOrderState { 29 pub order_id: String, 30 pub listing_addr: String, 31 pub buyer_pubkey: String, 32 pub seller_pubkey: String, 33 pub status: TradeOrderStatus, 34 #[serde(default)] 35 pub listing_snapshot_event_id: Option<String>, 36 #[serde(default)] 37 pub root_event_id: Option<String>, 38 #[serde(default)] 39 pub last_event_id: Option<String>, 40 pub seen_event_ids: HashSet<String>, 41 } 42 43 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 44 pub struct ValidatedListingState { 45 pub event_id: String, 46 } 47 48 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 49 pub struct ListingEventState { 50 pub event_id: String, 51 pub kind: u32, 52 } 53 54 #[derive(Debug, Default, Clone, Serialize, Deserialize)] 55 pub struct TradeListingState { 56 #[serde(default)] 57 validated_listings: HashSet<String>, 58 #[serde(default)] 59 validated_listing_events: HashMap<String, ValidatedListingState>, 60 #[serde(default)] 61 listing_events: HashMap<String, ListingEventState>, 62 #[serde(default)] 63 seen_non_order_event_ids: HashSet<String>, 64 orders: HashMap<String, TradeOrderState>, 65 last_event_created_at: Option<u32>, 66 } 67 68 #[derive(Clone, Debug)] 69 pub struct TradeListingRuntime { 70 state: SharedTradeListingState, 71 config: TradeListingRuntimeConfig, 72 persistence: Option<Arc<TradeListingStatePersistence>>, 73 } 74 75 #[derive(Clone, Debug, Serialize, Deserialize)] 76 pub struct TradeListingRuntimeConfig { 77 pub state_path: PathBuf, 78 pub replay_window_secs: u64, 79 pub replay_overlap_secs: u64, 80 } 81 82 #[derive(Clone, Debug)] 83 struct TradeListingStatePersistence { 84 path: PathBuf, 85 } 86 87 #[derive(Debug, Serialize, Deserialize)] 88 struct PersistedTradeListingState { 89 version: u32, 90 state: TradeListingState, 91 } 92 93 impl Default for TradeListingRuntimeConfig { 94 fn default() -> Self { 95 Self { 96 state_path: crate::paths::default_subscriber_state_path_for_process() 97 .expect("resolve canonical rhi trade-listing state path"), 98 replay_window_secs: 24 * 60 * 60, 99 replay_overlap_secs: 5 * 60, 100 } 101 } 102 } 103 104 impl Default for TradeListingRuntime { 105 fn default() -> Self { 106 Self { 107 state: Arc::new(Mutex::new(TradeListingState::default())), 108 config: TradeListingRuntimeConfig::default(), 109 persistence: None, 110 } 111 } 112 } 113 114 impl TradeListingRuntime { 115 pub fn new() -> Self { 116 Self::default() 117 } 118 119 pub async fn load(config: TradeListingRuntimeConfig) -> Result<Self, TradeListingRuntimeError> { 120 let persistence = Arc::new(TradeListingStatePersistence::new(config.state_path.clone())); 121 let state = persistence.load().await?; 122 Ok(Self { 123 state: Arc::new(Mutex::new(state)), 124 config, 125 persistence: Some(persistence), 126 }) 127 } 128 129 pub fn state(&self) -> SharedTradeListingState { 130 Arc::clone(&self.state) 131 } 132 133 pub async fn persist(&self) -> Result<(), TradeListingRuntimeError> { 134 let Some(persistence) = &self.persistence else { 135 return Ok(()); 136 }; 137 let snapshot = self.state.lock().await.clone(); 138 persistence.persist(&snapshot).await 139 } 140 141 pub async fn mark_processed_event( 142 &self, 143 created_at: u32, 144 ) -> Result<(), TradeListingRuntimeError> { 145 { 146 let mut state = self.state.lock().await; 147 state.observe_event_created_at(created_at); 148 } 149 self.persist().await 150 } 151 152 pub async fn recovery_filter(&self, kinds: Vec<RadrootsNostrKind>) -> RadrootsNostrFilter { 153 let since = { 154 let state = self.state.lock().await; 155 state.replay_since( 156 RadrootsNostrTimestamp::now().as_secs(), 157 self.config.replay_window_secs, 158 self.config.replay_overlap_secs, 159 ) 160 }; 161 RadrootsNostrFilter::new() 162 .kinds(kinds) 163 .since(RadrootsNostrTimestamp::from(since)) 164 } 165 } 166 167 impl TradeListingState { 168 pub fn upsert_listing_event(&mut self, listing_addr: &str, event_id: &str, kind: u32) { 169 self.listing_events.insert( 170 listing_addr.to_string(), 171 ListingEventState { 172 event_id: event_id.to_string(), 173 kind, 174 }, 175 ); 176 } 177 178 pub fn listing_event_id(&self, listing_addr: &str) -> Option<&str> { 179 self.listing_events 180 .get(listing_addr) 181 .map(|listing| listing.event_id.as_str()) 182 } 183 184 pub fn mark_listing_validated(&mut self, listing_addr: &str, event_id: &str) { 185 self.validated_listings.insert(listing_addr.to_string()); 186 self.validated_listing_events.insert( 187 listing_addr.to_string(), 188 ValidatedListingState { 189 event_id: event_id.to_string(), 190 }, 191 ); 192 } 193 194 pub fn clear_listing_validation(&mut self, listing_addr: &str) { 195 self.validated_listings.remove(listing_addr); 196 self.validated_listing_events.remove(listing_addr); 197 } 198 199 pub fn validated_listing_event_id(&self, listing_addr: &str) -> Option<&str> { 200 self.validated_listing_events 201 .get(listing_addr) 202 .map(|validated| validated.event_id.as_str()) 203 } 204 205 pub fn is_listing_validated(&self, listing_addr: &str) -> bool { 206 self.validated_listing_event_id(listing_addr).is_some() 207 } 208 209 pub fn order_exists(&self, order_id: &str) -> bool { 210 self.orders.contains_key(order_id) 211 } 212 213 pub fn get_order_mut(&mut self, order_id: &str) -> Option<&mut TradeOrderState> { 214 self.orders.get_mut(order_id) 215 } 216 217 pub fn insert_order(&mut self, order: TradeOrderState) { 218 self.orders.insert(order.order_id.clone(), order); 219 } 220 221 pub fn mark_event_seen(&mut self, order_id: &str, event_id: &str) -> bool { 222 if let Some(state) = self.orders.get_mut(order_id) { 223 state.seen_event_ids.insert(event_id.to_string()) 224 } else { 225 false 226 } 227 } 228 229 pub fn is_event_seen(&self, order_id: &str, event_id: &str) -> bool { 230 self.orders 231 .get(order_id) 232 .map(|state| state.seen_event_ids.contains(event_id)) 233 .unwrap_or(false) 234 } 235 236 pub fn mark_non_order_event_seen(&mut self, event_id: &str) -> bool { 237 self.seen_non_order_event_ids.insert(event_id.to_string()) 238 } 239 240 pub fn is_non_order_event_seen(&self, event_id: &str) -> bool { 241 self.seen_non_order_event_ids.contains(event_id) 242 } 243 244 pub fn observe_event_created_at(&mut self, created_at: u32) { 245 self.last_event_created_at = Some( 246 self.last_event_created_at 247 .map_or(created_at, |current| current.max(created_at)), 248 ); 249 } 250 251 pub fn last_event_created_at(&self) -> Option<u32> { 252 self.last_event_created_at 253 } 254 255 pub fn replay_since( 256 &self, 257 now_secs: u64, 258 replay_window_secs: u64, 259 replay_overlap_secs: u64, 260 ) -> u64 { 261 match self.last_event_created_at { 262 Some(last) => u64::from(last).saturating_sub(replay_overlap_secs), 263 None => now_secs.saturating_sub(replay_window_secs), 264 } 265 } 266 } 267 268 impl TradeListingStatePersistence { 269 fn new(path: PathBuf) -> Self { 270 Self { path } 271 } 272 273 async fn load(&self) -> Result<TradeListingState, TradeListingRuntimeError> { 274 if !tokio::fs::try_exists(&self.path).await? { 275 return Ok(TradeListingState::default()); 276 } 277 278 let payload = tokio::fs::read_to_string(&self.path).await?; 279 let snapshot: PersistedTradeListingState = serde_json::from_str(&payload)?; 280 if snapshot.version != TRADE_LISTING_STATE_VERSION { 281 return Err(TradeListingRuntimeError::UnsupportedStateVersion( 282 snapshot.version, 283 )); 284 } 285 Ok(snapshot.state) 286 } 287 288 async fn persist(&self, state: &TradeListingState) -> Result<(), TradeListingRuntimeError> { 289 if let Some(parent) = self.path.parent() { 290 if !parent.as_os_str().is_empty() { 291 tokio::fs::create_dir_all(parent).await?; 292 } 293 } 294 295 let snapshot = PersistedTradeListingState { 296 version: TRADE_LISTING_STATE_VERSION, 297 state: state.clone(), 298 }; 299 let payload = serde_json::to_vec_pretty(&snapshot)?; 300 let temp_path = temp_state_path(&self.path)?; 301 tokio::fs::write(&temp_path, payload).await?; 302 tokio::fs::rename(&temp_path, &self.path).await?; 303 Ok(()) 304 } 305 } 306 307 fn temp_state_path(path: &Path) -> Result<PathBuf, TradeListingRuntimeError> { 308 let file_name = path 309 .file_name() 310 .ok_or_else(|| TradeListingRuntimeError::InvalidStatePath(path.to_path_buf()))?; 311 Ok(path.with_file_name(format!("{}.tmp", file_name.to_string_lossy()))) 312 } 313 314 #[derive(Debug, Clone, PartialEq, Eq)] 315 pub enum TradeListingStateError { 316 MissingOrder, 317 InvalidTransition { 318 from: TradeOrderStatus, 319 to: TradeOrderStatus, 320 }, 321 } 322 323 impl core::fmt::Display for TradeListingStateError { 324 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 325 match self { 326 TradeListingStateError::MissingOrder => write!(f, "missing order state"), 327 TradeListingStateError::InvalidTransition { from, to } => { 328 write!(f, "invalid order transition: {from:?} -> {to:?}") 329 } 330 } 331 } 332 } 333 334 impl std::error::Error for TradeListingStateError {} 335 336 #[derive(Debug, Error)] 337 pub enum TradeListingRuntimeError { 338 #[error("invalid trade listing state path: {0}")] 339 InvalidStatePath(PathBuf), 340 #[error("unsupported trade listing state version: {0}")] 341 UnsupportedStateVersion(u32), 342 #[error("trade listing state io error: {0}")] 343 Io(#[from] std::io::Error), 344 #[error("trade listing state json error: {0}")] 345 Json(#[from] serde_json::Error), 346 } 347 348 #[cfg(test)] 349 #[cfg_attr(coverage_nightly, coverage(off))] 350 mod tests { 351 use super::{ 352 ListingEventState, PersistedTradeListingState, TradeListingRuntime, 353 TradeListingRuntimeConfig, TradeListingRuntimeError, TradeListingState, 354 TradeListingStateError, TradeOrderState, TradeOrderStatus, ValidatedListingState, 355 }; 356 use std::collections::{HashMap, HashSet}; 357 358 fn unique_state_path(suffix: &str) -> std::path::PathBuf { 359 let nanos = std::time::SystemTime::now() 360 .duration_since(std::time::UNIX_EPOCH) 361 .expect("time") 362 .as_nanos(); 363 std::env::temp_dir().join(format!("rhi-trade-state-{suffix}-{nanos}.json")) 364 } 365 366 #[test] 367 fn state_tracks_listings_events_and_replay_anchor() { 368 let mut state = TradeListingState::default(); 369 assert!(!state.is_listing_validated("addr")); 370 state.mark_listing_validated("addr", "evt-listing-1"); 371 assert!(state.is_listing_validated("addr")); 372 assert_eq!( 373 state.validated_listing_event_id("addr"), 374 Some("evt-listing-1") 375 ); 376 377 let order = TradeOrderState { 378 order_id: "order-1".into(), 379 listing_addr: "addr".into(), 380 buyer_pubkey: "buyer".into(), 381 seller_pubkey: "seller".into(), 382 status: TradeOrderStatus::Requested, 383 listing_snapshot_event_id: Some("evt-listing-1".into()), 384 root_event_id: Some("evt-root-1".into()), 385 last_event_id: Some("evt-root-1".into()), 386 seen_event_ids: Default::default(), 387 }; 388 state.insert_order(order); 389 assert!(!state.is_event_seen("order-1", "evt")); 390 assert!(state.mark_event_seen("order-1", "evt")); 391 assert!(state.is_event_seen("order-1", "evt")); 392 assert!(!state.is_non_order_event_seen("evt-non-order")); 393 assert!(state.mark_non_order_event_seen("evt-non-order")); 394 assert!(state.is_non_order_event_seen("evt-non-order")); 395 state.upsert_listing_event("addr", "evt-listing-1", 30402); 396 assert_eq!(state.listing_event_id("addr"), Some("evt-listing-1")); 397 assert_eq!(state.replay_since(1_000, 300, 60), 700); 398 399 state.observe_event_created_at(900); 400 assert_eq!(state.last_event_created_at(), Some(900)); 401 assert_eq!(state.replay_since(1_000, 300, 60), 840); 402 } 403 404 #[test] 405 fn state_covers_missing_order_paths_and_error_display() { 406 let mut state = TradeListingState::default(); 407 assert!(!state.order_exists("missing")); 408 assert!(state.get_order_mut("missing").is_none()); 409 assert!(!state.mark_event_seen("missing", "evt-1")); 410 assert!(!state.is_event_seen("missing", "evt-1")); 411 assert!(!state.is_non_order_event_seen("evt-2")); 412 413 assert_eq!( 414 TradeListingStateError::MissingOrder.to_string(), 415 "missing order state" 416 ); 417 418 let invalid = TradeListingStateError::InvalidTransition { 419 from: TradeOrderStatus::Requested, 420 to: TradeOrderStatus::Completed, 421 }; 422 assert_eq!( 423 invalid.to_string(), 424 "invalid order transition: Requested -> Completed" 425 ); 426 } 427 428 #[tokio::test] 429 async fn runtime_reuses_shared_trade_listing_state() { 430 let runtime = TradeListingRuntime::new(); 431 let state = runtime.state(); 432 state 433 .lock() 434 .await 435 .mark_listing_validated("addr", "evt-listing-1"); 436 437 assert!(runtime.state().lock().await.is_listing_validated("addr")); 438 } 439 440 #[tokio::test] 441 async fn runtime_persists_and_loads_trade_listing_state() { 442 let path = unique_state_path("roundtrip"); 443 let config = TradeListingRuntimeConfig { 444 state_path: path.clone(), 445 replay_window_secs: 600, 446 replay_overlap_secs: 30, 447 }; 448 let runtime = TradeListingRuntime::load(config.clone()) 449 .await 450 .expect("runtime"); 451 452 { 453 let state_handle = runtime.state(); 454 let mut state = state_handle.lock().await; 455 state.mark_listing_validated("addr", "evt-listing-1"); 456 state.mark_non_order_event_seen("evt-validate-1"); 457 state.observe_event_created_at(456); 458 } 459 runtime.persist().await.expect("persist"); 460 461 let loaded = TradeListingRuntime::load(config).await.expect("load"); 462 let loaded_state_handle = loaded.state(); 463 let loaded_state = loaded_state_handle.lock().await; 464 assert!(loaded_state.is_listing_validated("addr")); 465 assert_eq!( 466 loaded_state.validated_listing_event_id("addr"), 467 Some("evt-listing-1") 468 ); 469 assert!(loaded_state.is_non_order_event_seen("evt-validate-1")); 470 assert_eq!(loaded_state.last_event_created_at(), Some(456)); 471 472 let _ = tokio::fs::remove_file(path).await; 473 } 474 475 #[tokio::test] 476 async fn runtime_load_rejects_unsupported_snapshot_version() { 477 let path = unique_state_path("version"); 478 let payload = PersistedTradeListingState { 479 version: 99, 480 state: TradeListingState::default(), 481 }; 482 tokio::fs::write(&path, serde_json::to_vec(&payload).expect("payload")) 483 .await 484 .expect("write"); 485 486 let err = TradeListingRuntime::load(TradeListingRuntimeConfig { 487 state_path: path.clone(), 488 replay_window_secs: 600, 489 replay_overlap_secs: 30, 490 }) 491 .await 492 .expect_err("unsupported snapshot should fail"); 493 assert!(matches!( 494 err, 495 TradeListingRuntimeError::UnsupportedStateVersion(99) 496 )); 497 498 let _ = tokio::fs::remove_file(path).await; 499 } 500 501 #[tokio::test] 502 async fn runtime_loads_legacy_validation_state_without_trusting_it() { 503 let path = unique_state_path("legacy-validation"); 504 let payload = PersistedTradeListingState { 505 version: 1, 506 state: TradeListingState { 507 validated_listings: ["addr".to_string()].into_iter().collect(), 508 validated_listing_events: HashMap::new(), 509 listing_events: HashMap::new(), 510 seen_non_order_event_ids: HashSet::new(), 511 orders: HashMap::new(), 512 last_event_created_at: Some(321), 513 }, 514 }; 515 tokio::fs::write(&path, serde_json::to_vec(&payload).expect("payload")) 516 .await 517 .expect("write"); 518 519 let loaded = TradeListingRuntime::load(TradeListingRuntimeConfig { 520 state_path: path.clone(), 521 replay_window_secs: 600, 522 replay_overlap_secs: 30, 523 }) 524 .await 525 .expect("load"); 526 let loaded_state_handle = loaded.state(); 527 let loaded_state = loaded_state_handle.lock().await; 528 assert!(!loaded_state.is_listing_validated("addr")); 529 assert_eq!(loaded_state.validated_listing_event_id("addr"), None); 530 assert_eq!(loaded_state.last_event_created_at(), Some(321)); 531 532 let _ = tokio::fs::remove_file(path).await; 533 } 534 535 #[test] 536 fn state_can_clear_listing_validation() { 537 let mut state = TradeListingState { 538 validated_listings: ["addr".to_string()].into_iter().collect(), 539 validated_listing_events: HashMap::from([( 540 "addr".to_string(), 541 ValidatedListingState { 542 event_id: "evt-listing-1".to_string(), 543 }, 544 )]), 545 listing_events: HashMap::from([( 546 "addr".to_string(), 547 ListingEventState { 548 event_id: "evt-listing-1".to_string(), 549 kind: 30402, 550 }, 551 )]), 552 seen_non_order_event_ids: HashSet::new(), 553 orders: HashMap::new(), 554 last_event_created_at: None, 555 }; 556 assert!(state.is_listing_validated("addr")); 557 state.clear_listing_validation("addr"); 558 assert!(!state.is_listing_validated("addr")); 559 assert_eq!(state.validated_listing_event_id("addr"), None); 560 } 561 }