1use std::collections::{HashMap, VecDeque};
4use std::time::{Duration, Instant};
5
6use parking_lot::Mutex;
7
8const FREQ_WINDOW: Duration = Duration::from_secs(30);
10
11fn default_proto_limits() -> HashMap<u32, u32> {
13 use futu_core::proto_id::*;
14 let mut m = HashMap::new();
15 m.insert(TRD_UNLOCK_TRADE, 10);
16 m.insert(TRD_PLACE_ORDER, 15);
17 m.insert(TRD_MODIFY_ORDER, 20);
18 m.insert(TRD_GET_HISTORY_ORDER_FILL_LIST, 10);
19 m.insert(TRD_GET_HISTORY_ORDER_LIST, 10);
20 m.insert(QOT_GET_SECURITY_SNAPSHOT, 10);
21 m.insert(QOT_GET_PLATE_SET, 10);
22 m.insert(QOT_GET_PLATE_SECURITY, 10);
23 m.insert(QOT_GET_OWNER_PLATE, 10);
24 m.insert(QOT_GET_HOLDING_CHANGE_LIST, 10);
25 m.insert(QOT_GET_OPTION_CHAIN, 10);
26 m.insert(QOT_REQUEST_HISTORY_KL, 10);
27 m
28}
29
30struct ConnFreqRecord {
32 proto_times: HashMap<u32, VecDeque<Instant>>,
34 last_serial: u32,
36}
37
38impl ConnFreqRecord {
39 fn new() -> Self {
40 Self {
41 proto_times: HashMap::new(),
42 last_serial: 0,
43 }
44 }
45}
46
47pub struct ProtectionManager {
49 records: Mutex<HashMap<u64, ConnFreqRecord>>,
50 proto_limits: HashMap<u32, u32>,
51}
52
53impl ProtectionManager {
54 pub fn new() -> Self {
55 Self {
56 records: Mutex::new(HashMap::new()),
57 proto_limits: default_proto_limits(),
58 }
59 }
60
61 pub fn check_freq_limit(&self, conn_id: u64, proto_id: u32) -> bool {
65 let limit = match self.proto_limits.get(&proto_id) {
66 Some(&limit) => limit,
67 None => return false, };
69
70 let mut records = self.records.lock();
71 let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
72
73 let times = record.proto_times.entry(proto_id).or_default();
74 let now = Instant::now();
75
76 while times
78 .front()
79 .is_some_and(|t| now.duration_since(*t) > FREQ_WINDOW)
80 {
81 times.pop_front();
82 }
83
84 if times.len() as u32 >= limit {
85 true } else {
87 times.push_back(now);
88 false
89 }
90 }
91
92 pub fn check_replay(&self, conn_id: u64, serial_no: u32) -> bool {
96 let mut records = self.records.lock();
97 let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
98
99 if serial_no <= record.last_serial {
100 true } else {
102 record.last_serial = serial_no;
103 false
104 }
105 }
106
107 pub fn on_disconnect(&self, conn_id: u64) {
109 self.records.lock().remove(&conn_id);
110 }
111}
112
113impl Default for ProtectionManager {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_freq_limit() {
125 let mgr = ProtectionManager::new();
126 let conn_id = 1;
127 let proto_id = futu_core::proto_id::TRD_UNLOCK_TRADE; for _ in 0..10 {
131 assert!(!mgr.check_freq_limit(conn_id, proto_id));
132 }
133
134 assert!(mgr.check_freq_limit(conn_id, proto_id));
136 }
137
138 #[test]
139 fn test_no_limit_proto() {
140 let mgr = ProtectionManager::new();
141 for _ in 0..100 {
143 assert!(!mgr.check_freq_limit(1, futu_core::proto_id::INIT_CONNECT));
144 }
145 }
146
147 #[test]
148 fn test_replay_detection() {
149 let mgr = ProtectionManager::new();
150 assert!(!mgr.check_replay(1, 1));
151 assert!(!mgr.check_replay(1, 2));
152 assert!(mgr.check_replay(1, 2)); assert!(mgr.check_replay(1, 1)); assert!(!mgr.check_replay(1, 3)); }
156
157 #[test]
158 fn test_disconnect_cleanup() {
159 let mgr = ProtectionManager::new();
160 mgr.check_freq_limit(1, futu_core::proto_id::TRD_PLACE_ORDER);
161 mgr.check_replay(1, 5);
162 mgr.on_disconnect(1);
163
164 assert!(!mgr.check_replay(1, 1));
166 }
167}