mas_storage_pg/user/
registration.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
12    UserRegistrationPassword, UserRegistrationToken,
13};
14use mas_storage::user::UserRegistrationRepository;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
22
23/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
24/// connection
25pub struct PgUserRegistrationRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgUserRegistrationRepository<'c> {
30    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
31    /// connection
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37struct UserRegistrationLookup {
38    user_registration_id: Uuid,
39    ip_address: Option<IpAddr>,
40    user_agent: Option<String>,
41    post_auth_action: Option<serde_json::Value>,
42    username: String,
43    display_name: Option<String>,
44    terms_url: Option<String>,
45    email_authentication_id: Option<Uuid>,
46    user_registration_token_id: Option<Uuid>,
47    hashed_password: Option<String>,
48    hashed_password_version: Option<i32>,
49    upstream_oauth_authorization_session_id: Option<Uuid>,
50    created_at: DateTime<Utc>,
51    completed_at: Option<DateTime<Utc>>,
52}
53
54impl TryFrom<UserRegistrationLookup> for UserRegistration {
55    type Error = DatabaseInconsistencyError;
56
57    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
58        let id = Ulid::from(value.user_registration_id);
59
60        let password = match (value.hashed_password, value.hashed_password_version) {
61            (Some(hashed_password), Some(version)) => {
62                let version = version.try_into().map_err(|e| {
63                    DatabaseInconsistencyError::on("user_registrations")
64                        .column("hashed_password_version")
65                        .row(id)
66                        .source(e)
67                })?;
68
69                Some(UserRegistrationPassword {
70                    hashed_password,
71                    version,
72                })
73            }
74            (None, None) => None,
75            _ => {
76                return Err(DatabaseInconsistencyError::on("user_registrations")
77                    .column("hashed_password")
78                    .row(id));
79            }
80        };
81
82        let terms_url = value
83            .terms_url
84            .map(|u| u.parse())
85            .transpose()
86            .map_err(|e| {
87                DatabaseInconsistencyError::on("user_registrations")
88                    .column("terms_url")
89                    .row(id)
90                    .source(e)
91            })?;
92
93        Ok(UserRegistration {
94            id,
95            ip_address: value.ip_address,
96            user_agent: value.user_agent,
97            post_auth_action: value.post_auth_action,
98            username: value.username,
99            display_name: value.display_name,
100            terms_url,
101            email_authentication_id: value.email_authentication_id.map(Ulid::from),
102            user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
103            password,
104            upstream_oauth_authorization_session_id: value
105                .upstream_oauth_authorization_session_id
106                .map(Ulid::from),
107            created_at: value.created_at,
108            completed_at: value.completed_at,
109        })
110    }
111}
112
113#[async_trait]
114impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
115    type Error = DatabaseError;
116
117    #[tracing::instrument(
118        name = "db.user_registration.lookup",
119        skip_all,
120        fields(
121            db.query.text,
122            user_registration.id = %id,
123        ),
124        err,
125    )]
126    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
127        let res = sqlx::query_as!(
128            UserRegistrationLookup,
129            r#"
130                SELECT user_registration_id
131                     , ip_address as "ip_address: IpAddr"
132                     , user_agent
133                     , post_auth_action
134                     , username
135                     , display_name
136                     , terms_url
137                     , email_authentication_id
138                     , user_registration_token_id
139                     , hashed_password
140                     , hashed_password_version
141                     , upstream_oauth_authorization_session_id
142                     , created_at
143                     , completed_at
144                FROM user_registrations
145                WHERE user_registration_id = $1
146            "#,
147            Uuid::from(id),
148        )
149        .traced()
150        .fetch_optional(&mut *self.conn)
151        .await?;
152
153        let Some(res) = res else { return Ok(None) };
154
155        Ok(Some(res.try_into()?))
156    }
157
158    #[tracing::instrument(
159        name = "db.user_registration.add",
160        skip_all,
161        fields(
162            db.query.text,
163            user_registration.id,
164        ),
165        err,
166    )]
167    async fn add(
168        &mut self,
169        rng: &mut (dyn RngCore + Send),
170        clock: &dyn Clock,
171        username: String,
172        ip_address: Option<IpAddr>,
173        user_agent: Option<String>,
174        post_auth_action: Option<serde_json::Value>,
175    ) -> Result<UserRegistration, Self::Error> {
176        let created_at = clock.now();
177        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
178        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
179
180        sqlx::query!(
181            r#"
182                INSERT INTO user_registrations
183                  ( user_registration_id
184                  , ip_address
185                  , user_agent
186                  , post_auth_action
187                  , username
188                  , created_at
189                  )
190                VALUES ($1, $2, $3, $4, $5, $6)
191            "#,
192            Uuid::from(id),
193            ip_address as Option<IpAddr>,
194            user_agent.as_deref(),
195            post_auth_action,
196            username,
197            created_at,
198        )
199        .traced()
200        .execute(&mut *self.conn)
201        .await?;
202
203        Ok(UserRegistration {
204            id,
205            ip_address,
206            user_agent,
207            post_auth_action,
208            created_at,
209            completed_at: None,
210            username,
211            display_name: None,
212            terms_url: None,
213            email_authentication_id: None,
214            user_registration_token_id: None,
215            password: None,
216            upstream_oauth_authorization_session_id: None,
217        })
218    }
219
220    #[tracing::instrument(
221        name = "db.user_registration.set_display_name",
222        skip_all,
223        fields(
224            db.query.text,
225            user_registration.id = %user_registration.id,
226            user_registration.display_name = display_name,
227        ),
228        err,
229    )]
230    async fn set_display_name(
231        &mut self,
232        mut user_registration: UserRegistration,
233        display_name: String,
234    ) -> Result<UserRegistration, Self::Error> {
235        let res = sqlx::query!(
236            r#"
237                UPDATE user_registrations
238                SET display_name = $2
239                WHERE user_registration_id = $1 AND completed_at IS NULL
240            "#,
241            Uuid::from(user_registration.id),
242            display_name,
243        )
244        .traced()
245        .execute(&mut *self.conn)
246        .await?;
247
248        DatabaseError::ensure_affected_rows(&res, 1)?;
249
250        user_registration.display_name = Some(display_name);
251
252        Ok(user_registration)
253    }
254
255    #[tracing::instrument(
256        name = "db.user_registration.set_terms_url",
257        skip_all,
258        fields(
259            db.query.text,
260            user_registration.id = %user_registration.id,
261            user_registration.terms_url = %terms_url,
262        ),
263        err,
264    )]
265    async fn set_terms_url(
266        &mut self,
267        mut user_registration: UserRegistration,
268        terms_url: Url,
269    ) -> Result<UserRegistration, Self::Error> {
270        let res = sqlx::query!(
271            r#"
272                UPDATE user_registrations
273                SET terms_url = $2
274                WHERE user_registration_id = $1 AND completed_at IS NULL
275            "#,
276            Uuid::from(user_registration.id),
277            terms_url.as_str(),
278        )
279        .traced()
280        .execute(&mut *self.conn)
281        .await?;
282
283        DatabaseError::ensure_affected_rows(&res, 1)?;
284
285        user_registration.terms_url = Some(terms_url);
286
287        Ok(user_registration)
288    }
289
290    #[tracing::instrument(
291        name = "db.user_registration.set_email_authentication",
292        skip_all,
293        fields(
294            db.query.text,
295            %user_registration.id,
296            %user_email_authentication.id,
297            %user_email_authentication.email,
298        ),
299        err,
300    )]
301    async fn set_email_authentication(
302        &mut self,
303        mut user_registration: UserRegistration,
304        user_email_authentication: &UserEmailAuthentication,
305    ) -> Result<UserRegistration, Self::Error> {
306        let res = sqlx::query!(
307            r#"
308                UPDATE user_registrations
309                SET email_authentication_id = $2
310                WHERE user_registration_id = $1 AND completed_at IS NULL
311            "#,
312            Uuid::from(user_registration.id),
313            Uuid::from(user_email_authentication.id),
314        )
315        .traced()
316        .execute(&mut *self.conn)
317        .await?;
318
319        DatabaseError::ensure_affected_rows(&res, 1)?;
320
321        user_registration.email_authentication_id = Some(user_email_authentication.id);
322
323        Ok(user_registration)
324    }
325
326    #[tracing::instrument(
327        name = "db.user_registration.set_password",
328        skip_all,
329        fields(
330            db.query.text,
331            user_registration.id = %user_registration.id,
332            user_registration.hashed_password = hashed_password,
333            user_registration.hashed_password_version = version,
334        ),
335        err,
336    )]
337    async fn set_password(
338        &mut self,
339        mut user_registration: UserRegistration,
340        hashed_password: String,
341        version: u16,
342    ) -> Result<UserRegistration, Self::Error> {
343        let res = sqlx::query!(
344            r#"
345                UPDATE user_registrations
346                SET hashed_password = $2, hashed_password_version = $3
347                WHERE user_registration_id = $1 AND completed_at IS NULL
348            "#,
349            Uuid::from(user_registration.id),
350            hashed_password,
351            i32::from(version),
352        )
353        .traced()
354        .execute(&mut *self.conn)
355        .await?;
356
357        DatabaseError::ensure_affected_rows(&res, 1)?;
358
359        user_registration.password = Some(UserRegistrationPassword {
360            hashed_password,
361            version,
362        });
363
364        Ok(user_registration)
365    }
366
367    #[tracing::instrument(
368        name = "db.user_registration.set_registration_token",
369        skip_all,
370        fields(
371            db.query.text,
372            %user_registration.id,
373            %user_registration_token.id,
374        ),
375        err,
376    )]
377    async fn set_registration_token(
378        &mut self,
379        mut user_registration: UserRegistration,
380        user_registration_token: &UserRegistrationToken,
381    ) -> Result<UserRegistration, Self::Error> {
382        let res = sqlx::query!(
383            r#"
384                UPDATE user_registrations
385                SET user_registration_token_id = $2
386                WHERE user_registration_id = $1 AND completed_at IS NULL
387            "#,
388            Uuid::from(user_registration.id),
389            Uuid::from(user_registration_token.id),
390        )
391        .traced()
392        .execute(&mut *self.conn)
393        .await?;
394
395        DatabaseError::ensure_affected_rows(&res, 1)?;
396
397        user_registration.user_registration_token_id = Some(user_registration_token.id);
398
399        Ok(user_registration)
400    }
401
402    #[tracing::instrument(
403        name = "db.user_registration.set_upstream_oauth_authorization_session",
404        skip_all,
405        fields(
406            db.query.text,
407            %user_registration.id,
408            %upstream_oauth_authorization_session.id,
409        ),
410        err,
411    )]
412    async fn set_upstream_oauth_authorization_session(
413        &mut self,
414        mut user_registration: UserRegistration,
415        upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
416    ) -> Result<UserRegistration, Self::Error> {
417        let res = sqlx::query!(
418            r#"
419                UPDATE user_registrations
420                SET upstream_oauth_authorization_session_id = $2
421                WHERE user_registration_id = $1 AND completed_at IS NULL
422            "#,
423            Uuid::from(user_registration.id),
424            Uuid::from(upstream_oauth_authorization_session.id),
425        )
426        .traced()
427        .execute(&mut *self.conn)
428        .await?;
429
430        DatabaseError::ensure_affected_rows(&res, 1)?;
431
432        user_registration.upstream_oauth_authorization_session_id =
433            Some(upstream_oauth_authorization_session.id);
434
435        Ok(user_registration)
436    }
437
438    #[tracing::instrument(
439        name = "db.user_registration.complete",
440        skip_all,
441        fields(
442            db.query.text,
443            user_registration.id = %user_registration.id,
444        ),
445        err,
446    )]
447    async fn complete(
448        &mut self,
449        clock: &dyn Clock,
450        mut user_registration: UserRegistration,
451    ) -> Result<UserRegistration, Self::Error> {
452        let completed_at = clock.now();
453        let res = sqlx::query!(
454            r#"
455                UPDATE user_registrations
456                SET completed_at = $2
457                WHERE user_registration_id = $1 AND completed_at IS NULL
458            "#,
459            Uuid::from(user_registration.id),
460            completed_at,
461        )
462        .traced()
463        .execute(&mut *self.conn)
464        .await?;
465
466        DatabaseError::ensure_affected_rows(&res, 1)?;
467
468        user_registration.completed_at = Some(completed_at);
469
470        Ok(user_registration)
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use std::net::{IpAddr, Ipv4Addr};
477
478    use mas_data_model::{
479        Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
480        UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
481        UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
482    };
483    use mas_iana::jose::JsonWebSignatureAlg;
484    use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
485    use oauth2_types::scope::Scope;
486    use rand::SeedableRng;
487    use rand_chacha::ChaChaRng;
488    use sqlx::PgPool;
489
490    use crate::PgRepository;
491
492    #[sqlx::test(migrator = "crate::MIGRATOR")]
493    async fn test_create_lookup_complete(pool: PgPool) {
494        let mut rng = ChaChaRng::seed_from_u64(42);
495        let clock = MockClock::default();
496
497        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
498
499        let registration = repo
500            .user_registration()
501            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
502            .await
503            .unwrap();
504
505        assert_eq!(registration.created_at, clock.now());
506        assert_eq!(registration.completed_at, None);
507        assert_eq!(registration.username, "alice");
508        assert_eq!(registration.display_name, None);
509        assert_eq!(registration.terms_url, None);
510        assert_eq!(registration.email_authentication_id, None);
511        assert_eq!(registration.password, None);
512        assert_eq!(registration.user_agent, None);
513        assert_eq!(registration.ip_address, None);
514        assert_eq!(registration.post_auth_action, None);
515
516        let lookup = repo
517            .user_registration()
518            .lookup(registration.id)
519            .await
520            .unwrap()
521            .unwrap();
522
523        assert_eq!(lookup.id, registration.id);
524        assert_eq!(lookup.created_at, registration.created_at);
525        assert_eq!(lookup.completed_at, registration.completed_at);
526        assert_eq!(lookup.username, registration.username);
527        assert_eq!(lookup.display_name, registration.display_name);
528        assert_eq!(lookup.terms_url, registration.terms_url);
529        assert_eq!(
530            lookup.email_authentication_id,
531            registration.email_authentication_id
532        );
533        assert_eq!(lookup.password, registration.password);
534        assert_eq!(lookup.user_agent, registration.user_agent);
535        assert_eq!(lookup.ip_address, registration.ip_address);
536        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
537
538        // Mark the registration as completed
539        let registration = repo
540            .user_registration()
541            .complete(&clock, registration)
542            .await
543            .unwrap();
544        assert_eq!(registration.completed_at, Some(clock.now()));
545
546        // Lookup the registration again
547        let lookup = repo
548            .user_registration()
549            .lookup(registration.id)
550            .await
551            .unwrap()
552            .unwrap();
553        assert_eq!(lookup.completed_at, registration.completed_at);
554
555        // Do it again, it should fail
556        let res = repo
557            .user_registration()
558            .complete(&clock, registration)
559            .await;
560        assert!(res.is_err());
561    }
562
563    #[sqlx::test(migrator = "crate::MIGRATOR")]
564    async fn test_create_useragent_ipaddress(pool: PgPool) {
565        let mut rng = ChaChaRng::seed_from_u64(42);
566        let clock = MockClock::default();
567
568        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
569
570        let registration = repo
571            .user_registration()
572            .add(
573                &mut rng,
574                &clock,
575                "alice".to_owned(),
576                Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
577                Some("Mozilla/5.0".to_owned()),
578                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
579            )
580            .await
581            .unwrap();
582
583        assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
584        assert_eq!(
585            registration.ip_address,
586            Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
587        );
588        assert_eq!(
589            registration.post_auth_action,
590            Some(
591                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
592            )
593        );
594
595        let lookup = repo
596            .user_registration()
597            .lookup(registration.id)
598            .await
599            .unwrap()
600            .unwrap();
601
602        assert_eq!(lookup.user_agent, registration.user_agent);
603        assert_eq!(lookup.ip_address, registration.ip_address);
604        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
605    }
606
607    #[sqlx::test(migrator = "crate::MIGRATOR")]
608    async fn test_set_display_name(pool: PgPool) {
609        let mut rng = ChaChaRng::seed_from_u64(42);
610        let clock = MockClock::default();
611
612        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
613
614        let registration = repo
615            .user_registration()
616            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
617            .await
618            .unwrap();
619
620        assert_eq!(registration.display_name, None);
621
622        let registration = repo
623            .user_registration()
624            .set_display_name(registration, "Alice".to_owned())
625            .await
626            .unwrap();
627
628        assert_eq!(registration.display_name, Some("Alice".to_owned()));
629
630        let lookup = repo
631            .user_registration()
632            .lookup(registration.id)
633            .await
634            .unwrap()
635            .unwrap();
636
637        assert_eq!(lookup.display_name, registration.display_name);
638
639        // Setting it again should work
640        let registration = repo
641            .user_registration()
642            .set_display_name(registration, "Bob".to_owned())
643            .await
644            .unwrap();
645
646        assert_eq!(registration.display_name, Some("Bob".to_owned()));
647
648        let lookup = repo
649            .user_registration()
650            .lookup(registration.id)
651            .await
652            .unwrap()
653            .unwrap();
654
655        assert_eq!(lookup.display_name, registration.display_name);
656
657        // Can't set it once completed
658        let registration = repo
659            .user_registration()
660            .complete(&clock, registration)
661            .await
662            .unwrap();
663
664        let res = repo
665            .user_registration()
666            .set_display_name(registration, "Charlie".to_owned())
667            .await;
668        assert!(res.is_err());
669    }
670
671    #[sqlx::test(migrator = "crate::MIGRATOR")]
672    async fn test_set_terms_url(pool: PgPool) {
673        let mut rng = ChaChaRng::seed_from_u64(42);
674        let clock = MockClock::default();
675
676        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
677
678        let registration = repo
679            .user_registration()
680            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
681            .await
682            .unwrap();
683
684        assert_eq!(registration.terms_url, None);
685
686        let registration = repo
687            .user_registration()
688            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
689            .await
690            .unwrap();
691
692        assert_eq!(
693            registration.terms_url,
694            Some("https://example.com/terms".parse().unwrap())
695        );
696
697        let lookup = repo
698            .user_registration()
699            .lookup(registration.id)
700            .await
701            .unwrap()
702            .unwrap();
703
704        assert_eq!(lookup.terms_url, registration.terms_url);
705
706        // Setting it again should work
707        let registration = repo
708            .user_registration()
709            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
710            .await
711            .unwrap();
712
713        assert_eq!(
714            registration.terms_url,
715            Some("https://example.com/terms2".parse().unwrap())
716        );
717
718        let lookup = repo
719            .user_registration()
720            .lookup(registration.id)
721            .await
722            .unwrap()
723            .unwrap();
724
725        assert_eq!(lookup.terms_url, registration.terms_url);
726
727        // Can't set it once completed
728        let registration = repo
729            .user_registration()
730            .complete(&clock, registration)
731            .await
732            .unwrap();
733
734        let res = repo
735            .user_registration()
736            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
737            .await;
738        assert!(res.is_err());
739    }
740
741    #[sqlx::test(migrator = "crate::MIGRATOR")]
742    async fn test_set_email_authentication(pool: PgPool) {
743        let mut rng = ChaChaRng::seed_from_u64(42);
744        let clock = MockClock::default();
745
746        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
747
748        let registration = repo
749            .user_registration()
750            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
751            .await
752            .unwrap();
753
754        assert_eq!(registration.email_authentication_id, None);
755
756        let authentication = repo
757            .user_email()
758            .add_authentication_for_registration(
759                &mut rng,
760                &clock,
761                "alice@example.com".to_owned(),
762                &registration,
763            )
764            .await
765            .unwrap();
766
767        let registration = repo
768            .user_registration()
769            .set_email_authentication(registration, &authentication)
770            .await
771            .unwrap();
772
773        assert_eq!(
774            registration.email_authentication_id,
775            Some(authentication.id)
776        );
777
778        let lookup = repo
779            .user_registration()
780            .lookup(registration.id)
781            .await
782            .unwrap()
783            .unwrap();
784
785        assert_eq!(
786            lookup.email_authentication_id,
787            registration.email_authentication_id
788        );
789
790        // Setting it again should work
791        let registration = repo
792            .user_registration()
793            .set_email_authentication(registration, &authentication)
794            .await
795            .unwrap();
796
797        assert_eq!(
798            registration.email_authentication_id,
799            Some(authentication.id)
800        );
801
802        let lookup = repo
803            .user_registration()
804            .lookup(registration.id)
805            .await
806            .unwrap()
807            .unwrap();
808
809        assert_eq!(
810            lookup.email_authentication_id,
811            registration.email_authentication_id
812        );
813
814        // Can't set it once completed
815        let registration = repo
816            .user_registration()
817            .complete(&clock, registration)
818            .await
819            .unwrap();
820
821        let res = repo
822            .user_registration()
823            .set_email_authentication(registration, &authentication)
824            .await;
825        assert!(res.is_err());
826    }
827
828    #[sqlx::test(migrator = "crate::MIGRATOR")]
829    async fn test_set_password(pool: PgPool) {
830        let mut rng = ChaChaRng::seed_from_u64(42);
831        let clock = MockClock::default();
832
833        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
834
835        let registration = repo
836            .user_registration()
837            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
838            .await
839            .unwrap();
840
841        assert_eq!(registration.password, None);
842
843        let registration = repo
844            .user_registration()
845            .set_password(registration, "fakehashedpassword".to_owned(), 1)
846            .await
847            .unwrap();
848
849        assert_eq!(
850            registration.password,
851            Some(UserRegistrationPassword {
852                hashed_password: "fakehashedpassword".to_owned(),
853                version: 1,
854            })
855        );
856
857        let lookup = repo
858            .user_registration()
859            .lookup(registration.id)
860            .await
861            .unwrap()
862            .unwrap();
863
864        assert_eq!(lookup.password, registration.password);
865
866        // Setting it again should work
867        let registration = repo
868            .user_registration()
869            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
870            .await
871            .unwrap();
872
873        assert_eq!(
874            registration.password,
875            Some(UserRegistrationPassword {
876                hashed_password: "fakehashedpassword2".to_owned(),
877                version: 2,
878            })
879        );
880
881        let lookup = repo
882            .user_registration()
883            .lookup(registration.id)
884            .await
885            .unwrap()
886            .unwrap();
887
888        assert_eq!(lookup.password, registration.password);
889
890        // Can't set it once completed
891        let registration = repo
892            .user_registration()
893            .complete(&clock, registration)
894            .await
895            .unwrap();
896
897        let res = repo
898            .user_registration()
899            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
900            .await;
901        assert!(res.is_err());
902    }
903
904    #[sqlx::test(migrator = "crate::MIGRATOR")]
905    async fn test_set_upstream_oauth_link(pool: PgPool) {
906        let mut rng = ChaChaRng::seed_from_u64(42);
907        let clock = MockClock::default();
908
909        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
910
911        let registration = repo
912            .user_registration()
913            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
914            .await
915            .unwrap();
916
917        assert_eq!(registration.upstream_oauth_authorization_session_id, None);
918
919        let provider = repo
920            .upstream_oauth_provider()
921            .add(
922                &mut rng,
923                &clock,
924                UpstreamOAuthProviderParams {
925                    issuer: Some("https://example.com/".to_owned()),
926                    human_name: Some("Example Ltd.".to_owned()),
927                    brand_name: None,
928                    scope: Scope::from_iter([oauth2_types::scope::OPENID]),
929                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
930                    token_endpoint_signing_alg: None,
931                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
932                    client_id: "client".to_owned(),
933                    encrypted_client_secret: None,
934                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
935                    authorization_endpoint_override: None,
936                    token_endpoint_override: None,
937                    userinfo_endpoint_override: None,
938                    fetch_userinfo: false,
939                    userinfo_signed_response_alg: None,
940                    jwks_uri_override: None,
941                    discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
942                    pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
943                    response_mode: None,
944                    additional_authorization_parameters: Vec::new(),
945                    forward_login_hint: false,
946                    ui_order: 0,
947                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
948                },
949            )
950            .await
951            .unwrap();
952
953        let session = repo
954            .upstream_oauth_session()
955            .add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
956            .await
957            .unwrap();
958
959        let registration = repo
960            .user_registration()
961            .set_upstream_oauth_authorization_session(registration, &session)
962            .await
963            .unwrap();
964
965        assert_eq!(
966            registration.upstream_oauth_authorization_session_id,
967            Some(session.id)
968        );
969
970        let lookup = repo
971            .user_registration()
972            .lookup(registration.id)
973            .await
974            .unwrap()
975            .unwrap();
976
977        assert_eq!(
978            lookup.upstream_oauth_authorization_session_id,
979            registration.upstream_oauth_authorization_session_id
980        );
981
982        // Setting it again should work
983        let registration = repo
984            .user_registration()
985            .set_upstream_oauth_authorization_session(registration, &session)
986            .await
987            .unwrap();
988
989        assert_eq!(
990            registration.upstream_oauth_authorization_session_id,
991            Some(session.id)
992        );
993
994        let lookup = repo
995            .user_registration()
996            .lookup(registration.id)
997            .await
998            .unwrap()
999            .unwrap();
1000
1001        assert_eq!(
1002            lookup.upstream_oauth_authorization_session_id,
1003            registration.upstream_oauth_authorization_session_id
1004        );
1005
1006        // Can't set it once completed
1007        let registration = repo
1008            .user_registration()
1009            .complete(&clock, registration)
1010            .await
1011            .unwrap();
1012
1013        let res = repo
1014            .user_registration()
1015            .set_upstream_oauth_authorization_session(registration, &session)
1016            .await;
1017        assert!(res.is_err());
1018    }
1019}