1use std::collections::HashSet;
4
5use chrono::{DateTime, NaiveTime, Utc};
6use rand::RngCore;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::limits::Limits;
11use crate::scope::Scope;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct KeyRecord {
16 pub id: String,
18 pub hash: String,
20 pub scopes: HashSet<Scope>,
21
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub allowed_markets: Option<HashSet<String>>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub allowed_symbols: Option<HashSet<String>>,
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub max_order_value: Option<f64>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub max_daily_value: Option<f64>,
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub hours_window: Option<String>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub max_orders_per_minute: Option<u32>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub allowed_trd_sides: Option<HashSet<String>>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub allowed_acc_ids: Option<HashSet<u64>>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub allowed_card_nums: Option<Vec<String>>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub expires_at: Option<DateTime<Utc>>,
65 pub created_at: DateTime<Utc>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub note: Option<String>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub allowed_machines: Option<Vec<String>>,
78
79 #[serde(skip)]
92 pub raw_explicit_acc_ids: Option<HashSet<u64>>,
93}
94
95impl KeyRecord {
96 #[must_use = "丢弃生成结果会丢失 plaintext; 调用方必须立即展示给用户"]
100 pub fn generate(
101 id: impl Into<String>,
102 scopes: HashSet<Scope>,
103 limits: Option<Limits>,
104 expires_at: Option<DateTime<Utc>>,
105 note: Option<String>,
106 ) -> (String, KeyRecord) {
107 Self::generate_with_machines(id, scopes, limits, expires_at, note, None)
108 }
109
110 #[must_use = "丢弃生成结果会丢失 plaintext; 调用方必须立即展示给用户"]
112 pub fn generate_with_machines(
113 id: impl Into<String>,
114 scopes: HashSet<Scope>,
115 limits: Option<Limits>,
116 expires_at: Option<DateTime<Utc>>,
117 note: Option<String>,
118 allowed_machines: Option<Vec<String>>,
119 ) -> (String, KeyRecord) {
120 let mut bytes = [0u8; 32];
121 rand::thread_rng().fill_bytes(&mut bytes);
122 let plaintext = hex::encode(bytes);
123 let hash = format!(
124 "sha256:{}",
125 hex::encode(Sha256::digest(plaintext.as_bytes()))
126 );
127 let limits = limits.unwrap_or_default();
128 let raw_explicit_acc_ids = limits.allowed_acc_ids.clone();
130 let record = KeyRecord {
131 id: id.into(),
132 hash,
133 scopes,
134 allowed_markets: limits.allowed_markets,
135 allowed_symbols: limits.allowed_symbols,
136 max_order_value: limits.max_order_value,
137 max_daily_value: limits.max_daily_value,
138 hours_window: limits.hours_window,
139 max_orders_per_minute: limits.max_orders_per_minute,
140 allowed_trd_sides: limits.allowed_trd_sides,
141 allowed_acc_ids: limits.allowed_acc_ids,
142 allowed_card_nums: limits.allowed_card_nums,
143 expires_at,
144 created_at: Utc::now(),
145 note,
146 allowed_machines,
147 raw_explicit_acc_ids,
148 };
149 (plaintext, record)
150 }
151
152 pub fn check_machine(&self) -> Result<(), crate::machine::MachineError> {
154 crate::machine::check(&self.id, self.allowed_machines.as_deref())
155 }
156
157 #[must_use]
159 pub fn matches(&self, plaintext: &str) -> bool {
160 let computed = hash_plaintext(plaintext);
161 self.matches_hash(&computed)
162 }
163
164 pub(crate) fn matches_hash(&self, computed: &str) -> bool {
165 constant_time_eq_str(&self.hash, computed)
166 }
167
168 pub(crate) fn is_generated_plaintext_shape(plaintext: &str) -> bool {
172 plaintext.len() == 64
173 && plaintext
174 .bytes()
175 .all(|b| b.is_ascii_digit() || (b'a'..=b'f').contains(&b))
176 }
177
178 #[must_use]
180 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
181 self.expires_at.map(|t| now >= t).unwrap_or(false)
182 }
183
184 pub fn hours_range(&self) -> Result<Option<(NaiveTime, NaiveTime)>, String> {
186 let Some(s) = &self.hours_window else {
187 return Ok(None);
188 };
189 let (l, r) = s
190 .split_once('-')
191 .ok_or_else(|| format!("invalid hours_window {s:?}: expect HH:MM-HH:MM"))?;
192 let parse = |p: &str| {
193 NaiveTime::parse_from_str(p.trim(), "%H:%M")
194 .map_err(|e| format!("invalid time {p:?}: {e}"))
195 };
196 Ok(Some((parse(l)?, parse(r)?)))
197 }
198
199 #[must_use]
201 pub fn limits(&self) -> Limits {
202 Limits {
203 allowed_markets: self.allowed_markets.clone(),
204 allowed_symbols: self.allowed_symbols.clone(),
205 max_order_value: self.max_order_value,
206 max_daily_value: self.max_daily_value,
207 hours_window: self.hours_window.clone(),
208 max_orders_per_minute: self.max_orders_per_minute,
209 allowed_trd_sides: self.allowed_trd_sides.clone(),
210 allowed_acc_ids: self.allowed_acc_ids.clone(),
211 allowed_card_nums: self.allowed_card_nums.clone(),
212 }
213 }
214}
215
216fn constant_time_eq_str(a: &str, b: &str) -> bool {
217 let a = a.as_bytes();
218 let b = b.as_bytes();
219 if a.len() != b.len() {
220 return false;
221 }
222 let mut acc: u8 = 0;
223 for (x, y) in a.iter().zip(b.iter()) {
224 acc |= x ^ y;
225 }
226 acc == 0
227}
228
229#[must_use]
231pub fn hash_plaintext(plaintext: &str) -> String {
232 format!(
233 "sha256:{}",
234 hex::encode(Sha256::digest(plaintext.as_bytes()))
235 )
236}
237
238#[cfg(test)]
239mod tests;