autopush_common/db/postgres/
mod.rs

1/* Postgres DbClient implementation.
2 * As noted elsewhere, autopush was originally designed to work with NoSql type databases.
3 * This implementation was done partially as an experiment. Postgres allows for limited
4 * NoSql-like functionality. The author, however, has VERY limited knowledge of postgres,
5 * and there are likely many inefficiencies in this implementation.
6 *
7 * PRs are always welcome.
8 */
9
10use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::Duration;
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use serde_json::json;
19use sqlx::postgres::PgRow;
20use sqlx::{PgPool, Row};
21
22use async_trait::async_trait;
23
24/// Extract a human-readable error message from an sqlx::Error.
25fn sqlx_err_to_string(e: &sqlx::Error) -> String {
26    match e {
27        sqlx::Error::Database(db_err) => db_err.message().to_owned(),
28        _ => e.to_string(),
29    }
30}
31use cadence::StatsdClient;
32use serde::{Deserialize, Serialize};
33use uuid::Uuid;
34
35use crate::db::client::DbClient;
36use crate::db::error::{DbError, DbResult};
37use crate::db::{DbSettings, User};
38use crate::notification::Notification;
39use crate::{util, MAX_ROUTER_TTL_SECS};
40
41use super::client::FetchMessageResponse;
42
43#[cfg(feature = "reliable_report")]
44const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
47#[serde(default)]
48pub struct PostgresDbSettings {
49    pub schema: Option<String>,    // Optional DB Schema
50    pub router_table: String,      // Routing info
51    pub message_table: String,     // Message storage info
52    pub meta_table: String,        // Channels and meta info
53    pub reliability_table: String, // Channels and meta info
54    max_router_ttl: u64,           // Max time for router records to live.
55                                   // #[serde(default)]
56                                   // pub use_tls: bool // Should you use a TLS connection to the db.
57}
58
59impl Default for PostgresDbSettings {
60    fn default() -> Self {
61        Self {
62            schema: None,
63            router_table: "router".to_owned(),
64            message_table: "message".to_owned(),
65            meta_table: "meta".to_owned(),
66            reliability_table: "reliability".to_owned(),
67            max_router_ttl: MAX_ROUTER_TTL_SECS,
68            // use_tls: false,
69        }
70    }
71}
72
73impl TryFrom<&str> for PostgresDbSettings {
74    type Error = DbError;
75    fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
76        if setting_string.trim().is_empty() {
77            return Ok(PostgresDbSettings::default());
78        }
79        serde_json::from_str(setting_string).map_err(|e| {
80            DbError::General(format!(
81                "Could not parse configuration db_settings: {:?}",
82                e
83            ))
84        })
85    }
86}
87
88#[derive(Clone)]
89pub struct PgClientImpl {
90    _metrics: Arc<StatsdClient>,
91    db_settings: PostgresDbSettings,
92    pool: PgPool,
93    /// Cached fully-qualified table names to avoid repeated format!/clone per query
94    cached_router_table: String,
95    cached_message_table: String,
96    cached_meta_table: String,
97    #[cfg(feature = "reliable_report")]
98    cached_reliability_table: String,
99}
100
101impl PgClientImpl {
102    /// Create a new Postgres Client.
103    ///
104    /// This uses the `settings.db_dsn`. to try and connect to the postgres database.
105    /// See https://docs.rs/sqlx/latest/sqlx/postgres/struct.PgConnectOptions.html
106    /// for parameter details and requirements.
107    /// Example DSN: postgresql://user:password@host/database?option=val
108    /// e.g. (postgresql://scott:tiger@dbhost/autopush?connect_timeout=10&keepalives_idle=3600)
109    pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
110        let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
111        info!(
112            "📮 Initializing Postgres DB Client with settings: {:?} from {:?}",
113            db_settings, &settings.db_settings
114        );
115        if let Some(dsn) = settings.dsn.clone() {
116            trace!("📮 Postgres Connect {}", &dsn);
117
118            // Validate the DSN eagerly so misconfigurations surface at startup
119            let connect_options = sqlx::postgres::PgConnectOptions::from_str(&dsn)
120                .map_err(|e| DbError::General(format!("Invalid Postgres DSN: {e}")))?;
121            let pool = sqlx::postgres::PgPoolOptions::new().connect_lazy_with(connect_options);
122            let cached_router_table = if let Some(schema) = &db_settings.schema {
123                format!("{}.{}", schema, db_settings.router_table)
124            } else {
125                db_settings.router_table.clone()
126            };
127            let cached_message_table = if let Some(schema) = &db_settings.schema {
128                format!("{}.{}", schema, db_settings.message_table)
129            } else {
130                db_settings.message_table.clone()
131            };
132            let cached_meta_table = if let Some(schema) = &db_settings.schema {
133                format!("{}.{}", schema, db_settings.meta_table)
134            } else {
135                db_settings.meta_table.clone()
136            };
137            #[cfg(feature = "reliable_report")]
138            let cached_reliability_table = if let Some(schema) = &db_settings.schema {
139                format!("{}.{}", schema, db_settings.reliability_table)
140            } else {
141                db_settings.reliability_table.clone()
142            };
143            return Ok(Self {
144                _metrics: metrics,
145                db_settings,
146                pool,
147                cached_router_table,
148                cached_message_table,
149                cached_meta_table,
150                #[cfg(feature = "reliable_report")]
151                cached_reliability_table,
152            });
153        };
154        Err(DbError::ConnectionError("No DSN specified".to_owned()))
155    }
156
157    /// Does the given table exist
158    async fn table_exists(&self, table_name: &str) -> DbResult<bool> {
159        let (schema, table_name) = if table_name.contains('.') {
160            let mut parts = table_name.splitn(2, '.');
161            (
162                parts.next().unwrap_or("public").to_owned(),
163                // If we are in a situation where someone specified a table name as
164                // `whatever.`, then we should absolutely panic here.
165                parts.next().unwrap().to_owned(),
166            )
167        } else {
168            ("public".to_owned(), table_name.to_owned())
169        };
170        let row: (bool,) = sqlx::query_as(
171            "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
172        )
173        .bind(&schema)
174        .bind(&table_name)
175        .fetch_one(&self.pool)
176        .await
177        .map_err(DbError::PgError)?;
178        Ok(row.0)
179    }
180
181    /// Return the router's expiration timestamp
182    fn router_expiry(&self) -> u64 {
183        util::sec_since_epoch() + self.db_settings.max_router_ttl
184    }
185
186    /// The router table contains how to route messages to the recipient UAID.
187    pub(crate) fn router_table(&self) -> &str {
188        &self.cached_router_table
189    }
190
191    /// The message table contains stored messages for UAIDs.
192    pub(crate) fn message_table(&self) -> &str {
193        &self.cached_message_table
194    }
195
196    /// The meta table contains channel and other metadata for UAIDs.
197    /// With traditional "No-Sql" databases, this would be rolled into the
198    /// router table.
199    pub(crate) fn meta_table(&self) -> &str {
200        &self.cached_meta_table
201    }
202
203    /// The reliability table contains message delivery reliability states.
204    /// This is optional and should only be used to track internally generated
205    /// and consumed messages based on the VAPID public key signature.
206    #[cfg(feature = "reliable_report")]
207    pub(crate) fn reliability_table(&self) -> &str {
208        &self.cached_reliability_table
209    }
210}
211
212#[async_trait]
213impl DbClient for PgClientImpl {
214    /// add user to router_table if not exists uaid
215    async fn add_user(&self, user: &User) -> DbResult<()> {
216        let sql = format!(
217            "INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
218             VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
219             ON CONFLICT (uaid) DO
220                UPDATE SET connected_at=EXCLUDED.connected_at,
221                    router_type=EXCLUDED.router_type,
222                    router_data=EXCLUDED.router_data,
223                    node_id=EXCLUDED.node_id,
224                    record_version=EXCLUDED.record_version,
225                    version=EXCLUDED.version,
226                    last_update=EXCLUDED.last_update,
227                    priv_channels=EXCLUDED.priv_channels,
228                    expiry=EXCLUDED.expiry
229                    ;
230            ",
231            tablename = self.router_table()
232        );
233        sqlx::query(&sql)
234            .bind(user.uaid.simple().to_string()) // 1
235            .bind(user.connected_at as i64) // 2
236            .bind(&user.router_type) // 3
237            .bind(json!(user.router_data).to_string()) // 4
238            .bind(&user.node_id) // 5
239            .bind(user.record_version.map(|i| i as i64)) // 6
240            .bind(user.version.map(|v| v.simple().to_string())) // 7
241            .bind(user.current_timestamp.map(|i| i as i64)) // 8
242            .bind(
243                user.priv_channels
244                    .iter()
245                    .map(|v| v.to_string())
246                    .collect::<Vec<String>>(),
247            )
248            .bind(self.router_expiry() as i64) // 10
249            .execute(&self.pool)
250            .await
251            .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
252        Ok(())
253    }
254
255    /// update user record in router_table at user.uaid
256    async fn update_user(&self, user: &mut User) -> DbResult<bool> {
257        let cmd = format!(
258            "UPDATE {tablename} SET connected_at=$2::BIGINT,
259                router_type=$3,
260                router_data=$4,
261                node_id=$5,
262                record_version=$6::BIGINT,
263                version=$7,
264                last_update=$8::BIGINT,
265                priv_channels=$9,
266                expiry=$10::BIGINT
267            WHERE
268                uaid = $1 AND connected_at < $2::BIGINT;
269            ",
270            tablename = self.router_table()
271        );
272        let result = sqlx::query(&cmd)
273            .bind(user.uaid.simple().to_string()) // 1
274            .bind(user.connected_at as i64) // 2
275            .bind(&user.router_type) // 3
276            .bind(json!(user.router_data).to_string()) // 4
277            .bind(&user.node_id) // 5
278            .bind(user.record_version.map(|i| i as i64)) // 6
279            .bind(user.version.map(|v| v.simple().to_string())) // 7
280            .bind(user.current_timestamp.map(|i| i as i64)) // 8
281            .bind(
282                user.priv_channels
283                    .iter()
284                    .map(|v| v.to_string())
285                    .collect::<Vec<String>>(),
286            )
287            .bind(self.router_expiry() as i64) // 10
288            .execute(&self.pool)
289            .await
290            .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
291        Ok(result.rows_affected() > 0)
292    }
293
294    /// fetch user information from router_table for uaid.
295    async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
296        let row: Option<PgRow> = sqlx::query(&format!(
297            "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels
298             FROM {tablename}
299             WHERE uaid = $1",
300            tablename = self.router_table()
301        ))
302        .bind(uaid.simple().to_string())
303        .fetch_optional(&self.pool)
304        .await
305        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
306
307        let Some(row) = row else {
308            return Ok(None);
309        };
310
311        // I was tempted to make this a From impl, but realized that it would mean making autopush-common require a dependency.
312        // Maybe make this a deserialize?
313        let priv_channels = if let Ok(Some(channels)) =
314            row.try_get::<Option<Vec<String>>, _>("priv_channels")
315        {
316            let mut priv_channels = HashSet::new();
317            for channel in channels.iter() {
318                let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
319                priv_channels.insert(uuid);
320            }
321            priv_channels
322        } else {
323            HashSet::new()
324        };
325        let resp = User {
326            uaid: *uaid,
327            connected_at: row
328                .try_get::<i64, _>("connected_at")
329                .map_err(DbError::PgError)? as u64,
330            router_type: row
331                .try_get::<String, _>("router_type")
332                .map_err(DbError::PgError)?,
333            router_data: serde_json::from_str(
334                row.try_get::<&str, _>("router_data")
335                    .map_err(DbError::PgError)?,
336            )
337            .map_err(|e| DbError::General(e.to_string()))?,
338            node_id: row
339                .try_get::<Option<String>, _>("node_id")
340                .map_err(DbError::PgError)?,
341            record_version: row
342                .try_get::<Option<i64>, _>("record_version")
343                .map_err(DbError::PgError)?
344                .map(|v| v as u64),
345            current_timestamp: row
346                .try_get::<Option<i64>, _>("last_update")
347                .map_err(DbError::PgError)?
348                .map(|v| v as u64),
349            version: row
350                .try_get::<Option<String>, _>("version")
351                .map_err(DbError::PgError)?
352                // An invalid UUID here is a data integrity error.
353                .map(|v| {
354                    Uuid::from_str(&v).map_err(|e| {
355                        DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
356                    })
357                })
358                .transpose()?,
359            priv_channels,
360        };
361        Ok(Some(resp))
362    }
363
364    /// delete a user at uaid from router_table
365    async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
366        sqlx::query(&format!(
367            "DELETE FROM {tablename}
368             WHERE uaid = $1",
369            tablename = self.router_table()
370        ))
371        .bind(uaid.simple().to_string())
372        .execute(&self.pool)
373        .await
374        .map_err(DbError::PgError)?;
375        Ok(())
376    }
377
378    /// update list of channel_ids for uaid in meta table
379    /// Note: a conflicting channel_id is ignored, since it's already registered.
380    /// This should probably be optimized into the router table as a set value,
381    /// however I'm not familiar enough with Postgres to do so at this time.
382    /// Channels can be somewhat ephemeral, and we also want to limit the potential of
383    /// race conditions when adding or removing channels, particularly for mobile devices.
384    /// For some efficiency (mostly around the mobile "daily refresh" call), I've broken
385    /// the channels out by UAID into this table.
386    async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
387        sqlx::query(&format!(
388            "INSERT
389             INTO {tablename} (uaid, channel_id) VALUES ($1, $2)
390             ON CONFLICT DO NOTHING",
391            tablename = self.meta_table()
392        ))
393        .bind(uaid.simple().to_string())
394        .bind(channel_id.simple().to_string())
395        .execute(&self.pool)
396        .await
397        .map_err(DbError::PgError)?;
398        Ok(())
399    }
400
401    /// Save all channels in a list
402    async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
403        if channels.is_empty() {
404            trace!("📮 No channels to save.");
405            return Ok(());
406        };
407        let uaid_str = uaid.simple().to_string();
408
409        let mut query_builder: sqlx::QueryBuilder<sqlx::Postgres> =
410            sqlx::QueryBuilder::new(format!(
411                "INSERT INTO {tablename} (uaid, channel_id) ",
412                tablename = self.meta_table()
413            ));
414        query_builder.push_values(channels.iter(), |mut b, channel| {
415            b.push_bind(uaid_str.clone());
416            b.push_bind(channel.simple().to_string());
417        });
418        query_builder.push(" ON CONFLICT DO NOTHING");
419
420        query_builder
421            .build()
422            .execute(&self.pool)
423            .await
424            .map_err(DbError::PgError)?;
425        Ok(())
426    }
427
428    /// get all channels for uaid from meta table
429    async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
430        let mut result = HashSet::new();
431        let rows: Vec<PgRow> = sqlx::query(&format!(
432            "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
433            tablename = self.meta_table()
434        ))
435        .bind(uaid.simple().to_string())
436        .fetch_all(&self.pool)
437        .await
438        .map_err(DbError::PgError)?;
439        for row in rows.iter() {
440            let s = row
441                .try_get::<&str, _>("channel_id")
442                .map_err(DbError::PgError)?;
443            result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
444        }
445        Ok(result)
446    }
447
448    /// remove an individual channel for a given uaid from meta table
449    async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
450        let cmd = format!(
451            "DELETE FROM {tablename}
452             WHERE uaid = $1 AND channel_id = $2;",
453            tablename = self.meta_table()
454        );
455        let result = sqlx::query(&cmd)
456            .bind(uaid.simple().to_string())
457            .bind(channel_id.simple().to_string())
458            .execute(&self.pool)
459            .await
460            .map_err(DbError::PgError)?;
461        // We sometimes want to know if the channel existed previously.
462        Ok(result.rows_affected() > 0)
463    }
464
465    /// remove node info for a uaid from router table
466    async fn remove_node_id(
467        &self,
468        uaid: &Uuid,
469        node_id: &str,
470        connected_at: u64,
471        version: &Option<Uuid>,
472    ) -> DbResult<bool> {
473        let Some(version) = version else {
474            return Err(DbError::General("Expected a user version field".to_owned()));
475        };
476        sqlx::query(&format!(
477            "UPDATE {tablename}
478                SET node_id = null
479                WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
480            tablename = self.router_table()
481        ))
482        .bind(uaid.simple().to_string())
483        .bind(node_id)
484        .bind(connected_at as i64)
485        .bind(version.simple().to_string())
486        .execute(&self.pool)
487        .await
488        .map_err(DbError::PgError)?;
489        Ok(true)
490    }
491
492    /// write a message to message table
493    async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
494        // fun fact: serde_postgres exists, but only deserializes (as of 0.2)
495        // (This is mutable if `reliable_report` enabled)
496        #[allow(unused_mut)]
497        let mut fields = vec![
498            "uaid",
499            "channel_id",
500            "chid_message_id",
501            "version",
502            "ttl",
503            "expiry",
504            "topic",
505            "timestamp",
506            "data",
507            "sortkey_timestamp",
508            "headers",
509        ];
510        // (This is mutable if `reliable_report` enabled)
511        #[allow(unused_mut)]
512        let mut inputs = vec![
513            "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
514        ];
515        #[cfg(feature = "reliable_report")]
516        {
517            fields.append(&mut ["reliability_id"].to_vec());
518            inputs.append(&mut ["$12"].to_vec());
519        }
520        let cmd = format!(
521            "INSERT INTO {tablename}
522                ({fields})
523                VALUES
524                ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
525                    uaid=EXCLUDED.uaid,
526                    channel_id=EXCLUDED.channel_id,
527                    version=EXCLUDED.version,
528                    ttl=EXCLUDED.ttl,
529                    expiry=EXCLUDED.expiry,
530                    topic=EXCLUDED.topic,
531                    timestamp=EXCLUDED.timestamp,
532                    data=EXCLUDED.data,
533                    sortkey_timestamp=EXCLUDED.sortkey_timestamp,
534                    headers=EXCLUDED.headers",
535            tablename = &self.message_table(),
536            fields = fields.join(","),
537            inputs = inputs.join(",")
538        );
539        #[allow(unused_mut)]
540        let mut q = sqlx::query(&cmd)
541            .bind(uaid.simple().to_string())
542            .bind(message.channel_id.simple().to_string())
543            .bind(message.chidmessageid())
544            .bind(&message.version)
545            .bind(message.ttl as i64)
546            .bind(util::sec_since_epoch() as i64 + message.ttl as i64)
547            .bind(message.topic.as_ref().filter(|v| !v.is_empty()).cloned())
548            .bind(message.timestamp as i64)
549            .bind(message.data.as_deref().unwrap_or_default().to_owned())
550            .bind(message.sortkey_timestamp.map(|v| v as i64))
551            .bind(json!(message.headers).to_string());
552        #[cfg(feature = "reliable_report")]
553        {
554            q = q.bind(message.reliability_id.clone());
555        }
556        q.execute(&self.pool)
557            .await
558            .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
559        Ok(())
560    }
561
562    /// remove a given message from the message table
563    async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
564        debug!(
565            "📮 Removing message for user {} with chid_message_id {}",
566            uaid.simple(),
567            chidmessageid
568        );
569        let result = sqlx::query(&format!(
570            "DELETE FROM {tablename}
571             WHERE uaid=$1 AND chid_message_id = $2;",
572            tablename = self.message_table()
573        ))
574        .bind(uaid.simple().to_string())
575        .bind(chidmessageid)
576        .execute(&self.pool)
577        .await
578        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
579        debug!(
580            "📮 Deleted {} rows for user {} with chid_message_id {}",
581            result.rows_affected(),
582            uaid.simple(),
583            chidmessageid
584        );
585        Ok(())
586    }
587
588    async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
589        if messages.is_empty() {
590            return Ok(());
591        }
592        if messages.len() == 1 {
593            // Fast path: single message doesn't need batch construction
594            return self
595                .save_message(uaid, messages.into_iter().next().unwrap())
596                .await;
597        }
598
599        let field_names = {
600            #[allow(unused_mut)]
601            let mut fields = vec![
602                "uaid",
603                "channel_id",
604                "chid_message_id",
605                "version",
606                "ttl",
607                "expiry",
608                "topic",
609                "timestamp",
610                "data",
611                "sortkey_timestamp",
612                "headers",
613            ];
614            #[cfg(feature = "reliable_report")]
615            fields.push("reliability_id");
616            fields.join(",")
617        };
618
619        let uaid_str = uaid.simple().to_string();
620        let now = util::sec_since_epoch() as i64;
621
622        // Pre-compute owned values for each message
623        struct MessageParams {
624            channel_id: String,
625            chidmessageid: String,
626            version: String,
627            ttl: i64,
628            expiry: i64,
629            topic: Option<String>,
630            timestamp: i64,
631            data: String,
632            sortkey_timestamp: Option<i64>,
633            headers: String,
634            #[cfg(feature = "reliable_report")]
635            reliability_id: Option<String>,
636        }
637
638        let msg_params: Vec<MessageParams> = messages
639            .into_iter()
640            .map(|m| {
641                let topic = m.topic.as_ref().filter(|v| !v.is_empty()).cloned();
642                MessageParams {
643                    channel_id: m.channel_id.simple().to_string(),
644                    chidmessageid: m.chidmessageid(),
645                    version: m.version,
646                    ttl: m.ttl as i64,
647                    expiry: now + m.ttl as i64,
648                    topic,
649                    timestamp: m.timestamp as i64,
650                    data: m.data.as_deref().unwrap_or_default().to_owned(),
651                    sortkey_timestamp: m.sortkey_timestamp.map(|v| v as i64),
652                    headers: json!(m.headers).to_string(),
653                    #[cfg(feature = "reliable_report")]
654                    reliability_id: m.reliability_id,
655                }
656            })
657            .collect();
658
659        let mut query_builder: sqlx::QueryBuilder<sqlx::Postgres> =
660            sqlx::QueryBuilder::new(format!(
661                "INSERT INTO {tablename} ({field_names}) ",
662                tablename = &self.message_table(),
663            ));
664        // `query` is provided by `push_values` to bind parameters for each row
665        query_builder.push_values(msg_params.iter(), |mut query, mp| {
666            query.push_bind(uaid_str.clone());
667            query.push_bind(mp.channel_id.clone());
668            query.push_bind(mp.chidmessageid.clone());
669            query.push_bind(mp.version.clone());
670            query.push_bind(mp.ttl);
671            query.push_bind(mp.expiry);
672            query.push_bind(mp.topic.clone());
673            query.push_bind(mp.timestamp);
674            query.push_bind(mp.data.clone());
675            query.push_bind(mp.sortkey_timestamp);
676            query.push_bind(mp.headers.clone());
677            #[cfg(feature = "reliable_report")]
678            query.push_bind(mp.reliability_id.clone());
679        });
680        query_builder.push(
681            " ON CONFLICT (chid_message_id) DO UPDATE SET
682                    uaid=EXCLUDED.uaid,
683                    channel_id=EXCLUDED.channel_id,
684                    version=EXCLUDED.version,
685                    ttl=EXCLUDED.ttl,
686                    expiry=EXCLUDED.expiry,
687                    topic=EXCLUDED.topic,
688                    timestamp=EXCLUDED.timestamp,
689                    data=EXCLUDED.data,
690                    sortkey_timestamp=EXCLUDED.sortkey_timestamp,
691                    headers=EXCLUDED.headers",
692        );
693
694        query_builder
695            .build()
696            .execute(&self.pool)
697            .await
698            .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
699        Ok(())
700    }
701
702    /// fetch topic messages for the user up to {limit}
703    /// Topic messages are auto-replacing singleton messages for a given user.
704    async fn fetch_topic_messages(
705        &self,
706        uaid: &Uuid,
707        limit: usize,
708    ) -> DbResult<FetchMessageResponse> {
709        let messages: Vec<Notification> = sqlx::query(&format!(
710            "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers
711             FROM {tablename}
712             WHERE uaid=$1 AND expiry >= $2 AND (topic IS NOT NULL AND topic != '')
713             ORDER BY timestamp DESC
714             LIMIT $3",
715            tablename = &self.message_table(),
716        ))
717        .bind(uaid.simple().to_string())
718        .bind(util::sec_since_epoch() as i64)
719        .bind(limit as i64)
720        .fetch_all(&self.pool)
721        .await
722        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?
723        .iter()
724        .map(|row: &PgRow| row.try_into())
725        .collect::<Result<Vec<Notification>, DbError>>()?;
726
727        if messages.is_empty() {
728            Ok(Default::default())
729        } else {
730            Ok(FetchMessageResponse {
731                timestamp: Some(messages[0].timestamp),
732                messages,
733            })
734        }
735    }
736
737    /// Fetch messages for a user on or after a given timestamp up to {limit}
738    async fn fetch_timestamp_messages(
739        &self,
740        uaid: &Uuid,
741        timestamp: Option<u64>,
742        limit: usize,
743    ) -> DbResult<FetchMessageResponse> {
744        let uaid = uaid.simple().to_string();
745        let response: Vec<PgRow> = if let Some(ts) = timestamp {
746            trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
747            sqlx::query(&format!(
748                "SELECT * FROM {}
749                 WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
750                 ORDER BY timestamp
751                 LIMIT $4",
752                self.message_table()
753            ))
754            .bind(&uaid)
755            .bind(ts as i64)
756            .bind(util::sec_since_epoch() as i64)
757            .bind(limit as i64)
758            .fetch_all(&self.pool)
759            .await
760        } else {
761            trace!("📮 Fetching messages for user {}", &uaid);
762            sqlx::query(&format!(
763                "SELECT *
764                 FROM {}
765                 WHERE uaid = $1
766                 AND expiry >= $2
767                 LIMIT $3",
768                self.message_table()
769            ))
770            .bind(&uaid)
771            .bind(util::sec_since_epoch() as i64)
772            .bind(limit as i64)
773            .fetch_all(&self.pool)
774            .await
775        }
776        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
777        let messages: Vec<Notification> = response
778            .iter()
779            .map(|row: &PgRow| row.try_into())
780            .collect::<Result<Vec<Notification>, DbError>>()?;
781        let timestamp = if !messages.is_empty() {
782            Some(messages[0].timestamp)
783        } else {
784            None
785        };
786
787        Ok(FetchMessageResponse {
788            timestamp,
789            messages,
790        })
791    }
792
793    /// Convenience function to check if the router table exists
794    async fn router_table_exists(&self) -> DbResult<bool> {
795        self.table_exists(self.router_table()).await
796    }
797
798    /// Convenience function to check if the message table exists
799    async fn message_table_exists(&self) -> DbResult<bool> {
800        self.table_exists(self.message_table()).await
801    }
802
803    #[cfg(feature = "reliable_report")]
804    async fn log_report(
805        &self,
806        reliability_id: &str,
807        new_state: crate::reliability::ReliabilityState,
808    ) -> DbResult<()> {
809        let timestamp_epoch = (std::time::SystemTime::now()
810            + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64))
811        .duration_since(std::time::UNIX_EPOCH)
812        .map_err(|e| DbError::General(format!("System time before UNIX epoch: {e}")))?
813        .as_secs() as i64;
814        debug!("📮 Logging report for {reliability_id} as {new_state}");
815
816        let tablename = &self.reliability_table();
817        let state = new_state.to_string();
818        sqlx::query(&format!(
819            "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
820             ON CONFLICT (id) DO
821             UPDATE SET states = EXCLUDED.states,
822             last_update_timestamp = EXCLUDED.last_update_timestamp;",
823            tablename = tablename
824        ))
825        .bind(reliability_id)
826        .bind(&state)
827        .bind(timestamp_epoch)
828        .execute(&self.pool)
829        .await
830        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
831        Ok(())
832    }
833
834    async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
835        debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
836        let tablename = &self.router_table();
837
838        trace!("📮 Purging git{uaid} for < {timestamp}");
839        let mut tx = self.pool.begin().await.map_err(DbError::PgError)?;
840        // Try to garbage collect old messages first.
841        sqlx::query(&format!(
842            "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
843            &self.message_table()
844        ))
845        .bind(uaid.simple().to_string())
846        .bind(util::sec_since_epoch() as i64)
847        .execute(&mut *tx)
848        .await
849        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
850        // Now, delete messages that we've already delivered.
851        sqlx::query(&format!(
852            "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
853            &self.message_table()
854        ))
855        .bind(uaid.simple().to_string())
856        .bind(timestamp as i64)
857        .execute(&mut *tx)
858        .await
859        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
860        sqlx::query(&format!(
861            "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
862        ))
863        .bind(uaid.simple().to_string())
864        .bind(timestamp as i64)
865        .bind(self.router_expiry() as i64)
866        .execute(&mut *tx)
867        .await
868        .map_err(|e| DbError::PgDbError(sqlx_err_to_string(&e)))?;
869        tx.commit().await.map_err(DbError::PgError)?;
870        Ok(())
871    }
872
873    fn name(&self) -> String {
874        "Postgres".to_owned()
875    }
876
877    async fn health_check(&self) -> DbResult<bool> {
878        let row: (bool,) = sqlx::query_as("select true")
879            .fetch_one(&self.pool)
880            .await
881            .map_err(DbError::PgError)?;
882        if !row.0 {
883            error!("📮 Failed to fetch from database");
884            return Ok(false);
885        }
886        if !self.router_table_exists().await? {
887            error!("📮 Router table does not exist");
888            return Ok(false);
889        }
890        if !self.message_table_exists().await? {
891            error!("📮 Message table does not exist");
892            return Ok(false);
893        }
894        Ok(true)
895    }
896
897    /// Convenience function to return self as a Boxed DbClient
898    fn box_clone(&self) -> Box<dyn DbClient> {
899        Box::new(self.clone())
900    }
901}
902
903/* Note:
904 * For preliminary testing, you will need to start a local postgres instance (see
905 * https://www.docker.com/blog/how-to-use-the-postgres-docker-official-image/) and initialize the
906 * database with `schema.psql`.
907 * Once you have, you can define the environment variable `POSTGRES_HOST` to point to the
908 * appropriate host (e.g. `postgres:post_pass@localhost:/autopush`). `new_client` will add the
909 * `postgres://` prefix automatically.
910 *
911 * TODO: Really should move the bulk of the tests to a higher level and add backend specific
912 * versions of `new_client`.
913 *
914 */
915#[cfg(test)]
916mod tests {
917    use crate::util::sec_since_epoch;
918    use crate::{logging::init_test_logging, util::ms_since_epoch};
919    use rand::prelude::*;
920    use serde_json::json;
921    use std::env;
922
923    use super::*;
924    const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
925    const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
926
927    fn new_client() -> DbResult<PgClientImpl> {
928        // Use an environment variable to potentially override the default storage test host.
929        let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
930        let env_dsn = format!("postgres://{host}");
931        debug!("📮 Connecting to {env_dsn}");
932        let settings = DbSettings {
933            dsn: Some(env_dsn),
934            db_settings: json!(PostgresDbSettings {
935                schema: Some("autopush".to_owned()),
936                ..Default::default()
937            })
938            .to_string(),
939        };
940        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
941        PgClientImpl::new(metrics, &settings)
942    }
943
944    fn gen_test_user() -> String {
945        // Create a semi-unique test user to avoid conflicting test values.
946        let mut rng = rand::rng();
947        let test_num = rng.random::<u8>();
948        format!(
949            "DEADBEEF-0000-0000-{:04}-{:012}",
950            test_num,
951            sec_since_epoch()
952        )
953    }
954
955    #[actix_rt::test]
956    async fn health_check() {
957        let client = new_client().unwrap();
958
959        let result = client.health_check().await;
960        assert!(result.is_ok());
961        assert!(result.unwrap());
962    }
963
964    /// Test if [increment_storage] correctly wipe expired messages
965    #[actix_rt::test]
966    async fn wipe_expired() -> DbResult<()> {
967        init_test_logging();
968        let client = new_client()?;
969
970        let connected_at = ms_since_epoch();
971
972        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
973        let chid = Uuid::parse_str(TEST_CHID).unwrap();
974
975        let node_id = "test_node".to_owned();
976
977        // purge the user record if it exists.
978        let _ = client.remove_user(&uaid).await;
979
980        let test_user = User {
981            uaid,
982            router_type: "webpush".to_owned(),
983            connected_at,
984            router_data: None,
985            node_id: Some(node_id.clone()),
986            ..Default::default()
987        };
988
989        // purge the old user (if present)
990        // in case a prior test failed for whatever reason.
991        let _ = client.remove_user(&uaid).await;
992
993        // can we add the user?
994        let timestamp = sec_since_epoch();
995        client.add_user(&test_user).await?;
996        let test_notification = crate::db::Notification {
997            channel_id: chid,
998            version: "test".to_owned(),
999            ttl: 1,
1000            timestamp,
1001            data: Some("Encrypted".into()),
1002            sortkey_timestamp: Some(timestamp),
1003            ..Default::default()
1004        };
1005        client.save_message(&uaid, test_notification).await?;
1006        client.increment_storage(&uaid, timestamp + 1).await?;
1007        let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1008        assert_eq!(msgs.messages.len(), 0);
1009        assert!(client.remove_user(&uaid).await.is_ok());
1010        Ok(())
1011    }
1012
1013    /// run a gauntlet of testing. These are a bit linear because they need
1014    /// to run in sequence.
1015    #[actix_rt::test]
1016    async fn run_gauntlet() -> DbResult<()> {
1017        init_test_logging();
1018        let client = new_client()?;
1019
1020        let connected_at = ms_since_epoch();
1021
1022        let user_id = &gen_test_user();
1023        let uaid = Uuid::parse_str(user_id).unwrap();
1024        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1025        let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
1026
1027        let node_id = "test_node".to_owned();
1028
1029        // purge the user record if it exists.
1030        let _ = client.remove_user(&uaid).await;
1031
1032        let test_user = User {
1033            uaid,
1034            router_type: "webpush".to_owned(),
1035            connected_at,
1036            router_data: None,
1037            node_id: Some(node_id.clone()),
1038            ..Default::default()
1039        };
1040
1041        // purge the old user (if present)
1042        // in case a prior test failed for whatever reason.
1043        let _ = client.remove_user(&uaid).await;
1044
1045        // can we add the user?
1046        trace!("📮 Adding user {}", &user_id);
1047        client.add_user(&test_user).await?;
1048        let fetched = client.get_user(&uaid).await?;
1049        assert!(fetched.is_some());
1050        let fetched = fetched.unwrap();
1051        assert_eq!(fetched.router_type, "webpush".to_owned());
1052
1053        // can we add channels?
1054        trace!("📮 Adding channel {} to user {}", &chid, &user_id);
1055        client.add_channel(&uaid, &chid).await?;
1056        let channels = client.get_channels(&uaid).await?;
1057        assert!(channels.contains(&chid));
1058
1059        // can we add lots of channels?
1060        let mut new_channels: HashSet<Uuid> = HashSet::new();
1061        trace!("📮 Adding multiple channels to user {}", &user_id);
1062        new_channels.insert(chid);
1063        for _ in 1..10 {
1064            new_channels.insert(uuid::Uuid::new_v4());
1065        }
1066        let chid_to_remove = uuid::Uuid::new_v4();
1067        trace!(
1068            "📮 Adding removable channel {} to user {}",
1069            &chid_to_remove,
1070            &user_id
1071        );
1072        new_channels.insert(chid_to_remove);
1073        client.add_channels(&uaid, new_channels.clone()).await?;
1074        let channels = client.get_channels(&uaid).await?;
1075        assert_eq!(channels, new_channels);
1076
1077        // can we remove a channel?
1078        trace!(
1079            "📮 Removing channel {} from user {}",
1080            &chid_to_remove,
1081            &user_id
1082        );
1083        assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1084        trace!(
1085            "📮 retrying Removing channel {} from user {}",
1086            &chid_to_remove,
1087            &user_id
1088        );
1089        assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1090        new_channels.remove(&chid_to_remove);
1091        let channels = client.get_channels(&uaid).await?;
1092        assert_eq!(channels, new_channels);
1093
1094        // now ensure that we can update a user that's after the time we set
1095        // prior. first ensure that we can't update a user that's before the
1096        // time we set prior to the last write
1097        let mut updated = User {
1098            connected_at,
1099            ..test_user.clone()
1100        };
1101        trace!(
1102            "📮 Attempting to update user {} with old connected_at: {}",
1103            &user_id,
1104            &updated.connected_at
1105        );
1106        let result = client.update_user(&mut updated).await;
1107        assert!(result.is_ok());
1108        assert!(!result.unwrap());
1109
1110        // Make sure that the `connected_at` wasn't modified
1111        let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1112        assert_eq!(fetched.connected_at, fetched2.connected_at);
1113
1114        // and make sure we can update a record with a later connected_at time.
1115        let mut updated = User {
1116            connected_at: fetched.connected_at + 300,
1117            ..fetched2
1118        };
1119        trace!(
1120            "📮 Attempting to update user {} with new connected_at",
1121            &user_id
1122        );
1123        let result = client.update_user(&mut updated).await;
1124        assert!(result.is_ok());
1125        assert!(result.unwrap());
1126        assert_ne!(
1127            fetched2.connected_at,
1128            client.get_user(&uaid).await?.unwrap().connected_at
1129        );
1130        // can we increment the storage for the user?
1131        trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1132        client
1133            .increment_storage(&fetched.uaid, sec_since_epoch())
1134            .await?;
1135
1136        let test_data = "An_encrypted_pile_of_crap".to_owned();
1137        let timestamp = sec_since_epoch();
1138        let sort_key = sec_since_epoch();
1139        let fetch_timestamp = timestamp;
1140        // Can we store a message?
1141        let test_notification = crate::db::Notification {
1142            channel_id: chid,
1143            version: "test".to_owned(),
1144            ttl: 300,
1145            timestamp,
1146            data: Some(test_data.clone()),
1147            sortkey_timestamp: Some(sort_key),
1148            ..Default::default()
1149        };
1150        trace!("📮 Saving message for user {}", &user_id);
1151        let res = client.save_message(&uaid, test_notification.clone()).await;
1152        assert!(res.is_ok());
1153
1154        trace!("📮 Fetching all messages for user {}", &user_id);
1155        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1156        assert_ne!(fetched.messages.len(), 0);
1157        let fm = fetched.messages.pop().unwrap();
1158        assert_eq!(fm.channel_id, test_notification.channel_id);
1159        assert_eq!(fm.data, Some(test_data));
1160
1161        // Grab all 1 of the messages that were submitted within the past 10 seconds.
1162        trace!(
1163            "📮 Fetching messages for user {} within the past 10 seconds",
1164            &user_id
1165        );
1166        let fetched = client
1167            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1168            .await?;
1169        assert_ne!(fetched.messages.len(), 0);
1170
1171        // Try grabbing a message for 10 seconds from now.
1172        trace!(
1173            "📮 Fetching messages for user {} 10 seconds in the future",
1174            &user_id
1175        );
1176        let fetched = client
1177            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1178            .await?;
1179        assert_eq!(fetched.messages.len(), 0);
1180
1181        // can we clean up our toys?
1182        trace!(
1183            "📮 Removing message for user {} :: {}",
1184            &user_id,
1185            &test_notification.chidmessageid()
1186        );
1187        assert!(client
1188            .remove_message(&uaid, &test_notification.chidmessageid())
1189            .await
1190            .is_ok());
1191
1192        trace!("📮 Removing channel for user {}", &user_id);
1193        assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1194
1195        trace!("📮 Making sure no messages remain for user {}", &user_id);
1196        let msgs = client
1197            .fetch_timestamp_messages(&uaid, None, 999)
1198            .await?
1199            .messages;
1200        assert!(msgs.is_empty());
1201
1202        // Now, can we do all that with topic messages
1203        // Unlike bigtable, we don't use [fetch_topic_messages]: it always return None:
1204        // they are handled as usuals messages.
1205        client.add_channel(&uaid, &topic_chid).await?;
1206        let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1207        let timestamp = sec_since_epoch();
1208        let sort_key = sec_since_epoch();
1209
1210        // We store 2 messages, with a single topic
1211        let test_notification_0 = crate::db::Notification {
1212            channel_id: topic_chid,
1213            version: "version0".to_owned(),
1214            ttl: 300,
1215            topic: Some("topic".to_owned()),
1216            timestamp,
1217            data: Some(test_data.clone()),
1218            sortkey_timestamp: Some(sort_key),
1219            ..Default::default()
1220        };
1221        assert!(client
1222            .save_message(&uaid, test_notification_0.clone())
1223            .await
1224            .is_ok());
1225
1226        let test_notification = crate::db::Notification {
1227            timestamp: sec_since_epoch(),
1228            version: "version1".to_owned(),
1229            sortkey_timestamp: Some(sort_key + 10),
1230            ..test_notification_0
1231        };
1232
1233        assert!(client
1234            .save_message(&uaid, test_notification.clone())
1235            .await
1236            .is_ok());
1237
1238        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1239        assert_eq!(fetched.messages.len(), 1);
1240        let fm = fetched.messages.pop().unwrap();
1241        assert_eq!(fm.channel_id, test_notification.channel_id);
1242        assert_eq!(fm.data, Some(test_data));
1243
1244        // Grab the message that was submitted.
1245        let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1246        assert_ne!(fetched.messages.len(), 0);
1247
1248        // can we clean up our toys?
1249        assert!(client
1250            .remove_message(&uaid, &test_notification.chidmessageid())
1251            .await
1252            .is_ok());
1253
1254        assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1255
1256        let msgs = client
1257            .fetch_timestamp_messages(&uaid, None, 999)
1258            .await?
1259            .messages;
1260        assert!(msgs.is_empty());
1261
1262        let fetched = client.get_user(&uaid).await?.unwrap();
1263        assert!(client
1264            .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1265            .await
1266            .is_ok());
1267        // did we remove it?
1268        let fetched = client.get_user(&uaid).await?.unwrap();
1269        assert_eq!(fetched.node_id, None);
1270
1271        assert!(client.remove_user(&uaid).await.is_ok());
1272
1273        assert!(client.get_user(&uaid).await?.is_none());
1274        Ok(())
1275    }
1276
1277    #[actix_rt::test]
1278    async fn test_expiry() -> DbResult<()> {
1279        // Make sure that we really are purging messages correctly
1280        init_test_logging();
1281        let client = new_client()?;
1282
1283        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1284        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1285        let now = sec_since_epoch();
1286
1287        let test_notification = crate::db::Notification {
1288            channel_id: chid,
1289            version: "test".to_owned(),
1290            ttl: 2,
1291            timestamp: now,
1292            data: Some("SomeData".into()),
1293            sortkey_timestamp: Some(now),
1294            ..Default::default()
1295        };
1296        client
1297            .add_user(&User {
1298                uaid,
1299                router_type: "test".to_owned(),
1300                connected_at: ms_since_epoch(),
1301                ..Default::default()
1302            })
1303            .await?;
1304        client.add_channel(&uaid, &chid).await?;
1305        debug!("🧪Writing test notif");
1306        client
1307            .save_message(&uaid, test_notification.clone())
1308            .await?;
1309        let key = uaid.simple().to_string();
1310        debug!("🧪Checking {}...", &key);
1311        let msg = client
1312            .fetch_timestamp_messages(&uaid, None, 1)
1313            .await?
1314            .messages
1315            .pop();
1316        assert!(msg.is_some());
1317        debug!("🧪Purging...");
1318        client.increment_storage(&uaid, now + 2).await?;
1319        debug!("🧪Checking for empty {}...", &key);
1320        let cc = client
1321            .fetch_timestamp_messages(&uaid, None, 1)
1322            .await?
1323            .messages
1324            .pop();
1325        assert!(cc.is_none());
1326        // clean up after the test.
1327        assert!(client.remove_user(&uaid).await.is_ok());
1328        Ok(())
1329    }
1330}