resource_limits.rs (4061B)
1 #![forbid(unsafe_code)] 2 3 use crate::errors::BaseRelayError; 4 use std::sync::{ 5 Arc, 6 atomic::{AtomicUsize, Ordering}, 7 }; 8 9 #[derive(Debug, Clone)] 10 pub struct RelayResourceLimiter { 11 inner: Arc<RelayResourceLimiterInner>, 12 } 13 14 #[derive(Debug)] 15 struct RelayResourceLimiterInner { 16 max_connections: usize, 17 max_subscriptions: usize, 18 active_connections: AtomicUsize, 19 active_subscriptions: AtomicUsize, 20 } 21 22 impl RelayResourceLimiter { 23 pub fn new(max_connections: usize, max_subscriptions: usize) -> Self { 24 Self { 25 inner: Arc::new(RelayResourceLimiterInner { 26 max_connections, 27 max_subscriptions, 28 active_connections: AtomicUsize::new(0), 29 active_subscriptions: AtomicUsize::new(0), 30 }), 31 } 32 } 33 34 pub fn try_open_connection(&self) -> Result<RelayConnectionPermit, BaseRelayError> { 35 increment_with_limit( 36 &self.inner.active_connections, 37 1, 38 self.inner.max_connections, 39 "host total connection limit exceeded", 40 )?; 41 Ok(RelayConnectionPermit { 42 resources: self.inner.clone(), 43 released: false, 44 }) 45 } 46 47 pub fn try_open_subscriptions( 48 &self, 49 count: usize, 50 ) -> Result<RelaySubscriptionPermit, BaseRelayError> { 51 if count == 0 { 52 return Err(BaseRelayError::invalid( 53 "subscription reservation count must be greater than zero", 54 )); 55 } 56 increment_with_limit( 57 &self.inner.active_subscriptions, 58 count, 59 self.inner.max_subscriptions, 60 "host total subscription limit exceeded", 61 )?; 62 Ok(RelaySubscriptionPermit { 63 resources: self.inner.clone(), 64 count, 65 released: false, 66 }) 67 } 68 69 pub fn active_connections(&self) -> usize { 70 self.inner.active_connections.load(Ordering::Relaxed) 71 } 72 73 pub fn active_subscriptions(&self) -> usize { 74 self.inner.active_subscriptions.load(Ordering::Relaxed) 75 } 76 77 pub fn max_connections(&self) -> usize { 78 self.inner.max_connections 79 } 80 81 pub fn max_subscriptions(&self) -> usize { 82 self.inner.max_subscriptions 83 } 84 } 85 86 #[derive(Debug)] 87 pub struct RelayConnectionPermit { 88 resources: Arc<RelayResourceLimiterInner>, 89 released: bool, 90 } 91 92 impl RelayConnectionPermit { 93 pub fn release(mut self) { 94 self.release_inner(); 95 } 96 97 fn release_inner(&mut self) { 98 if !self.released { 99 self.resources 100 .active_connections 101 .fetch_sub(1, Ordering::Relaxed); 102 self.released = true; 103 } 104 } 105 } 106 107 impl Drop for RelayConnectionPermit { 108 fn drop(&mut self) { 109 self.release_inner(); 110 } 111 } 112 113 #[derive(Debug)] 114 pub struct RelaySubscriptionPermit { 115 resources: Arc<RelayResourceLimiterInner>, 116 count: usize, 117 released: bool, 118 } 119 120 impl RelaySubscriptionPermit { 121 pub fn release(mut self) { 122 self.release_inner(); 123 } 124 125 fn release_inner(&mut self) { 126 if !self.released { 127 self.resources 128 .active_subscriptions 129 .fetch_sub(self.count, Ordering::Relaxed); 130 self.released = true; 131 } 132 } 133 } 134 135 impl Drop for RelaySubscriptionPermit { 136 fn drop(&mut self) { 137 self.release_inner(); 138 } 139 } 140 141 fn increment_with_limit( 142 counter: &AtomicUsize, 143 amount: usize, 144 limit: usize, 145 message: &'static str, 146 ) -> Result<(), BaseRelayError> { 147 let mut current = counter.load(Ordering::Relaxed); 148 loop { 149 let Some(next) = current.checked_add(amount) else { 150 return Err(BaseRelayError::restricted(message)); 151 }; 152 if next > limit { 153 return Err(BaseRelayError::restricted(message)); 154 } 155 match counter.compare_exchange(current, next, Ordering::Relaxed, Ordering::Relaxed) { 156 Ok(_) => return Ok(()), 157 Err(actual) => current = actual, 158 } 159 } 160 }