1use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::Duration;
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use serde_json::json;
19use sqlx::postgres::PgRow;
20use sqlx::{PgPool, Row};
21
22use async_trait::async_trait;
23
24fn sqlx_err_to_string(e: &sqlx::Error) -> String {
26 match e {
27 sqlx::Error::Database(db_err) => db_err.message().to_owned(),
28 _ => e.to_string(),
29 }
30}
31use cadence::StatsdClient;
32use serde::{Deserialize, Serialize};
33use uuid::Uuid;
34
35use crate::db::client::DbClient;
36use crate::db::error::{DbError, DbResult};
37use crate::db::{DbSettings, User};
38use crate::notification::Notification;
39use crate::{util, MAX_ROUTER_TTL_SECS};
40
41use super::client::FetchMessageResponse;
42
43#[cfg(feature = "reliable_report")]
44const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
47#[serde(default)]
48pub struct PostgresDbSettings {
49 pub schema: Option<String>, pub router_table: String, pub message_table: String, pub meta_table: String, pub reliability_table: String, max_router_ttl: u64, }
58
59impl Default for PostgresDbSettings {
60 fn default() -> Self {
61 Self {
62 schema: None,
63 router_table: "router".to_owned(),
64 message_table: "message".to_owned(),
65 meta_table: "meta".to_owned(),
66 reliability_table: "reliability".to_owned(),
67 max_router_ttl: MAX_ROUTER_TTL_SECS,
68 }
70 }
71}
72
73impl TryFrom<&str> for PostgresDbSettings {
74 type Error = DbError;
75 fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
76 if setting_string.trim().is_empty() {
77 return Ok(PostgresDbSettings::default());
78 }
79 serde_json::from_str(setting_string).map_err(|e| {
80 DbError::General(format!(
81 "Could not parse configuration db_settings: {:?}",
82 e
83 ))
84 })
85 }
86}
87
88#[derive(Clone)]
89pub struct PgClientImpl {
90 _metrics: Arc<StatsdClient>,
91 db_settings: PostgresDbSettings,
92 pool: PgPool,
93 cached_router_table: String,
95 cached_message_table: String,
96 cached_meta_table: String,
97 #[cfg(feature = "reliable_report")]
98 cached_reliability_table: String,
99}
100
101impl PgClientImpl {
102 pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
110 let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
111 info!(
112 "📮 Initializing Postgres DB Client with settings: {:?} from {:?}",
113 db_settings, &settings.db_settings
114 );
115 if let Some(dsn) = settings.dsn.clone() {
116 trace!("📮 Postgres Connect {}", &dsn);
117
118 let connect_options = sqlx::postgres::PgConnectOptions::from_str(&dsn)
120 .map_err(|e| DbError::General(format!("Invalid Postgres DSN: {e}")))?;
121 let pool = sqlx::postgres::PgPoolOptions::new().connect_lazy_with(connect_options);
122 let cached_router_table = if let Some(schema) = &db_settings.schema {
123 format!("{}.{}", schema, db_settings.router_table)
124 } else {
125 db_settings.router_table.clone()
126 };
127 let cached_message_table = if let Some(schema) = &db_settings.schema {
128 format!("{}.{}", schema, db_settings.message_table)
129 } else {
130 db_settings.message_table.clone()
131 };
132 let cached_meta_table = if let Some(schema) = &db_settings.schema {
133 format!("{}.{}", schema, db_settings.meta_table)
134 } else {
135 db_settings.meta_table.clone()
136 };
137 #[cfg(feature = "reliable_report")]
138 let cached_reliability_table = if let Some(schema) = &db_settings.schema {
139 format!("{}.{}", schema, db_settings.reliability_table)
140 } else {
141 db_settings.reliability_table.clone()
142 };
143 return Ok(Self {
144 _metrics: metrics,
145 db_settings,
146 pool,
147 cached_router_table,
148 cached_message_table,
149 cached_meta_table,
150 #[cfg(feature = "reliable_report")]
151 cached_reliability_table,
152 });
153 };
154 Err(DbError::ConnectionError("No DSN specified".to_owned()))
155 }
156
157 async fn table_exists(&self, table_name: &str) -> DbResult<bool> {
159 let (schema, table_name) = if table_name.contains('.') {
160 let mut parts = table_name.splitn(2, '.');
161 (
162 parts.next().unwrap_or("public").to_owned(),
163 parts.next().unwrap().to_owned(),
166 )
167 } else {
168 ("public".to_owned(), table_name.to_owned())
169 };
170 let row: (bool,) = sqlx::query_as(
171 "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
172 )
173 .bind(&schema)
174 .bind(&table_name)
175 .fetch_one(&self.pool)
176 .await
177 .map_err(DbError::PgError)?;
178 Ok(row.0)
179 }
180
181 fn router_expiry(&self) -> u64 {
183 util::sec_since_epoch() + self.db_settings.max_router_ttl
184 }
185
186 pub(crate) fn router_table(&self) -> &str {
188 &self.cached_router_table
189 }
190
191 pub(crate) fn message_table(&self) -> &str {
193 &self.cached_message_table
194 }
195
196 pub(crate) fn meta_table(&self) -> &str {
200 &self.cached_meta_table
201 }
202
203 #[cfg(feature = "reliable_report")]
207 pub(crate) fn reliability_table(&self) -> &str {
208 &self.cached_reliability_table
209 }
210}
211
212#[async_trait]
213impl DbClient for PgClientImpl {
214 async fn add_user(&self, user: &User) -> DbResult<()> {
216 let sql = format!(
217 "INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
218 VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
219 ON CONFLICT (uaid) DO
220 UPDATE SET connected_at=EXCLUDED.connected_at,
221 router_type=EXCLUDED.router_type,
222 router_data=EXCLUDED.router_data,
223 node_id=EXCLUDED.node_id,
224 record_version=EXCLUDED.record_version,
225 version=EXCLUDED.version,
226 last_update=EXCLUDED.last_update,
227 priv_channels=EXCLUDED.priv_channels,
228 expiry=EXCLUDED.expiry
229 ;
230 ",
231 tablename = self.router_table()
232 );
233 sqlx::query(&sql)
234 .bind(user.uaid.simple().to_string()) .bind(user.connected_at as i64) .bind(&user.router_type) .bind(json!(user.router_data).to_string()) .bind(&user.node_id) .bind(user.record_version.map(|i| i as i64)) .bind(user.version.map(|v| v.simple().to_string())) .bind(user.current_timestamp.map(|i| i as i64)) .bind(
243 user.priv_channels
244 .iter()
245 .map(|v| v.to_string())
246 .collect::<Vec<String>>(),
247 )
248 .bind(self.router_expiry() as i64) .execute(&self.pool)
250 .await
251 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
252 Ok(())
253 }
254
255 async fn update_user(&self, user: &mut User) -> DbResult<bool> {
257 let cmd = format!(
258 "UPDATE {tablename} SET connected_at=$2::BIGINT,
259 router_type=$3,
260 router_data=$4,
261 node_id=$5,
262 record_version=$6::BIGINT,
263 version=$7,
264 last_update=$8::BIGINT,
265 priv_channels=$9,
266 expiry=$10::BIGINT
267 WHERE
268 uaid = $1 AND connected_at < $2::BIGINT;
269 ",
270 tablename = self.router_table()
271 );
272 let result = sqlx::query(&cmd)
273 .bind(user.uaid.simple().to_string()) .bind(user.connected_at as i64) .bind(&user.router_type) .bind(json!(user.router_data).to_string()) .bind(&user.node_id) .bind(user.record_version.map(|i| i as i64)) .bind(user.version.map(|v| v.simple().to_string())) .bind(user.current_timestamp.map(|i| i as i64)) .bind(
282 user.priv_channels
283 .iter()
284 .map(|v| v.to_string())
285 .collect::<Vec<String>>(),
286 )
287 .bind(self.router_expiry() as i64) .execute(&self.pool)
289 .await
290 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
291 Ok(result.rows_affected() > 0)
292 }
293
294 async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
296 let row: Option<PgRow> = sqlx::query(&format!(
297 "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels
298 FROM {tablename}
299 WHERE uaid = $1",
300 tablename = self.router_table()
301 ))
302 .bind(uaid.simple().to_string())
303 .fetch_optional(&self.pool)
304 .await
305 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
306
307 let Some(row) = row else {
308 return Ok(None);
309 };
310
311 let priv_channels = if let Ok(Some(channels)) =
314 row.try_get::<Option<Vec<String>>, _>("priv_channels")
315 {
316 let mut priv_channels = HashSet::new();
317 for channel in channels.iter() {
318 let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
319 priv_channels.insert(uuid);
320 }
321 priv_channels
322 } else {
323 HashSet::new()
324 };
325 let resp = User {
326 uaid: *uaid,
327 connected_at: row
328 .try_get::<i64, _>("connected_at")
329 .map_err(DbError::PgError)? as u64,
330 router_type: row
331 .try_get::<String, _>("router_type")
332 .map_err(DbError::PgError)?,
333 router_data: serde_json::from_str(
334 row.try_get::<&str, _>("router_data")
335 .map_err(DbError::PgError)?,
336 )
337 .map_err(|e| DbError::General(e.to_string()))?,
338 node_id: row
339 .try_get::<Option<String>, _>("node_id")
340 .map_err(DbError::PgError)?,
341 record_version: row
342 .try_get::<Option<i64>, _>("record_version")
343 .map_err(DbError::PgError)?
344 .map(|v| v as u64),
345 current_timestamp: row
346 .try_get::<Option<i64>, _>("last_update")
347 .map_err(DbError::PgError)?
348 .map(|v| v as u64),
349 version: row
350 .try_get::<Option<String>, _>("version")
351 .map_err(DbError::PgError)?
352 .map(|v| {
354 Uuid::from_str(&v).map_err(|e| {
355 DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
356 })
357 })
358 .transpose()?,
359 priv_channels,
360 };
361 Ok(Some(resp))
362 }
363
364 async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
366 sqlx::query(&format!(
367 "DELETE FROM {tablename}
368 WHERE uaid = $1",
369 tablename = self.router_table()
370 ))
371 .bind(uaid.simple().to_string())
372 .execute(&self.pool)
373 .await
374 .map_err(DbError::PgError)?;
375 Ok(())
376 }
377
378 async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
387 sqlx::query(&format!(
388 "INSERT
389 INTO {tablename} (uaid, channel_id) VALUES ($1, $2)
390 ON CONFLICT DO NOTHING",
391 tablename = self.meta_table()
392 ))
393 .bind(uaid.simple().to_string())
394 .bind(channel_id.simple().to_string())
395 .execute(&self.pool)
396 .await
397 .map_err(DbError::PgError)?;
398 Ok(())
399 }
400
401 async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
403 if channels.is_empty() {
404 trace!("📮 No channels to save.");
405 return Ok(());
406 };
407 let uaid_str = uaid.simple().to_string();
408
409 let mut query_builder: sqlx::QueryBuilder<sqlx::Postgres> =
410 sqlx::QueryBuilder::new(format!(
411 "INSERT INTO {tablename} (uaid, channel_id) ",
412 tablename = self.meta_table()
413 ));
414 query_builder.push_values(channels.iter(), |mut b, channel| {
415 b.push_bind(uaid_str.clone());
416 b.push_bind(channel.simple().to_string());
417 });
418 query_builder.push(" ON CONFLICT DO NOTHING");
419
420 query_builder
421 .build()
422 .execute(&self.pool)
423 .await
424 .map_err(DbError::PgError)?;
425 Ok(())
426 }
427
428 async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
430 let mut result = HashSet::new();
431 let rows: Vec<PgRow> = sqlx::query(&format!(
432 "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
433 tablename = self.meta_table()
434 ))
435 .bind(uaid.simple().to_string())
436 .fetch_all(&self.pool)
437 .await
438 .map_err(DbError::PgError)?;
439 for row in rows.iter() {
440 let s = row
441 .try_get::<&str, _>("channel_id")
442 .map_err(DbError::PgError)?;
443 result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
444 }
445 Ok(result)
446 }
447
448 async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
450 let cmd = format!(
451 "DELETE FROM {tablename}
452 WHERE uaid = $1 AND channel_id = $2;",
453 tablename = self.meta_table()
454 );
455 let result = sqlx::query(&cmd)
456 .bind(uaid.simple().to_string())
457 .bind(channel_id.simple().to_string())
458 .execute(&self.pool)
459 .await
460 .map_err(DbError::PgError)?;
461 Ok(result.rows_affected() > 0)
463 }
464
465 async fn remove_node_id(
467 &self,
468 uaid: &Uuid,
469 node_id: &str,
470 connected_at: u64,
471 version: &Option<Uuid>,
472 ) -> DbResult<bool> {
473 let Some(version) = version else {
474 return Err(DbError::General("Expected a user version field".to_owned()));
475 };
476 sqlx::query(&format!(
477 "UPDATE {tablename}
478 SET node_id = null
479 WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
480 tablename = self.router_table()
481 ))
482 .bind(uaid.simple().to_string())
483 .bind(node_id)
484 .bind(connected_at as i64)
485 .bind(version.simple().to_string())
486 .execute(&self.pool)
487 .await
488 .map_err(DbError::PgError)?;
489 Ok(true)
490 }
491
492 async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
494 #[allow(unused_mut)]
497 let mut fields = vec![
498 "uaid",
499 "channel_id",
500 "chid_message_id",
501 "version",
502 "ttl",
503 "expiry",
504 "topic",
505 "timestamp",
506 "data",
507 "sortkey_timestamp",
508 "headers",
509 ];
510 #[allow(unused_mut)]
512 let mut inputs = vec![
513 "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
514 ];
515 #[cfg(feature = "reliable_report")]
516 {
517 fields.append(&mut ["reliability_id"].to_vec());
518 inputs.append(&mut ["$12"].to_vec());
519 }
520 let cmd = format!(
521 "INSERT INTO {tablename}
522 ({fields})
523 VALUES
524 ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
525 uaid=EXCLUDED.uaid,
526 channel_id=EXCLUDED.channel_id,
527 version=EXCLUDED.version,
528 ttl=EXCLUDED.ttl,
529 expiry=EXCLUDED.expiry,
530 topic=EXCLUDED.topic,
531 timestamp=EXCLUDED.timestamp,
532 data=EXCLUDED.data,
533 sortkey_timestamp=EXCLUDED.sortkey_timestamp,
534 headers=EXCLUDED.headers",
535 tablename = &self.message_table(),
536 fields = fields.join(","),
537 inputs = inputs.join(",")
538 );
539 #[allow(unused_mut)]
540 let mut q = sqlx::query(&cmd)
541 .bind(uaid.simple().to_string())
542 .bind(message.channel_id.simple().to_string())
543 .bind(message.chidmessageid())
544 .bind(&message.version)
545 .bind(message.ttl as i64)
546 .bind(util::sec_since_epoch() as i64 + message.ttl as i64)
547 .bind(message.topic.as_ref().filter(|v| !v.is_empty()).cloned())
548 .bind(message.timestamp as i64)
549 .bind(message.data.as_deref().unwrap_or_default().to_owned())
550 .bind(message.sortkey_timestamp.map(|v| v as i64))
551 .bind(json!(message.headers).to_string());
552 #[cfg(feature = "reliable_report")]
553 {
554 q = q.bind(message.reliability_id.clone());
555 }
556 q.execute(&self.pool)
557 .await
558 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
559 Ok(())
560 }
561
562 async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
564 debug!(
565 "📮 Removing message for user {} with chid_message_id {}",
566 uaid.simple(),
567 chidmessageid
568 );
569 let result = sqlx::query(&format!(
570 "DELETE FROM {tablename}
571 WHERE uaid=$1 AND chid_message_id = $2;",
572 tablename = self.message_table()
573 ))
574 .bind(uaid.simple().to_string())
575 .bind(chidmessageid)
576 .execute(&self.pool)
577 .await
578 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
579 debug!(
580 "📮 Deleted {} rows for user {} with chid_message_id {}",
581 result.rows_affected(),
582 uaid.simple(),
583 chidmessageid
584 );
585 Ok(())
586 }
587
588 async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
589 if messages.is_empty() {
590 return Ok(());
591 }
592 if messages.len() == 1 {
593 return self
595 .save_message(uaid, messages.into_iter().next().unwrap())
596 .await;
597 }
598
599 let field_names = {
600 #[allow(unused_mut)]
601 let mut fields = vec![
602 "uaid",
603 "channel_id",
604 "chid_message_id",
605 "version",
606 "ttl",
607 "expiry",
608 "topic",
609 "timestamp",
610 "data",
611 "sortkey_timestamp",
612 "headers",
613 ];
614 #[cfg(feature = "reliable_report")]
615 fields.push("reliability_id");
616 fields.join(",")
617 };
618
619 let uaid_str = uaid.simple().to_string();
620 let now = util::sec_since_epoch() as i64;
621
622 struct MessageParams {
624 channel_id: String,
625 chidmessageid: String,
626 version: String,
627 ttl: i64,
628 expiry: i64,
629 topic: Option<String>,
630 timestamp: i64,
631 data: String,
632 sortkey_timestamp: Option<i64>,
633 headers: String,
634 #[cfg(feature = "reliable_report")]
635 reliability_id: Option<String>,
636 }
637
638 let msg_params: Vec<MessageParams> = messages
639 .into_iter()
640 .map(|m| {
641 let topic = m.topic.as_ref().filter(|v| !v.is_empty()).cloned();
642 MessageParams {
643 channel_id: m.channel_id.simple().to_string(),
644 chidmessageid: m.chidmessageid(),
645 version: m.version,
646 ttl: m.ttl as i64,
647 expiry: now + m.ttl as i64,
648 topic,
649 timestamp: m.timestamp as i64,
650 data: m.data.as_deref().unwrap_or_default().to_owned(),
651 sortkey_timestamp: m.sortkey_timestamp.map(|v| v as i64),
652 headers: json!(m.headers).to_string(),
653 #[cfg(feature = "reliable_report")]
654 reliability_id: m.reliability_id,
655 }
656 })
657 .collect();
658
659 let mut query_builder: sqlx::QueryBuilder<sqlx::Postgres> =
660 sqlx::QueryBuilder::new(format!(
661 "INSERT INTO {tablename} ({field_names}) ",
662 tablename = &self.message_table(),
663 ));
664 query_builder.push_values(msg_params.iter(), |mut query, mp| {
666 query.push_bind(uaid_str.clone());
667 query.push_bind(mp.channel_id.clone());
668 query.push_bind(mp.chidmessageid.clone());
669 query.push_bind(mp.version.clone());
670 query.push_bind(mp.ttl);
671 query.push_bind(mp.expiry);
672 query.push_bind(mp.topic.clone());
673 query.push_bind(mp.timestamp);
674 query.push_bind(mp.data.clone());
675 query.push_bind(mp.sortkey_timestamp);
676 query.push_bind(mp.headers.clone());
677 #[cfg(feature = "reliable_report")]
678 query.push_bind(mp.reliability_id.clone());
679 });
680 query_builder.push(
681 " ON CONFLICT (chid_message_id) DO UPDATE SET
682 uaid=EXCLUDED.uaid,
683 channel_id=EXCLUDED.channel_id,
684 version=EXCLUDED.version,
685 ttl=EXCLUDED.ttl,
686 expiry=EXCLUDED.expiry,
687 topic=EXCLUDED.topic,
688 timestamp=EXCLUDED.timestamp,
689 data=EXCLUDED.data,
690 sortkey_timestamp=EXCLUDED.sortkey_timestamp,
691 headers=EXCLUDED.headers",
692 );
693
694 query_builder
695 .build()
696 .execute(&self.pool)
697 .await
698 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
699 Ok(())
700 }
701
702 async fn fetch_topic_messages(
705 &self,
706 uaid: &Uuid,
707 limit: usize,
708 ) -> DbResult<FetchMessageResponse> {
709 let messages: Vec<Notification> = sqlx::query(&format!(
710 "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers
711 FROM {tablename}
712 WHERE uaid=$1 AND expiry >= $2 AND (topic IS NOT NULL AND topic != '')
713 ORDER BY timestamp DESC
714 LIMIT $3",
715 tablename = &self.message_table(),
716 ))
717 .bind(uaid.simple().to_string())
718 .bind(util::sec_since_epoch() as i64)
719 .bind(limit as i64)
720 .fetch_all(&self.pool)
721 .await
722 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?
723 .iter()
724 .map(|row: &PgRow| row.try_into())
725 .collect::<Result<Vec<Notification>, DbError>>()?;
726
727 if messages.is_empty() {
728 Ok(Default::default())
729 } else {
730 Ok(FetchMessageResponse {
731 timestamp: Some(messages[0].timestamp),
732 messages,
733 })
734 }
735 }
736
737 async fn fetch_timestamp_messages(
739 &self,
740 uaid: &Uuid,
741 timestamp: Option<u64>,
742 limit: usize,
743 ) -> DbResult<FetchMessageResponse> {
744 let uaid = uaid.simple().to_string();
745 let response: Vec<PgRow> = if let Some(ts) = timestamp {
746 trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
747 sqlx::query(&format!(
748 "SELECT * FROM {}
749 WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
750 ORDER BY timestamp
751 LIMIT $4",
752 self.message_table()
753 ))
754 .bind(&uaid)
755 .bind(ts as i64)
756 .bind(util::sec_since_epoch() as i64)
757 .bind(limit as i64)
758 .fetch_all(&self.pool)
759 .await
760 } else {
761 trace!("📮 Fetching messages for user {}", &uaid);
762 sqlx::query(&format!(
763 "SELECT *
764 FROM {}
765 WHERE uaid = $1
766 AND expiry >= $2
767 LIMIT $3",
768 self.message_table()
769 ))
770 .bind(&uaid)
771 .bind(util::sec_since_epoch() as i64)
772 .bind(limit as i64)
773 .fetch_all(&self.pool)
774 .await
775 }
776 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
777 let messages: Vec<Notification> = response
778 .iter()
779 .map(|row: &PgRow| row.try_into())
780 .collect::<Result<Vec<Notification>, DbError>>()?;
781 let timestamp = if !messages.is_empty() {
782 Some(messages[0].timestamp)
783 } else {
784 None
785 };
786
787 Ok(FetchMessageResponse {
788 timestamp,
789 messages,
790 })
791 }
792
793 async fn router_table_exists(&self) -> DbResult<bool> {
795 self.table_exists(self.router_table()).await
796 }
797
798 async fn message_table_exists(&self) -> DbResult<bool> {
800 self.table_exists(self.message_table()).await
801 }
802
803 #[cfg(feature = "reliable_report")]
804 async fn log_report(
805 &self,
806 reliability_id: &str,
807 new_state: crate::reliability::ReliabilityState,
808 ) -> DbResult<()> {
809 let timestamp_epoch = (std::time::SystemTime::now()
810 + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64))
811 .duration_since(std::time::UNIX_EPOCH)
812 .map_err(|e| DbError::General(format!("System time before UNIX epoch: {e}")))?
813 .as_secs() as i64;
814 debug!("📮 Logging report for {reliability_id} as {new_state}");
815
816 let tablename = &self.reliability_table();
817 let state = new_state.to_string();
818 sqlx::query(&format!(
819 "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
820 ON CONFLICT (id) DO
821 UPDATE SET states = EXCLUDED.states,
822 last_update_timestamp = EXCLUDED.last_update_timestamp;",
823 tablename = tablename
824 ))
825 .bind(reliability_id)
826 .bind(&state)
827 .bind(timestamp_epoch)
828 .execute(&self.pool)
829 .await
830 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
831 Ok(())
832 }
833
834 async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
835 debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
836 let tablename = &self.router_table();
837
838 trace!("📮 Purging git{uaid} for < {timestamp}");
839 let mut tx = self.pool.begin().await.map_err(DbError::PgError)?;
840 sqlx::query(&format!(
842 "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
843 &self.message_table()
844 ))
845 .bind(uaid.simple().to_string())
846 .bind(util::sec_since_epoch() as i64)
847 .execute(&mut *tx)
848 .await
849 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
850 sqlx::query(&format!(
852 "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
853 &self.message_table()
854 ))
855 .bind(uaid.simple().to_string())
856 .bind(timestamp as i64)
857 .execute(&mut *tx)
858 .await
859 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
860 sqlx::query(&format!(
861 "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
862 ))
863 .bind(uaid.simple().to_string())
864 .bind(timestamp as i64)
865 .bind(self.router_expiry() as i64)
866 .execute(&mut *tx)
867 .await
868 .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
869 tx.commit().await.map_err(DbError::PgError)?;
870 Ok(())
871 }
872
873 fn name(&self) -> String {
874 "Postgres".to_owned()
875 }
876
877 async fn health_check(&self) -> DbResult<bool> {
878 let row: (bool,) = sqlx::query_as("select true")
879 .fetch_one(&self.pool)
880 .await
881 .map_err(DbError::PgError)?;
882 if !row.0 {
883 error!("📮 Failed to fetch from database");
884 return Ok(false);
885 }
886 if !self.router_table_exists().await? {
887 error!("📮 Router table does not exist");
888 return Ok(false);
889 }
890 if !self.message_table_exists().await? {
891 error!("📮 Message table does not exist");
892 return Ok(false);
893 }
894 Ok(true)
895 }
896
897 fn box_clone(&self) -> Box<dyn DbClient> {
899 Box::new(self.clone())
900 }
901}
902
903#[cfg(test)]
916mod tests {
917 use crate::util::sec_since_epoch;
918 use crate::{logging::init_test_logging, util::ms_since_epoch};
919 use rand::prelude::*;
920 use serde_json::json;
921 use std::env;
922
923 use super::*;
924 const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
925 const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
926
927 fn new_client() -> DbResult<PgClientImpl> {
928 let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
930 let env_dsn = format!("postgres://{host}");
931 debug!("📮 Connecting to {env_dsn}");
932 let settings = DbSettings {
933 dsn: Some(env_dsn),
934 db_settings: json!(PostgresDbSettings {
935 schema: Some("autopush".to_owned()),
936 ..Default::default()
937 })
938 .to_string(),
939 };
940 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
941 PgClientImpl::new(metrics, &settings)
942 }
943
944 fn gen_test_user() -> String {
945 let mut rng = rand::rng();
947 let test_num = rng.random::<u8>();
948 format!(
949 "DEADBEEF-0000-0000-{:04}-{:012}",
950 test_num,
951 sec_since_epoch()
952 )
953 }
954
955 #[actix_rt::test]
956 async fn health_check() {
957 let client = new_client().unwrap();
958
959 let result = client.health_check().await;
960 assert!(result.is_ok());
961 assert!(result.unwrap());
962 }
963
964 #[actix_rt::test]
966 async fn wipe_expired() -> DbResult<()> {
967 init_test_logging();
968 let client = new_client()?;
969
970 let connected_at = ms_since_epoch();
971
972 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
973 let chid = Uuid::parse_str(TEST_CHID).unwrap();
974
975 let node_id = "test_node".to_owned();
976
977 let _ = client.remove_user(&uaid).await;
979
980 let test_user = User {
981 uaid,
982 router_type: "webpush".to_owned(),
983 connected_at,
984 router_data: None,
985 node_id: Some(node_id.clone()),
986 ..Default::default()
987 };
988
989 let _ = client.remove_user(&uaid).await;
992
993 let timestamp = sec_since_epoch();
995 client.add_user(&test_user).await?;
996 let test_notification = crate::db::Notification {
997 channel_id: chid,
998 version: "test".to_owned(),
999 ttl: 1,
1000 timestamp,
1001 data: Some("Encrypted".into()),
1002 sortkey_timestamp: Some(timestamp),
1003 ..Default::default()
1004 };
1005 client.save_message(&uaid, test_notification).await?;
1006 client.increment_storage(&uaid, timestamp + 1).await?;
1007 let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1008 assert_eq!(msgs.messages.len(), 0);
1009 assert!(client.remove_user(&uaid).await.is_ok());
1010 Ok(())
1011 }
1012
1013 #[actix_rt::test]
1016 async fn run_gauntlet() -> DbResult<()> {
1017 init_test_logging();
1018 let client = new_client()?;
1019
1020 let connected_at = ms_since_epoch();
1021
1022 let user_id = &gen_test_user();
1023 let uaid = Uuid::parse_str(user_id).unwrap();
1024 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1025 let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
1026
1027 let node_id = "test_node".to_owned();
1028
1029 let _ = client.remove_user(&uaid).await;
1031
1032 let test_user = User {
1033 uaid,
1034 router_type: "webpush".to_owned(),
1035 connected_at,
1036 router_data: None,
1037 node_id: Some(node_id.clone()),
1038 ..Default::default()
1039 };
1040
1041 let _ = client.remove_user(&uaid).await;
1044
1045 trace!("📮 Adding user {}", &user_id);
1047 client.add_user(&test_user).await?;
1048 let fetched = client.get_user(&uaid).await?;
1049 assert!(fetched.is_some());
1050 let fetched = fetched.unwrap();
1051 assert_eq!(fetched.router_type, "webpush".to_owned());
1052
1053 trace!("📮 Adding channel {} to user {}", &chid, &user_id);
1055 client.add_channel(&uaid, &chid).await?;
1056 let channels = client.get_channels(&uaid).await?;
1057 assert!(channels.contains(&chid));
1058
1059 let mut new_channels: HashSet<Uuid> = HashSet::new();
1061 trace!("📮 Adding multiple channels to user {}", &user_id);
1062 new_channels.insert(chid);
1063 for _ in 1..10 {
1064 new_channels.insert(uuid::Uuid::new_v4());
1065 }
1066 let chid_to_remove = uuid::Uuid::new_v4();
1067 trace!(
1068 "📮 Adding removable channel {} to user {}",
1069 &chid_to_remove,
1070 &user_id
1071 );
1072 new_channels.insert(chid_to_remove);
1073 client.add_channels(&uaid, new_channels.clone()).await?;
1074 let channels = client.get_channels(&uaid).await?;
1075 assert_eq!(channels, new_channels);
1076
1077 trace!(
1079 "📮 Removing channel {} from user {}",
1080 &chid_to_remove,
1081 &user_id
1082 );
1083 assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1084 trace!(
1085 "📮 retrying Removing channel {} from user {}",
1086 &chid_to_remove,
1087 &user_id
1088 );
1089 assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1090 new_channels.remove(&chid_to_remove);
1091 let channels = client.get_channels(&uaid).await?;
1092 assert_eq!(channels, new_channels);
1093
1094 let mut updated = User {
1098 connected_at,
1099 ..test_user.clone()
1100 };
1101 trace!(
1102 "📮 Attempting to update user {} with old connected_at: {}",
1103 &user_id,
1104 &updated.connected_at
1105 );
1106 let result = client.update_user(&mut updated).await;
1107 assert!(result.is_ok());
1108 assert!(!result.unwrap());
1109
1110 let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1112 assert_eq!(fetched.connected_at, fetched2.connected_at);
1113
1114 let mut updated = User {
1116 connected_at: fetched.connected_at + 300,
1117 ..fetched2
1118 };
1119 trace!(
1120 "📮 Attempting to update user {} with new connected_at",
1121 &user_id
1122 );
1123 let result = client.update_user(&mut updated).await;
1124 assert!(result.is_ok());
1125 assert!(result.unwrap());
1126 assert_ne!(
1127 fetched2.connected_at,
1128 client.get_user(&uaid).await?.unwrap().connected_at
1129 );
1130 trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1132 client
1133 .increment_storage(&fetched.uaid, sec_since_epoch())
1134 .await?;
1135
1136 let test_data = "An_encrypted_pile_of_crap".to_owned();
1137 let timestamp = sec_since_epoch();
1138 let sort_key = sec_since_epoch();
1139 let fetch_timestamp = timestamp;
1140 let test_notification = crate::db::Notification {
1142 channel_id: chid,
1143 version: "test".to_owned(),
1144 ttl: 300,
1145 timestamp,
1146 data: Some(test_data.clone()),
1147 sortkey_timestamp: Some(sort_key),
1148 ..Default::default()
1149 };
1150 trace!("📮 Saving message for user {}", &user_id);
1151 let res = client.save_message(&uaid, test_notification.clone()).await;
1152 assert!(res.is_ok());
1153
1154 trace!("📮 Fetching all messages for user {}", &user_id);
1155 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1156 assert_ne!(fetched.messages.len(), 0);
1157 let fm = fetched.messages.pop().unwrap();
1158 assert_eq!(fm.channel_id, test_notification.channel_id);
1159 assert_eq!(fm.data, Some(test_data));
1160
1161 trace!(
1163 "📮 Fetching messages for user {} within the past 10 seconds",
1164 &user_id
1165 );
1166 let fetched = client
1167 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1168 .await?;
1169 assert_ne!(fetched.messages.len(), 0);
1170
1171 trace!(
1173 "📮 Fetching messages for user {} 10 seconds in the future",
1174 &user_id
1175 );
1176 let fetched = client
1177 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1178 .await?;
1179 assert_eq!(fetched.messages.len(), 0);
1180
1181 trace!(
1183 "📮 Removing message for user {} :: {}",
1184 &user_id,
1185 &test_notification.chidmessageid()
1186 );
1187 assert!(client
1188 .remove_message(&uaid, &test_notification.chidmessageid())
1189 .await
1190 .is_ok());
1191
1192 trace!("📮 Removing channel for user {}", &user_id);
1193 assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1194
1195 trace!("📮 Making sure no messages remain for user {}", &user_id);
1196 let msgs = client
1197 .fetch_timestamp_messages(&uaid, None, 999)
1198 .await?
1199 .messages;
1200 assert!(msgs.is_empty());
1201
1202 client.add_channel(&uaid, &topic_chid).await?;
1206 let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1207 let timestamp = sec_since_epoch();
1208 let sort_key = sec_since_epoch();
1209
1210 let test_notification_0 = crate::db::Notification {
1212 channel_id: topic_chid,
1213 version: "version0".to_owned(),
1214 ttl: 300,
1215 topic: Some("topic".to_owned()),
1216 timestamp,
1217 data: Some(test_data.clone()),
1218 sortkey_timestamp: Some(sort_key),
1219 ..Default::default()
1220 };
1221 assert!(client
1222 .save_message(&uaid, test_notification_0.clone())
1223 .await
1224 .is_ok());
1225
1226 let test_notification = crate::db::Notification {
1227 timestamp: sec_since_epoch(),
1228 version: "version1".to_owned(),
1229 sortkey_timestamp: Some(sort_key + 10),
1230 ..test_notification_0
1231 };
1232
1233 assert!(client
1234 .save_message(&uaid, test_notification.clone())
1235 .await
1236 .is_ok());
1237
1238 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1239 assert_eq!(fetched.messages.len(), 1);
1240 let fm = fetched.messages.pop().unwrap();
1241 assert_eq!(fm.channel_id, test_notification.channel_id);
1242 assert_eq!(fm.data, Some(test_data));
1243
1244 let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1246 assert_ne!(fetched.messages.len(), 0);
1247
1248 assert!(client
1250 .remove_message(&uaid, &test_notification.chidmessageid())
1251 .await
1252 .is_ok());
1253
1254 assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1255
1256 let msgs = client
1257 .fetch_timestamp_messages(&uaid, None, 999)
1258 .await?
1259 .messages;
1260 assert!(msgs.is_empty());
1261
1262 let fetched = client.get_user(&uaid).await?.unwrap();
1263 assert!(client
1264 .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1265 .await
1266 .is_ok());
1267 let fetched = client.get_user(&uaid).await?.unwrap();
1269 assert_eq!(fetched.node_id, None);
1270
1271 assert!(client.remove_user(&uaid).await.is_ok());
1272
1273 assert!(client.get_user(&uaid).await?.is_none());
1274 Ok(())
1275 }
1276
1277 #[actix_rt::test]
1278 async fn test_expiry() -> DbResult<()> {
1279 init_test_logging();
1281 let client = new_client()?;
1282
1283 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1284 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1285 let now = sec_since_epoch();
1286
1287 let test_notification = crate::db::Notification {
1288 channel_id: chid,
1289 version: "test".to_owned(),
1290 ttl: 2,
1291 timestamp: now,
1292 data: Some("SomeData".into()),
1293 sortkey_timestamp: Some(now),
1294 ..Default::default()
1295 };
1296 client
1297 .add_user(&User {
1298 uaid,
1299 router_type: "test".to_owned(),
1300 connected_at: ms_since_epoch(),
1301 ..Default::default()
1302 })
1303 .await?;
1304 client.add_channel(&uaid, &chid).await?;
1305 debug!("🧪Writing test notif");
1306 client
1307 .save_message(&uaid, test_notification.clone())
1308 .await?;
1309 let key = uaid.simple().to_string();
1310 debug!("🧪Checking {}...", &key);
1311 let msg = client
1312 .fetch_timestamp_messages(&uaid, None, 1)
1313 .await?
1314 .messages
1315 .pop();
1316 assert!(msg.is_some());
1317 debug!("🧪Purging...");
1318 client.increment_storage(&uaid, now + 2).await?;
1319 debug!("🧪Checking for empty {}...", &key);
1320 let cc = client
1321 .fetch_timestamp_messages(&uaid, None, 1)
1322 .await?
1323 .messages
1324 .pop();
1325 assert!(cc.is_none());
1326 assert!(client.remove_user(&uaid).await.is_ok());
1328 Ok(())
1329 }
1330}