mas_storage_pg/user/
registration.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{UserEmailAuthentication, UserRegistration, UserRegistrationPassword};
11use mas_storage::{Clock, user::UserRegistrationRepository};
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use url::Url;
16use uuid::Uuid;
17
18use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
19
20/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
21/// connection
22pub struct PgUserRegistrationRepository<'c> {
23    conn: &'c mut PgConnection,
24}
25
26impl<'c> PgUserRegistrationRepository<'c> {
27    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
28    /// connection
29    pub fn new(conn: &'c mut PgConnection) -> Self {
30        Self { conn }
31    }
32}
33
34struct UserRegistrationLookup {
35    user_registration_id: Uuid,
36    ip_address: Option<IpAddr>,
37    user_agent: Option<String>,
38    post_auth_action: Option<serde_json::Value>,
39    username: String,
40    display_name: Option<String>,
41    terms_url: Option<String>,
42    email_authentication_id: Option<Uuid>,
43    hashed_password: Option<String>,
44    hashed_password_version: Option<i32>,
45    created_at: DateTime<Utc>,
46    completed_at: Option<DateTime<Utc>>,
47}
48
49impl TryFrom<UserRegistrationLookup> for UserRegistration {
50    type Error = DatabaseInconsistencyError;
51
52    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
53        let id = Ulid::from(value.user_registration_id);
54
55        let password = match (value.hashed_password, value.hashed_password_version) {
56            (Some(hashed_password), Some(version)) => {
57                let version = version.try_into().map_err(|e| {
58                    DatabaseInconsistencyError::on("user_registrations")
59                        .column("hashed_password_version")
60                        .row(id)
61                        .source(e)
62                })?;
63
64                Some(UserRegistrationPassword {
65                    hashed_password,
66                    version,
67                })
68            }
69            (None, None) => None,
70            _ => {
71                return Err(DatabaseInconsistencyError::on("user_registrations")
72                    .column("hashed_password")
73                    .row(id));
74            }
75        };
76
77        let terms_url = value
78            .terms_url
79            .map(|u| u.parse())
80            .transpose()
81            .map_err(|e| {
82                DatabaseInconsistencyError::on("user_registrations")
83                    .column("terms_url")
84                    .row(id)
85                    .source(e)
86            })?;
87
88        Ok(UserRegistration {
89            id,
90            ip_address: value.ip_address,
91            user_agent: value.user_agent,
92            post_auth_action: value.post_auth_action,
93            username: value.username,
94            display_name: value.display_name,
95            terms_url,
96            email_authentication_id: value.email_authentication_id.map(Ulid::from),
97            password,
98            created_at: value.created_at,
99            completed_at: value.completed_at,
100        })
101    }
102}
103
104#[async_trait]
105impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
106    type Error = DatabaseError;
107
108    #[tracing::instrument(
109        name = "db.user_registration.lookup",
110        skip_all,
111        fields(
112            db.query.text,
113            user_registration.id = %id,
114        ),
115        err,
116    )]
117    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
118        let res = sqlx::query_as!(
119            UserRegistrationLookup,
120            r#"
121                SELECT user_registration_id
122                     , ip_address as "ip_address: IpAddr"
123                     , user_agent
124                     , post_auth_action
125                     , username
126                     , display_name
127                     , terms_url
128                     , email_authentication_id
129                     , hashed_password
130                     , hashed_password_version
131                     , created_at
132                     , completed_at
133                FROM user_registrations
134                WHERE user_registration_id = $1
135            "#,
136            Uuid::from(id),
137        )
138        .traced()
139        .fetch_optional(&mut *self.conn)
140        .await?;
141
142        let Some(res) = res else { return Ok(None) };
143
144        Ok(Some(res.try_into()?))
145    }
146
147    #[tracing::instrument(
148        name = "db.user_registration.add",
149        skip_all,
150        fields(
151            db.query.text,
152            user_registration.id,
153        ),
154        err,
155    )]
156    async fn add(
157        &mut self,
158        rng: &mut (dyn RngCore + Send),
159        clock: &dyn Clock,
160        username: String,
161        ip_address: Option<IpAddr>,
162        user_agent: Option<String>,
163        post_auth_action: Option<serde_json::Value>,
164    ) -> Result<UserRegistration, Self::Error> {
165        let created_at = clock.now();
166        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
167        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
168
169        sqlx::query!(
170            r#"
171                INSERT INTO user_registrations
172                  ( user_registration_id
173                  , ip_address
174                  , user_agent
175                  , post_auth_action
176                  , username
177                  , created_at
178                  )
179                VALUES ($1, $2, $3, $4, $5, $6)
180            "#,
181            Uuid::from(id),
182            ip_address as Option<IpAddr>,
183            user_agent.as_deref(),
184            post_auth_action,
185            username,
186            created_at,
187        )
188        .traced()
189        .execute(&mut *self.conn)
190        .await?;
191
192        Ok(UserRegistration {
193            id,
194            ip_address,
195            user_agent,
196            post_auth_action,
197            created_at,
198            completed_at: None,
199            username,
200            display_name: None,
201            terms_url: None,
202            email_authentication_id: None,
203            password: None,
204        })
205    }
206
207    #[tracing::instrument(
208        name = "db.user_registration.set_display_name",
209        skip_all,
210        fields(
211            db.query.text,
212            user_registration.id = %user_registration.id,
213            user_registration.display_name = display_name,
214        ),
215        err,
216    )]
217    async fn set_display_name(
218        &mut self,
219        mut user_registration: UserRegistration,
220        display_name: String,
221    ) -> Result<UserRegistration, Self::Error> {
222        let res = sqlx::query!(
223            r#"
224                UPDATE user_registrations
225                SET display_name = $2
226                WHERE user_registration_id = $1 AND completed_at IS NULL
227            "#,
228            Uuid::from(user_registration.id),
229            display_name,
230        )
231        .traced()
232        .execute(&mut *self.conn)
233        .await?;
234
235        DatabaseError::ensure_affected_rows(&res, 1)?;
236
237        user_registration.display_name = Some(display_name);
238
239        Ok(user_registration)
240    }
241
242    #[tracing::instrument(
243        name = "db.user_registration.set_terms_url",
244        skip_all,
245        fields(
246            db.query.text,
247            user_registration.id = %user_registration.id,
248            user_registration.terms_url = %terms_url,
249        ),
250        err,
251    )]
252    async fn set_terms_url(
253        &mut self,
254        mut user_registration: UserRegistration,
255        terms_url: Url,
256    ) -> Result<UserRegistration, Self::Error> {
257        let res = sqlx::query!(
258            r#"
259                UPDATE user_registrations
260                SET terms_url = $2
261                WHERE user_registration_id = $1 AND completed_at IS NULL
262            "#,
263            Uuid::from(user_registration.id),
264            terms_url.as_str(),
265        )
266        .traced()
267        .execute(&mut *self.conn)
268        .await?;
269
270        DatabaseError::ensure_affected_rows(&res, 1)?;
271
272        user_registration.terms_url = Some(terms_url);
273
274        Ok(user_registration)
275    }
276
277    #[tracing::instrument(
278        name = "db.user_registration.set_email_authentication",
279        skip_all,
280        fields(
281            db.query.text,
282            %user_registration.id,
283            %user_email_authentication.id,
284            %user_email_authentication.email,
285        ),
286        err,
287    )]
288    async fn set_email_authentication(
289        &mut self,
290        mut user_registration: UserRegistration,
291        user_email_authentication: &UserEmailAuthentication,
292    ) -> Result<UserRegistration, Self::Error> {
293        let res = sqlx::query!(
294            r#"
295                UPDATE user_registrations
296                SET email_authentication_id = $2
297                WHERE user_registration_id = $1 AND completed_at IS NULL
298            "#,
299            Uuid::from(user_registration.id),
300            Uuid::from(user_email_authentication.id),
301        )
302        .traced()
303        .execute(&mut *self.conn)
304        .await?;
305
306        DatabaseError::ensure_affected_rows(&res, 1)?;
307
308        user_registration.email_authentication_id = Some(user_email_authentication.id);
309
310        Ok(user_registration)
311    }
312
313    #[tracing::instrument(
314        name = "db.user_registration.set_password",
315        skip_all,
316        fields(
317            db.query.text,
318            user_registration.id = %user_registration.id,
319            user_registration.hashed_password = hashed_password,
320            user_registration.hashed_password_version = version,
321        ),
322        err,
323    )]
324    async fn set_password(
325        &mut self,
326        mut user_registration: UserRegistration,
327        hashed_password: String,
328        version: u16,
329    ) -> Result<UserRegistration, Self::Error> {
330        let res = sqlx::query!(
331            r#"
332                UPDATE user_registrations
333                SET hashed_password = $2, hashed_password_version = $3
334                WHERE user_registration_id = $1 AND completed_at IS NULL
335            "#,
336            Uuid::from(user_registration.id),
337            hashed_password,
338            i32::from(version),
339        )
340        .traced()
341        .execute(&mut *self.conn)
342        .await?;
343
344        DatabaseError::ensure_affected_rows(&res, 1)?;
345
346        user_registration.password = Some(UserRegistrationPassword {
347            hashed_password,
348            version,
349        });
350
351        Ok(user_registration)
352    }
353
354    #[tracing::instrument(
355        name = "db.user_registration.complete",
356        skip_all,
357        fields(
358            db.query.text,
359            user_registration.id = %user_registration.id,
360        ),
361        err,
362    )]
363    async fn complete(
364        &mut self,
365        clock: &dyn Clock,
366        mut user_registration: UserRegistration,
367    ) -> Result<UserRegistration, Self::Error> {
368        let completed_at = clock.now();
369        let res = sqlx::query!(
370            r#"
371                UPDATE user_registrations
372                SET completed_at = $2
373                WHERE user_registration_id = $1 AND completed_at IS NULL
374            "#,
375            Uuid::from(user_registration.id),
376            completed_at,
377        )
378        .traced()
379        .execute(&mut *self.conn)
380        .await?;
381
382        DatabaseError::ensure_affected_rows(&res, 1)?;
383
384        user_registration.completed_at = Some(completed_at);
385
386        Ok(user_registration)
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use std::net::{IpAddr, Ipv4Addr};
393
394    use mas_data_model::UserRegistrationPassword;
395    use mas_storage::{Clock, clock::MockClock};
396    use rand::SeedableRng;
397    use rand_chacha::ChaChaRng;
398    use sqlx::PgPool;
399
400    use crate::PgRepository;
401
402    #[sqlx::test(migrator = "crate::MIGRATOR")]
403    async fn test_create_lookup_complete(pool: PgPool) {
404        let mut rng = ChaChaRng::seed_from_u64(42);
405        let clock = MockClock::default();
406
407        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
408
409        let registration = repo
410            .user_registration()
411            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
412            .await
413            .unwrap();
414
415        assert_eq!(registration.created_at, clock.now());
416        assert_eq!(registration.completed_at, None);
417        assert_eq!(registration.username, "alice");
418        assert_eq!(registration.display_name, None);
419        assert_eq!(registration.terms_url, None);
420        assert_eq!(registration.email_authentication_id, None);
421        assert_eq!(registration.password, None);
422        assert_eq!(registration.user_agent, None);
423        assert_eq!(registration.ip_address, None);
424        assert_eq!(registration.post_auth_action, None);
425
426        let lookup = repo
427            .user_registration()
428            .lookup(registration.id)
429            .await
430            .unwrap()
431            .unwrap();
432
433        assert_eq!(lookup.id, registration.id);
434        assert_eq!(lookup.created_at, registration.created_at);
435        assert_eq!(lookup.completed_at, registration.completed_at);
436        assert_eq!(lookup.username, registration.username);
437        assert_eq!(lookup.display_name, registration.display_name);
438        assert_eq!(lookup.terms_url, registration.terms_url);
439        assert_eq!(
440            lookup.email_authentication_id,
441            registration.email_authentication_id
442        );
443        assert_eq!(lookup.password, registration.password);
444        assert_eq!(lookup.user_agent, registration.user_agent);
445        assert_eq!(lookup.ip_address, registration.ip_address);
446        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
447
448        // Mark the registration as completed
449        let registration = repo
450            .user_registration()
451            .complete(&clock, registration)
452            .await
453            .unwrap();
454        assert_eq!(registration.completed_at, Some(clock.now()));
455
456        // Lookup the registration again
457        let lookup = repo
458            .user_registration()
459            .lookup(registration.id)
460            .await
461            .unwrap()
462            .unwrap();
463        assert_eq!(lookup.completed_at, registration.completed_at);
464
465        // Do it again, it should fail
466        let res = repo
467            .user_registration()
468            .complete(&clock, registration)
469            .await;
470        assert!(res.is_err());
471    }
472
473    #[sqlx::test(migrator = "crate::MIGRATOR")]
474    async fn test_create_useragent_ipaddress(pool: PgPool) {
475        let mut rng = ChaChaRng::seed_from_u64(42);
476        let clock = MockClock::default();
477
478        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
479
480        let registration = repo
481            .user_registration()
482            .add(
483                &mut rng,
484                &clock,
485                "alice".to_owned(),
486                Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
487                Some("Mozilla/5.0".to_owned()),
488                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
489            )
490            .await
491            .unwrap();
492
493        assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
494        assert_eq!(
495            registration.ip_address,
496            Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
497        );
498        assert_eq!(
499            registration.post_auth_action,
500            Some(
501                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
502            )
503        );
504
505        let lookup = repo
506            .user_registration()
507            .lookup(registration.id)
508            .await
509            .unwrap()
510            .unwrap();
511
512        assert_eq!(lookup.user_agent, registration.user_agent);
513        assert_eq!(lookup.ip_address, registration.ip_address);
514        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
515    }
516
517    #[sqlx::test(migrator = "crate::MIGRATOR")]
518    async fn test_set_display_name(pool: PgPool) {
519        let mut rng = ChaChaRng::seed_from_u64(42);
520        let clock = MockClock::default();
521
522        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
523
524        let registration = repo
525            .user_registration()
526            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
527            .await
528            .unwrap();
529
530        assert_eq!(registration.display_name, None);
531
532        let registration = repo
533            .user_registration()
534            .set_display_name(registration, "Alice".to_owned())
535            .await
536            .unwrap();
537
538        assert_eq!(registration.display_name, Some("Alice".to_owned()));
539
540        let lookup = repo
541            .user_registration()
542            .lookup(registration.id)
543            .await
544            .unwrap()
545            .unwrap();
546
547        assert_eq!(lookup.display_name, registration.display_name);
548
549        // Setting it again should work
550        let registration = repo
551            .user_registration()
552            .set_display_name(registration, "Bob".to_owned())
553            .await
554            .unwrap();
555
556        assert_eq!(registration.display_name, Some("Bob".to_owned()));
557
558        let lookup = repo
559            .user_registration()
560            .lookup(registration.id)
561            .await
562            .unwrap()
563            .unwrap();
564
565        assert_eq!(lookup.display_name, registration.display_name);
566
567        // Can't set it once completed
568        let registration = repo
569            .user_registration()
570            .complete(&clock, registration)
571            .await
572            .unwrap();
573
574        let res = repo
575            .user_registration()
576            .set_display_name(registration, "Charlie".to_owned())
577            .await;
578        assert!(res.is_err());
579    }
580
581    #[sqlx::test(migrator = "crate::MIGRATOR")]
582    async fn test_set_terms_url(pool: PgPool) {
583        let mut rng = ChaChaRng::seed_from_u64(42);
584        let clock = MockClock::default();
585
586        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
587
588        let registration = repo
589            .user_registration()
590            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
591            .await
592            .unwrap();
593
594        assert_eq!(registration.terms_url, None);
595
596        let registration = repo
597            .user_registration()
598            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
599            .await
600            .unwrap();
601
602        assert_eq!(
603            registration.terms_url,
604            Some("https://example.com/terms".parse().unwrap())
605        );
606
607        let lookup = repo
608            .user_registration()
609            .lookup(registration.id)
610            .await
611            .unwrap()
612            .unwrap();
613
614        assert_eq!(lookup.terms_url, registration.terms_url);
615
616        // Setting it again should work
617        let registration = repo
618            .user_registration()
619            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
620            .await
621            .unwrap();
622
623        assert_eq!(
624            registration.terms_url,
625            Some("https://example.com/terms2".parse().unwrap())
626        );
627
628        let lookup = repo
629            .user_registration()
630            .lookup(registration.id)
631            .await
632            .unwrap()
633            .unwrap();
634
635        assert_eq!(lookup.terms_url, registration.terms_url);
636
637        // Can't set it once completed
638        let registration = repo
639            .user_registration()
640            .complete(&clock, registration)
641            .await
642            .unwrap();
643
644        let res = repo
645            .user_registration()
646            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
647            .await;
648        assert!(res.is_err());
649    }
650
651    #[sqlx::test(migrator = "crate::MIGRATOR")]
652    async fn test_set_email_authentication(pool: PgPool) {
653        let mut rng = ChaChaRng::seed_from_u64(42);
654        let clock = MockClock::default();
655
656        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
657
658        let registration = repo
659            .user_registration()
660            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
661            .await
662            .unwrap();
663
664        assert_eq!(registration.email_authentication_id, None);
665
666        let authentication = repo
667            .user_email()
668            .add_authentication_for_registration(
669                &mut rng,
670                &clock,
671                "alice@example.com".to_owned(),
672                &registration,
673            )
674            .await
675            .unwrap();
676
677        let registration = repo
678            .user_registration()
679            .set_email_authentication(registration, &authentication)
680            .await
681            .unwrap();
682
683        assert_eq!(
684            registration.email_authentication_id,
685            Some(authentication.id)
686        );
687
688        let lookup = repo
689            .user_registration()
690            .lookup(registration.id)
691            .await
692            .unwrap()
693            .unwrap();
694
695        assert_eq!(
696            lookup.email_authentication_id,
697            registration.email_authentication_id
698        );
699
700        // Setting it again should work
701        let registration = repo
702            .user_registration()
703            .set_email_authentication(registration, &authentication)
704            .await
705            .unwrap();
706
707        assert_eq!(
708            registration.email_authentication_id,
709            Some(authentication.id)
710        );
711
712        let lookup = repo
713            .user_registration()
714            .lookup(registration.id)
715            .await
716            .unwrap()
717            .unwrap();
718
719        assert_eq!(
720            lookup.email_authentication_id,
721            registration.email_authentication_id
722        );
723
724        // Can't set it once completed
725        let registration = repo
726            .user_registration()
727            .complete(&clock, registration)
728            .await
729            .unwrap();
730
731        let res = repo
732            .user_registration()
733            .set_email_authentication(registration, &authentication)
734            .await;
735        assert!(res.is_err());
736    }
737
738    #[sqlx::test(migrator = "crate::MIGRATOR")]
739    async fn test_set_password(pool: PgPool) {
740        let mut rng = ChaChaRng::seed_from_u64(42);
741        let clock = MockClock::default();
742
743        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
744
745        let registration = repo
746            .user_registration()
747            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
748            .await
749            .unwrap();
750
751        assert_eq!(registration.password, None);
752
753        let registration = repo
754            .user_registration()
755            .set_password(registration, "fakehashedpassword".to_owned(), 1)
756            .await
757            .unwrap();
758
759        assert_eq!(
760            registration.password,
761            Some(UserRegistrationPassword {
762                hashed_password: "fakehashedpassword".to_owned(),
763                version: 1,
764            })
765        );
766
767        let lookup = repo
768            .user_registration()
769            .lookup(registration.id)
770            .await
771            .unwrap()
772            .unwrap();
773
774        assert_eq!(lookup.password, registration.password);
775
776        // Setting it again should work
777        let registration = repo
778            .user_registration()
779            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
780            .await
781            .unwrap();
782
783        assert_eq!(
784            registration.password,
785            Some(UserRegistrationPassword {
786                hashed_password: "fakehashedpassword2".to_owned(),
787                version: 2,
788            })
789        );
790
791        let lookup = repo
792            .user_registration()
793            .lookup(registration.id)
794            .await
795            .unwrap()
796            .unwrap();
797
798        assert_eq!(lookup.password, registration.password);
799
800        // Can't set it once completed
801        let registration = repo
802            .user_registration()
803            .complete(&clock, registration)
804            .await
805            .unwrap();
806
807        let res = repo
808            .user_registration()
809            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
810            .await;
811        assert!(res.is_err());
812    }
813}