1use std::collections::{HashSet, VecDeque};
4
5use chrono::{DateTime, Local, NaiveDate, NaiveTime, Utc};
6use dashmap::DashMap;
7use parking_lot::Mutex;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct Limits {
13 pub allowed_markets: Option<HashSet<String>>,
14 pub allowed_symbols: Option<HashSet<String>>,
15 pub max_order_value: Option<f64>,
16 pub max_daily_value: Option<f64>,
17 pub hours_window: Option<String>,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub max_orders_per_minute: Option<u32>,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub allowed_trd_sides: Option<HashSet<String>>,
26}
27
28#[derive(Debug, Clone)]
30pub struct CheckCtx {
31 pub market: String,
33 pub symbol: String,
36 pub order_value: Option<f64>,
38 pub trd_side: Option<String>,
41}
42
43#[derive(Debug, Clone, PartialEq)]
45pub enum LimitOutcome {
46 Allow,
47 Reject(String),
48}
49
50#[derive(Debug)]
52struct DailyCounter {
53 day: Mutex<NaiveDate>,
54 total: Mutex<f64>,
55}
56
57impl DailyCounter {
58 fn new(day: NaiveDate) -> Self {
59 Self {
60 day: Mutex::new(day),
61 total: Mutex::new(0.0),
62 }
63 }
64
65 fn try_add(&self, amount: f64, max: Option<f64>, today: NaiveDate) -> Result<f64, String> {
67 let mut day = self.day.lock();
68 let mut total = self.total.lock();
69 if *day != today {
70 *day = today;
71 *total = 0.0;
72 }
73 let next = *total + amount;
74 if let Some(cap) = max {
75 if next > cap + f64::EPSILON {
76 return Err(format!(
77 "daily value cap exceeded: {next:.2} > {cap:.2} (current={:.2} + order={:.2})",
78 *total, amount
79 ));
80 }
81 }
82 *total = next;
83 Ok(next)
84 }
85
86 #[cfg(test)]
87 fn peek_total(&self) -> f64 {
88 *self.total.lock()
89 }
90}
91
92#[derive(Debug, Default)]
97struct RateWindow {
98 recent: Mutex<VecDeque<DateTime<Utc>>>,
99}
100
101impl RateWindow {
102 fn try_record(&self, now: DateTime<Utc>, max: u32) -> Result<u32, String> {
104 let mut recent = self.recent.lock();
105 let cutoff = now - chrono::Duration::seconds(60);
106 while let Some(front) = recent.front() {
107 if *front < cutoff {
108 recent.pop_front();
109 } else {
110 break;
111 }
112 }
113 if recent.len() as u32 >= max {
114 return Err(format!(
115 "rate limit exceeded: {} orders in the last 60s (cap {})",
116 recent.len(),
117 max
118 ));
119 }
120 recent.push_back(now);
121 Ok(recent.len() as u32)
122 }
123}
124
125#[derive(Debug, Default)]
127pub struct RuntimeCounters {
128 counters: DashMap<String, DailyCounter>,
129 rates: DashMap<String, RateWindow>,
130}
131
132impl RuntimeCounters {
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn check_and_commit(
143 &self,
144 key_id: &str,
145 limits: &Limits,
146 ctx: &CheckCtx,
147 now: DateTime<Utc>,
148 ) -> LimitOutcome {
149 if let Some(markets) = &limits.allowed_markets {
153 if !markets.is_empty() && !ctx.market.is_empty() && !markets.contains(&ctx.market) {
154 return LimitOutcome::Reject(format!(
155 "market {:?} not in allowed list {:?}",
156 ctx.market, markets
157 ));
158 }
159 }
160
161 if let Some(symbols) = &limits.allowed_symbols {
165 if !symbols.is_empty() && !ctx.symbol.is_empty() && !symbols.contains(&ctx.symbol) {
166 return LimitOutcome::Reject(format!(
167 "symbol {:?} not in allowed list",
168 ctx.symbol
169 ));
170 }
171 }
172
173 if let (Some(allowed), Some(side)) = (&limits.allowed_trd_sides, &ctx.trd_side) {
175 if !allowed.is_empty() && !allowed.contains(side) {
176 return LimitOutcome::Reject(format!(
177 "trd_side {side:?} not in allowed list {allowed:?}"
178 ));
179 }
180 }
181
182 if let Some(spec) = &limits.hours_window {
184 match parse_window(spec) {
185 Ok((start, end)) => {
186 let now_local = now.with_timezone(&Local).time();
187 if !in_window(now_local, start, end) {
188 return LimitOutcome::Reject(format!(
189 "outside hours window {spec} (now={})",
190 now_local.format("%H:%M")
191 ));
192 }
193 }
194 Err(e) => {
195 return LimitOutcome::Reject(format!("invalid hours_window {spec:?}: {e}"));
196 }
197 }
198 }
199
200 if let Some(value) = ctx.order_value {
202 if let Some(cap) = limits.max_order_value {
203 if value > cap + f64::EPSILON {
204 return LimitOutcome::Reject(format!(
205 "order value {value:.2} exceeds per-order cap {cap:.2}"
206 ));
207 }
208 }
209 }
210
211 if let Some(max) = limits.max_orders_per_minute {
215 let window = self.rates.entry(key_id.to_string()).or_default();
216 if let Err(e) = window.try_record(now, max) {
217 return LimitOutcome::Reject(e);
218 }
219 }
220
221 if let (Some(value), Some(_)) = (ctx.order_value, limits.max_daily_value) {
223 let today = now.date_naive();
224 let counter = self
225 .counters
226 .entry(key_id.to_string())
227 .or_insert_with(|| DailyCounter::new(today));
228 match counter.try_add(value, limits.max_daily_value, today) {
229 Ok(_) => {}
230 Err(e) => return LimitOutcome::Reject(e),
231 }
232 }
233
234 LimitOutcome::Allow
235 }
236
237 pub fn check_full_skip_rate(
249 &self,
250 key_id: &str,
251 limits: &Limits,
252 ctx: &CheckCtx,
253 now: DateTime<Utc>,
254 ) -> LimitOutcome {
255 if let Some(markets) = &limits.allowed_markets {
257 if !markets.is_empty() && !ctx.market.is_empty() && !markets.contains(&ctx.market) {
258 return LimitOutcome::Reject(format!(
259 "market {:?} not in allowed list {:?}",
260 ctx.market, markets
261 ));
262 }
263 }
264
265 if let Some(symbols) = &limits.allowed_symbols {
267 if !symbols.is_empty() && !ctx.symbol.is_empty() && !symbols.contains(&ctx.symbol) {
268 return LimitOutcome::Reject(format!(
269 "symbol {:?} not in allowed list",
270 ctx.symbol
271 ));
272 }
273 }
274
275 if let (Some(allowed), Some(side)) = (&limits.allowed_trd_sides, &ctx.trd_side) {
277 if !allowed.is_empty() && !allowed.contains(side) {
278 return LimitOutcome::Reject(format!(
279 "trd_side {side:?} not in allowed list {allowed:?}"
280 ));
281 }
282 }
283
284 if let Some(spec) = &limits.hours_window {
286 match parse_window(spec) {
287 Ok((start, end)) => {
288 let now_local = now.with_timezone(&Local).time();
289 if !in_window(now_local, start, end) {
290 return LimitOutcome::Reject(format!(
291 "outside hours window {spec} (now={})",
292 now_local.format("%H:%M")
293 ));
294 }
295 }
296 Err(e) => {
297 return LimitOutcome::Reject(format!("invalid hours_window {spec:?}: {e}"));
298 }
299 }
300 }
301
302 if let Some(value) = ctx.order_value {
304 if let Some(cap) = limits.max_order_value {
305 if value > cap + f64::EPSILON {
306 return LimitOutcome::Reject(format!(
307 "order value {value:.2} exceeds per-order cap {cap:.2}"
308 ));
309 }
310 }
311 }
312
313 if let (Some(value), Some(_)) = (ctx.order_value, limits.max_daily_value) {
317 let today = now.date_naive();
318 let counter = self
319 .counters
320 .entry(key_id.to_string())
321 .or_insert_with(|| DailyCounter::new(today));
322 match counter.try_add(value, limits.max_daily_value, today) {
323 Ok(_) => {}
324 Err(e) => return LimitOutcome::Reject(e),
325 }
326 }
327
328 LimitOutcome::Allow
329 }
330
331 #[cfg(test)]
332 fn peek_total(&self, key_id: &str) -> f64 {
333 self.counters
334 .get(key_id)
335 .map(|c| c.peek_total())
336 .unwrap_or(0.0)
337 }
338}
339
340fn parse_window(s: &str) -> Result<(NaiveTime, NaiveTime), String> {
341 let (l, r) = s
342 .split_once('-')
343 .ok_or_else(|| format!("expect HH:MM-HH:MM, got {s:?}"))?;
344 let parse = |p: &str| {
345 NaiveTime::parse_from_str(p.trim(), "%H:%M").map_err(|e| format!("bad time {p:?}: {e}"))
346 };
347 Ok((parse(l)?, parse(r)?))
348}
349
350fn in_window(t: NaiveTime, start: NaiveTime, end: NaiveTime) -> bool {
352 if start <= end {
353 t >= start && t < end
354 } else {
355 t >= start || t < end
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 fn mk_limits() -> Limits {
365 Limits {
366 allowed_markets: Some(["HK".to_string()].into_iter().collect()),
367 allowed_symbols: Some(["HK.00700".to_string()].into_iter().collect()),
368 max_order_value: Some(10_000.0),
369 max_daily_value: Some(25_000.0),
370 hours_window: None,
371 max_orders_per_minute: None,
372 allowed_trd_sides: None,
373 }
374 }
375
376 fn mk_ctx(market: &str, symbol: &str, value: Option<f64>) -> CheckCtx {
377 CheckCtx {
378 market: market.into(),
379 symbol: symbol.into(),
380 order_value: value,
381 trd_side: None,
382 }
383 }
384
385 #[test]
386 fn market_whitelist() {
387 let rc = RuntimeCounters::new();
388 let lim = mk_limits();
389 let ctx = mk_ctx("US", "HK.00700", Some(100.0));
390 assert!(matches!(
391 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
392 LimitOutcome::Reject(_)
393 ));
394 }
395
396 #[test]
397 fn symbol_whitelist() {
398 let rc = RuntimeCounters::new();
399 let lim = mk_limits();
400 let ctx = mk_ctx("HK", "HK.09988", Some(100.0));
401 assert!(matches!(
402 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
403 LimitOutcome::Reject(_)
404 ));
405 }
406
407 #[test]
408 fn per_order_cap() {
409 let rc = RuntimeCounters::new();
410 let lim = mk_limits();
411 let ctx = mk_ctx("HK", "HK.00700", Some(20_000.0));
412 assert!(matches!(
413 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
414 LimitOutcome::Reject(_)
415 ));
416 }
417
418 #[test]
419 fn daily_cap() {
420 let rc = RuntimeCounters::new();
421 let lim = mk_limits();
422 let mk = |v: f64| mk_ctx("HK", "HK.00700", Some(v));
423 assert_eq!(
424 rc.check_and_commit("k", &lim, &mk(9_000.0), Utc::now()),
425 LimitOutcome::Allow
426 );
427 assert_eq!(
428 rc.check_and_commit("k", &lim, &mk(9_000.0), Utc::now()),
429 LimitOutcome::Allow
430 );
431 assert!(matches!(
433 rc.check_and_commit("k", &lim, &mk(9_000.0), Utc::now()),
434 LimitOutcome::Reject(_)
435 ));
436 assert_eq!(rc.peek_total("k"), 18_000.0);
437 }
438
439 #[test]
440 fn side_whitelist_blocks_wrong_side() {
441 let rc = RuntimeCounters::new();
442 let mut lim = mk_limits();
443 lim.allowed_trd_sides = Some(["SELL".to_string()].into_iter().collect());
444 let ctx = CheckCtx {
445 market: "HK".into(),
446 symbol: "HK.00700".into(),
447 order_value: Some(100.0),
448 trd_side: Some("BUY".into()),
449 };
450 assert!(matches!(
451 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
452 LimitOutcome::Reject(_)
453 ));
454 }
455
456 #[test]
457 fn side_whitelist_passes_right_side() {
458 let rc = RuntimeCounters::new();
459 let mut lim = mk_limits();
460 lim.allowed_trd_sides = Some(["SELL".to_string()].into_iter().collect());
461 let ctx = CheckCtx {
462 market: "HK".into(),
463 symbol: "HK.00700".into(),
464 order_value: Some(100.0),
465 trd_side: Some("SELL".into()),
466 };
467 assert_eq!(
468 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
469 LimitOutcome::Allow
470 );
471 }
472
473 #[test]
474 fn side_whitelist_skipped_when_ctx_has_no_side() {
475 let rc = RuntimeCounters::new();
477 let mut lim = mk_limits();
478 lim.allowed_trd_sides = Some(["SELL".to_string()].into_iter().collect());
479 let ctx = mk_ctx("HK", "HK.00700", Some(100.0));
480 assert_eq!(
481 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
482 LimitOutcome::Allow
483 );
484 }
485
486 #[test]
487 fn market_whitelist_skipped_when_ctx_market_is_empty() {
488 let rc = RuntimeCounters::new();
491 let lim = mk_limits(); let ctx = CheckCtx {
493 market: "".into(),
494 symbol: "".into(),
495 order_value: None,
496 trd_side: None,
497 };
498 assert_eq!(
499 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
500 LimitOutcome::Allow
501 );
502 let ctx_us = CheckCtx {
504 market: "US".into(),
505 symbol: "".into(),
506 order_value: None,
507 trd_side: None,
508 };
509 assert!(matches!(
510 rc.check_and_commit("k", &lim, &ctx_us, Utc::now()),
511 LimitOutcome::Reject(_)
512 ));
513 }
514
515 #[test]
516 fn symbol_whitelist_skipped_when_ctx_symbol_is_empty() {
517 let rc = RuntimeCounters::new();
520 let lim = mk_limits(); let ctx = CheckCtx {
522 market: "HK".into(),
523 symbol: "".into(),
524 order_value: None,
525 trd_side: None,
526 };
527 assert_eq!(
528 rc.check_and_commit("k", &lim, &ctx, Utc::now()),
529 LimitOutcome::Allow
530 );
531 let ctx_wrong_market = CheckCtx {
533 market: "US".into(),
534 symbol: "".into(),
535 order_value: None,
536 trd_side: None,
537 };
538 assert!(matches!(
539 rc.check_and_commit("k", &lim, &ctx_wrong_market, Utc::now()),
540 LimitOutcome::Reject(_)
541 ));
542 }
543
544 #[test]
545 fn check_full_skip_rate_does_not_double_count_rate() {
546 let rc = RuntimeCounters::new();
549 let mut lim = mk_limits();
550 lim.max_orders_per_minute = Some(2);
551 let ctx = mk_ctx("HK", "HK.00700", Some(100.0));
552 let now = Utc::now();
553 assert_eq!(
555 rc.check_and_commit("k", &lim, &ctx, now),
556 LimitOutcome::Allow
557 );
558 for _ in 0..100 {
560 assert_eq!(
561 rc.check_full_skip_rate("k", &lim, &ctx, now),
562 LimitOutcome::Allow
563 );
564 }
565 assert_eq!(
567 rc.check_and_commit("k", &lim, &ctx, now),
568 LimitOutcome::Allow
569 );
570 assert!(matches!(
571 rc.check_and_commit("k", &lim, &ctx, now),
572 LimitOutcome::Reject(_)
573 ));
574 }
575
576 #[test]
577 fn check_full_skip_rate_still_enforces_market_symbol_side_value_daily() {
578 let rc = RuntimeCounters::new();
579 let mut lim = mk_limits();
580 lim.allowed_trd_sides = Some(["SELL".to_string()].into_iter().collect());
581
582 assert!(matches!(
584 rc.check_full_skip_rate(
585 "k",
586 &lim,
587 &mk_ctx("US", "HK.00700", Some(100.0)),
588 Utc::now()
589 ),
590 LimitOutcome::Reject(_)
591 ));
592 assert!(matches!(
594 rc.check_full_skip_rate(
595 "k",
596 &lim,
597 &mk_ctx("HK", "HK.09988", Some(100.0)),
598 Utc::now()
599 ),
600 LimitOutcome::Reject(_)
601 ));
602 let ctx_buy = CheckCtx {
604 market: "HK".into(),
605 symbol: "HK.00700".into(),
606 order_value: Some(100.0),
607 trd_side: Some("BUY".into()),
608 };
609 assert!(matches!(
610 rc.check_full_skip_rate("k", &lim, &ctx_buy, Utc::now()),
611 LimitOutcome::Reject(_)
612 ));
613 assert!(matches!(
615 rc.check_full_skip_rate(
616 "k",
617 &lim,
618 &mk_ctx("HK", "HK.00700", Some(20_000.0)),
619 Utc::now()
620 ),
621 LimitOutcome::Reject(_)
622 ));
623 let ctx_ok = mk_ctx("HK", "HK.00700", Some(9_000.0));
625 for _ in 0..2 {
626 assert_eq!(
627 rc.check_full_skip_rate("k", &lim, &ctx_ok, Utc::now()),
628 LimitOutcome::Allow
629 );
630 }
631 assert!(matches!(
633 rc.check_full_skip_rate("k", &lim, &ctx_ok, Utc::now()),
634 LimitOutcome::Reject(_)
635 ));
636 }
637
638 #[test]
639 fn rate_limit_counts_mutations_with_empty_symbol() {
640 let rc = RuntimeCounters::new();
642 let mut lim = mk_limits();
643 lim.max_orders_per_minute = Some(2);
644 let mutation_ctx = CheckCtx {
645 market: "HK".into(),
646 symbol: "".into(),
647 order_value: None,
648 trd_side: None,
649 };
650 let t0 = Utc::now();
651 assert_eq!(
652 rc.check_and_commit("k", &lim, &mutation_ctx, t0),
653 LimitOutcome::Allow
654 );
655 assert_eq!(
656 rc.check_and_commit("k", &lim, &mutation_ctx, t0 + chrono::Duration::seconds(1)),
657 LimitOutcome::Allow
658 );
659 assert!(matches!(
661 rc.check_and_commit("k", &lim, &mutation_ctx, t0 + chrono::Duration::seconds(2)),
662 LimitOutcome::Reject(_)
663 ));
664 }
665
666 #[test]
667 fn rate_limit_sliding_60s() {
668 let rc = RuntimeCounters::new();
669 let mut lim = mk_limits();
670 lim.max_orders_per_minute = Some(3);
671 let ctx = mk_ctx("HK", "HK.00700", None); let t0 = Utc::now();
675 for i in 0..3 {
676 assert_eq!(
677 rc.check_and_commit("k", &lim, &ctx, t0 + chrono::Duration::seconds(i)),
678 LimitOutcome::Allow
679 );
680 }
681 assert!(matches!(
682 rc.check_and_commit("k", &lim, &ctx, t0 + chrono::Duration::seconds(10)),
683 LimitOutcome::Reject(_)
684 ));
685 assert_eq!(
687 rc.check_and_commit("k", &lim, &ctx, t0 + chrono::Duration::seconds(70)),
688 LimitOutcome::Allow
689 );
690 }
691
692 #[test]
693 fn rate_limit_not_set_means_unlimited() {
694 let rc = RuntimeCounters::new();
695 let lim = mk_limits(); let ctx = mk_ctx("HK", "HK.00700", None);
697 let t0 = Utc::now();
698 for i in 0..100 {
699 assert_eq!(
700 rc.check_and_commit("k", &lim, &ctx, t0 + chrono::Duration::milliseconds(i)),
701 LimitOutcome::Allow
702 );
703 }
704 }
705
706 #[test]
707 fn window_same_day() {
708 let t = |h, m| NaiveTime::from_hms_opt(h, m, 0).unwrap();
709 assert!(in_window(t(10, 0), t(9, 0), t(16, 0)));
710 assert!(!in_window(t(8, 0), t(9, 0), t(16, 0)));
711 assert!(!in_window(t(16, 0), t(9, 0), t(16, 0)));
712 }
713
714 #[test]
715 fn window_cross_midnight() {
716 let t = |h, m| NaiveTime::from_hms_opt(h, m, 0).unwrap();
717 assert!(in_window(t(23, 0), t(22, 0), t(4, 0)));
718 assert!(in_window(t(2, 0), t(22, 0), t(4, 0)));
719 assert!(!in_window(t(10, 0), t(22, 0), t(4, 0)));
720 }
721}