1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration,
12 UserRegistrationPassword, UserRegistrationToken,
13};
14use mas_storage::user::UserRegistrationRepository;
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
22
23pub struct PgUserRegistrationRepository<'c> {
26 conn: &'c mut PgConnection,
27}
28
29impl<'c> PgUserRegistrationRepository<'c> {
30 pub fn new(conn: &'c mut PgConnection) -> Self {
33 Self { conn }
34 }
35}
36
37struct UserRegistrationLookup {
38 user_registration_id: Uuid,
39 ip_address: Option<IpAddr>,
40 user_agent: Option<String>,
41 post_auth_action: Option<serde_json::Value>,
42 username: String,
43 display_name: Option<String>,
44 terms_url: Option<String>,
45 email_authentication_id: Option<Uuid>,
46 user_registration_token_id: Option<Uuid>,
47 hashed_password: Option<String>,
48 hashed_password_version: Option<i32>,
49 upstream_oauth_authorization_session_id: Option<Uuid>,
50 created_at: DateTime<Utc>,
51 completed_at: Option<DateTime<Utc>>,
52}
53
54impl TryFrom<UserRegistrationLookup> for UserRegistration {
55 type Error = DatabaseInconsistencyError;
56
57 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
58 let id = Ulid::from(value.user_registration_id);
59
60 let password = match (value.hashed_password, value.hashed_password_version) {
61 (Some(hashed_password), Some(version)) => {
62 let version = version.try_into().map_err(|e| {
63 DatabaseInconsistencyError::on("user_registrations")
64 .column("hashed_password_version")
65 .row(id)
66 .source(e)
67 })?;
68
69 Some(UserRegistrationPassword {
70 hashed_password,
71 version,
72 })
73 }
74 (None, None) => None,
75 _ => {
76 return Err(DatabaseInconsistencyError::on("user_registrations")
77 .column("hashed_password")
78 .row(id));
79 }
80 };
81
82 let terms_url = value
83 .terms_url
84 .map(|u| u.parse())
85 .transpose()
86 .map_err(|e| {
87 DatabaseInconsistencyError::on("user_registrations")
88 .column("terms_url")
89 .row(id)
90 .source(e)
91 })?;
92
93 Ok(UserRegistration {
94 id,
95 ip_address: value.ip_address,
96 user_agent: value.user_agent,
97 post_auth_action: value.post_auth_action,
98 username: value.username,
99 display_name: value.display_name,
100 terms_url,
101 email_authentication_id: value.email_authentication_id.map(Ulid::from),
102 user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
103 password,
104 upstream_oauth_authorization_session_id: value
105 .upstream_oauth_authorization_session_id
106 .map(Ulid::from),
107 created_at: value.created_at,
108 completed_at: value.completed_at,
109 })
110 }
111}
112
113#[async_trait]
114impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
115 type Error = DatabaseError;
116
117 #[tracing::instrument(
118 name = "db.user_registration.lookup",
119 skip_all,
120 fields(
121 db.query.text,
122 user_registration.id = %id,
123 ),
124 err,
125 )]
126 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
127 let res = sqlx::query_as!(
128 UserRegistrationLookup,
129 r#"
130 SELECT user_registration_id
131 , ip_address as "ip_address: IpAddr"
132 , user_agent
133 , post_auth_action
134 , username
135 , display_name
136 , terms_url
137 , email_authentication_id
138 , user_registration_token_id
139 , hashed_password
140 , hashed_password_version
141 , upstream_oauth_authorization_session_id
142 , created_at
143 , completed_at
144 FROM user_registrations
145 WHERE user_registration_id = $1
146 "#,
147 Uuid::from(id),
148 )
149 .traced()
150 .fetch_optional(&mut *self.conn)
151 .await?;
152
153 let Some(res) = res else { return Ok(None) };
154
155 Ok(Some(res.try_into()?))
156 }
157
158 #[tracing::instrument(
159 name = "db.user_registration.add",
160 skip_all,
161 fields(
162 db.query.text,
163 user_registration.id,
164 ),
165 err,
166 )]
167 async fn add(
168 &mut self,
169 rng: &mut (dyn RngCore + Send),
170 clock: &dyn Clock,
171 username: String,
172 ip_address: Option<IpAddr>,
173 user_agent: Option<String>,
174 post_auth_action: Option<serde_json::Value>,
175 ) -> Result<UserRegistration, Self::Error> {
176 let created_at = clock.now();
177 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
178 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
179
180 sqlx::query!(
181 r#"
182 INSERT INTO user_registrations
183 ( user_registration_id
184 , ip_address
185 , user_agent
186 , post_auth_action
187 , username
188 , created_at
189 )
190 VALUES ($1, $2, $3, $4, $5, $6)
191 "#,
192 Uuid::from(id),
193 ip_address as Option<IpAddr>,
194 user_agent.as_deref(),
195 post_auth_action,
196 username,
197 created_at,
198 )
199 .traced()
200 .execute(&mut *self.conn)
201 .await?;
202
203 Ok(UserRegistration {
204 id,
205 ip_address,
206 user_agent,
207 post_auth_action,
208 created_at,
209 completed_at: None,
210 username,
211 display_name: None,
212 terms_url: None,
213 email_authentication_id: None,
214 user_registration_token_id: None,
215 password: None,
216 upstream_oauth_authorization_session_id: None,
217 })
218 }
219
220 #[tracing::instrument(
221 name = "db.user_registration.set_display_name",
222 skip_all,
223 fields(
224 db.query.text,
225 user_registration.id = %user_registration.id,
226 user_registration.display_name = display_name,
227 ),
228 err,
229 )]
230 async fn set_display_name(
231 &mut self,
232 mut user_registration: UserRegistration,
233 display_name: String,
234 ) -> Result<UserRegistration, Self::Error> {
235 let res = sqlx::query!(
236 r#"
237 UPDATE user_registrations
238 SET display_name = $2
239 WHERE user_registration_id = $1 AND completed_at IS NULL
240 "#,
241 Uuid::from(user_registration.id),
242 display_name,
243 )
244 .traced()
245 .execute(&mut *self.conn)
246 .await?;
247
248 DatabaseError::ensure_affected_rows(&res, 1)?;
249
250 user_registration.display_name = Some(display_name);
251
252 Ok(user_registration)
253 }
254
255 #[tracing::instrument(
256 name = "db.user_registration.set_terms_url",
257 skip_all,
258 fields(
259 db.query.text,
260 user_registration.id = %user_registration.id,
261 user_registration.terms_url = %terms_url,
262 ),
263 err,
264 )]
265 async fn set_terms_url(
266 &mut self,
267 mut user_registration: UserRegistration,
268 terms_url: Url,
269 ) -> Result<UserRegistration, Self::Error> {
270 let res = sqlx::query!(
271 r#"
272 UPDATE user_registrations
273 SET terms_url = $2
274 WHERE user_registration_id = $1 AND completed_at IS NULL
275 "#,
276 Uuid::from(user_registration.id),
277 terms_url.as_str(),
278 )
279 .traced()
280 .execute(&mut *self.conn)
281 .await?;
282
283 DatabaseError::ensure_affected_rows(&res, 1)?;
284
285 user_registration.terms_url = Some(terms_url);
286
287 Ok(user_registration)
288 }
289
290 #[tracing::instrument(
291 name = "db.user_registration.set_email_authentication",
292 skip_all,
293 fields(
294 db.query.text,
295 %user_registration.id,
296 %user_email_authentication.id,
297 %user_email_authentication.email,
298 ),
299 err,
300 )]
301 async fn set_email_authentication(
302 &mut self,
303 mut user_registration: UserRegistration,
304 user_email_authentication: &UserEmailAuthentication,
305 ) -> Result<UserRegistration, Self::Error> {
306 let res = sqlx::query!(
307 r#"
308 UPDATE user_registrations
309 SET email_authentication_id = $2
310 WHERE user_registration_id = $1 AND completed_at IS NULL
311 "#,
312 Uuid::from(user_registration.id),
313 Uuid::from(user_email_authentication.id),
314 )
315 .traced()
316 .execute(&mut *self.conn)
317 .await?;
318
319 DatabaseError::ensure_affected_rows(&res, 1)?;
320
321 user_registration.email_authentication_id = Some(user_email_authentication.id);
322
323 Ok(user_registration)
324 }
325
326 #[tracing::instrument(
327 name = "db.user_registration.set_password",
328 skip_all,
329 fields(
330 db.query.text,
331 user_registration.id = %user_registration.id,
332 user_registration.hashed_password = hashed_password,
333 user_registration.hashed_password_version = version,
334 ),
335 err,
336 )]
337 async fn set_password(
338 &mut self,
339 mut user_registration: UserRegistration,
340 hashed_password: String,
341 version: u16,
342 ) -> Result<UserRegistration, Self::Error> {
343 let res = sqlx::query!(
344 r#"
345 UPDATE user_registrations
346 SET hashed_password = $2, hashed_password_version = $3
347 WHERE user_registration_id = $1 AND completed_at IS NULL
348 "#,
349 Uuid::from(user_registration.id),
350 hashed_password,
351 i32::from(version),
352 )
353 .traced()
354 .execute(&mut *self.conn)
355 .await?;
356
357 DatabaseError::ensure_affected_rows(&res, 1)?;
358
359 user_registration.password = Some(UserRegistrationPassword {
360 hashed_password,
361 version,
362 });
363
364 Ok(user_registration)
365 }
366
367 #[tracing::instrument(
368 name = "db.user_registration.set_registration_token",
369 skip_all,
370 fields(
371 db.query.text,
372 %user_registration.id,
373 %user_registration_token.id,
374 ),
375 err,
376 )]
377 async fn set_registration_token(
378 &mut self,
379 mut user_registration: UserRegistration,
380 user_registration_token: &UserRegistrationToken,
381 ) -> Result<UserRegistration, Self::Error> {
382 let res = sqlx::query!(
383 r#"
384 UPDATE user_registrations
385 SET user_registration_token_id = $2
386 WHERE user_registration_id = $1 AND completed_at IS NULL
387 "#,
388 Uuid::from(user_registration.id),
389 Uuid::from(user_registration_token.id),
390 )
391 .traced()
392 .execute(&mut *self.conn)
393 .await?;
394
395 DatabaseError::ensure_affected_rows(&res, 1)?;
396
397 user_registration.user_registration_token_id = Some(user_registration_token.id);
398
399 Ok(user_registration)
400 }
401
402 #[tracing::instrument(
403 name = "db.user_registration.set_upstream_oauth_authorization_session",
404 skip_all,
405 fields(
406 db.query.text,
407 %user_registration.id,
408 %upstream_oauth_authorization_session.id,
409 ),
410 err,
411 )]
412 async fn set_upstream_oauth_authorization_session(
413 &mut self,
414 mut user_registration: UserRegistration,
415 upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession,
416 ) -> Result<UserRegistration, Self::Error> {
417 let res = sqlx::query!(
418 r#"
419 UPDATE user_registrations
420 SET upstream_oauth_authorization_session_id = $2
421 WHERE user_registration_id = $1 AND completed_at IS NULL
422 "#,
423 Uuid::from(user_registration.id),
424 Uuid::from(upstream_oauth_authorization_session.id),
425 )
426 .traced()
427 .execute(&mut *self.conn)
428 .await?;
429
430 DatabaseError::ensure_affected_rows(&res, 1)?;
431
432 user_registration.upstream_oauth_authorization_session_id =
433 Some(upstream_oauth_authorization_session.id);
434
435 Ok(user_registration)
436 }
437
438 #[tracing::instrument(
439 name = "db.user_registration.complete",
440 skip_all,
441 fields(
442 db.query.text,
443 user_registration.id = %user_registration.id,
444 ),
445 err,
446 )]
447 async fn complete(
448 &mut self,
449 clock: &dyn Clock,
450 mut user_registration: UserRegistration,
451 ) -> Result<UserRegistration, Self::Error> {
452 let completed_at = clock.now();
453 let res = sqlx::query!(
454 r#"
455 UPDATE user_registrations
456 SET completed_at = $2
457 WHERE user_registration_id = $1 AND completed_at IS NULL
458 "#,
459 Uuid::from(user_registration.id),
460 completed_at,
461 )
462 .traced()
463 .execute(&mut *self.conn)
464 .await?;
465
466 DatabaseError::ensure_affected_rows(&res, 1)?;
467
468 user_registration.completed_at = Some(completed_at);
469
470 Ok(user_registration)
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use std::net::{IpAddr, Ipv4Addr};
477
478 use mas_data_model::{
479 Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
480 UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
481 UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock,
482 };
483 use mas_iana::jose::JsonWebSignatureAlg;
484 use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams;
485 use oauth2_types::scope::Scope;
486 use rand::SeedableRng;
487 use rand_chacha::ChaChaRng;
488 use sqlx::PgPool;
489
490 use crate::PgRepository;
491
492 #[sqlx::test(migrator = "crate::MIGRATOR")]
493 async fn test_create_lookup_complete(pool: PgPool) {
494 let mut rng = ChaChaRng::seed_from_u64(42);
495 let clock = MockClock::default();
496
497 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
498
499 let registration = repo
500 .user_registration()
501 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
502 .await
503 .unwrap();
504
505 assert_eq!(registration.created_at, clock.now());
506 assert_eq!(registration.completed_at, None);
507 assert_eq!(registration.username, "alice");
508 assert_eq!(registration.display_name, None);
509 assert_eq!(registration.terms_url, None);
510 assert_eq!(registration.email_authentication_id, None);
511 assert_eq!(registration.password, None);
512 assert_eq!(registration.user_agent, None);
513 assert_eq!(registration.ip_address, None);
514 assert_eq!(registration.post_auth_action, None);
515
516 let lookup = repo
517 .user_registration()
518 .lookup(registration.id)
519 .await
520 .unwrap()
521 .unwrap();
522
523 assert_eq!(lookup.id, registration.id);
524 assert_eq!(lookup.created_at, registration.created_at);
525 assert_eq!(lookup.completed_at, registration.completed_at);
526 assert_eq!(lookup.username, registration.username);
527 assert_eq!(lookup.display_name, registration.display_name);
528 assert_eq!(lookup.terms_url, registration.terms_url);
529 assert_eq!(
530 lookup.email_authentication_id,
531 registration.email_authentication_id
532 );
533 assert_eq!(lookup.password, registration.password);
534 assert_eq!(lookup.user_agent, registration.user_agent);
535 assert_eq!(lookup.ip_address, registration.ip_address);
536 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
537
538 let registration = repo
540 .user_registration()
541 .complete(&clock, registration)
542 .await
543 .unwrap();
544 assert_eq!(registration.completed_at, Some(clock.now()));
545
546 let lookup = repo
548 .user_registration()
549 .lookup(registration.id)
550 .await
551 .unwrap()
552 .unwrap();
553 assert_eq!(lookup.completed_at, registration.completed_at);
554
555 let res = repo
557 .user_registration()
558 .complete(&clock, registration)
559 .await;
560 assert!(res.is_err());
561 }
562
563 #[sqlx::test(migrator = "crate::MIGRATOR")]
564 async fn test_create_useragent_ipaddress(pool: PgPool) {
565 let mut rng = ChaChaRng::seed_from_u64(42);
566 let clock = MockClock::default();
567
568 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
569
570 let registration = repo
571 .user_registration()
572 .add(
573 &mut rng,
574 &clock,
575 "alice".to_owned(),
576 Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
577 Some("Mozilla/5.0".to_owned()),
578 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
579 )
580 .await
581 .unwrap();
582
583 assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
584 assert_eq!(
585 registration.ip_address,
586 Some(IpAddr::V4(Ipv4Addr::LOCALHOST))
587 );
588 assert_eq!(
589 registration.post_auth_action,
590 Some(
591 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
592 )
593 );
594
595 let lookup = repo
596 .user_registration()
597 .lookup(registration.id)
598 .await
599 .unwrap()
600 .unwrap();
601
602 assert_eq!(lookup.user_agent, registration.user_agent);
603 assert_eq!(lookup.ip_address, registration.ip_address);
604 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
605 }
606
607 #[sqlx::test(migrator = "crate::MIGRATOR")]
608 async fn test_set_display_name(pool: PgPool) {
609 let mut rng = ChaChaRng::seed_from_u64(42);
610 let clock = MockClock::default();
611
612 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
613
614 let registration = repo
615 .user_registration()
616 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
617 .await
618 .unwrap();
619
620 assert_eq!(registration.display_name, None);
621
622 let registration = repo
623 .user_registration()
624 .set_display_name(registration, "Alice".to_owned())
625 .await
626 .unwrap();
627
628 assert_eq!(registration.display_name, Some("Alice".to_owned()));
629
630 let lookup = repo
631 .user_registration()
632 .lookup(registration.id)
633 .await
634 .unwrap()
635 .unwrap();
636
637 assert_eq!(lookup.display_name, registration.display_name);
638
639 let registration = repo
641 .user_registration()
642 .set_display_name(registration, "Bob".to_owned())
643 .await
644 .unwrap();
645
646 assert_eq!(registration.display_name, Some("Bob".to_owned()));
647
648 let lookup = repo
649 .user_registration()
650 .lookup(registration.id)
651 .await
652 .unwrap()
653 .unwrap();
654
655 assert_eq!(lookup.display_name, registration.display_name);
656
657 let registration = repo
659 .user_registration()
660 .complete(&clock, registration)
661 .await
662 .unwrap();
663
664 let res = repo
665 .user_registration()
666 .set_display_name(registration, "Charlie".to_owned())
667 .await;
668 assert!(res.is_err());
669 }
670
671 #[sqlx::test(migrator = "crate::MIGRATOR")]
672 async fn test_set_terms_url(pool: PgPool) {
673 let mut rng = ChaChaRng::seed_from_u64(42);
674 let clock = MockClock::default();
675
676 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
677
678 let registration = repo
679 .user_registration()
680 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
681 .await
682 .unwrap();
683
684 assert_eq!(registration.terms_url, None);
685
686 let registration = repo
687 .user_registration()
688 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
689 .await
690 .unwrap();
691
692 assert_eq!(
693 registration.terms_url,
694 Some("https://example.com/terms".parse().unwrap())
695 );
696
697 let lookup = repo
698 .user_registration()
699 .lookup(registration.id)
700 .await
701 .unwrap()
702 .unwrap();
703
704 assert_eq!(lookup.terms_url, registration.terms_url);
705
706 let registration = repo
708 .user_registration()
709 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
710 .await
711 .unwrap();
712
713 assert_eq!(
714 registration.terms_url,
715 Some("https://example.com/terms2".parse().unwrap())
716 );
717
718 let lookup = repo
719 .user_registration()
720 .lookup(registration.id)
721 .await
722 .unwrap()
723 .unwrap();
724
725 assert_eq!(lookup.terms_url, registration.terms_url);
726
727 let registration = repo
729 .user_registration()
730 .complete(&clock, registration)
731 .await
732 .unwrap();
733
734 let res = repo
735 .user_registration()
736 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
737 .await;
738 assert!(res.is_err());
739 }
740
741 #[sqlx::test(migrator = "crate::MIGRATOR")]
742 async fn test_set_email_authentication(pool: PgPool) {
743 let mut rng = ChaChaRng::seed_from_u64(42);
744 let clock = MockClock::default();
745
746 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
747
748 let registration = repo
749 .user_registration()
750 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
751 .await
752 .unwrap();
753
754 assert_eq!(registration.email_authentication_id, None);
755
756 let authentication = repo
757 .user_email()
758 .add_authentication_for_registration(
759 &mut rng,
760 &clock,
761 "alice@example.com".to_owned(),
762 ®istration,
763 )
764 .await
765 .unwrap();
766
767 let registration = repo
768 .user_registration()
769 .set_email_authentication(registration, &authentication)
770 .await
771 .unwrap();
772
773 assert_eq!(
774 registration.email_authentication_id,
775 Some(authentication.id)
776 );
777
778 let lookup = repo
779 .user_registration()
780 .lookup(registration.id)
781 .await
782 .unwrap()
783 .unwrap();
784
785 assert_eq!(
786 lookup.email_authentication_id,
787 registration.email_authentication_id
788 );
789
790 let registration = repo
792 .user_registration()
793 .set_email_authentication(registration, &authentication)
794 .await
795 .unwrap();
796
797 assert_eq!(
798 registration.email_authentication_id,
799 Some(authentication.id)
800 );
801
802 let lookup = repo
803 .user_registration()
804 .lookup(registration.id)
805 .await
806 .unwrap()
807 .unwrap();
808
809 assert_eq!(
810 lookup.email_authentication_id,
811 registration.email_authentication_id
812 );
813
814 let registration = repo
816 .user_registration()
817 .complete(&clock, registration)
818 .await
819 .unwrap();
820
821 let res = repo
822 .user_registration()
823 .set_email_authentication(registration, &authentication)
824 .await;
825 assert!(res.is_err());
826 }
827
828 #[sqlx::test(migrator = "crate::MIGRATOR")]
829 async fn test_set_password(pool: PgPool) {
830 let mut rng = ChaChaRng::seed_from_u64(42);
831 let clock = MockClock::default();
832
833 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
834
835 let registration = repo
836 .user_registration()
837 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
838 .await
839 .unwrap();
840
841 assert_eq!(registration.password, None);
842
843 let registration = repo
844 .user_registration()
845 .set_password(registration, "fakehashedpassword".to_owned(), 1)
846 .await
847 .unwrap();
848
849 assert_eq!(
850 registration.password,
851 Some(UserRegistrationPassword {
852 hashed_password: "fakehashedpassword".to_owned(),
853 version: 1,
854 })
855 );
856
857 let lookup = repo
858 .user_registration()
859 .lookup(registration.id)
860 .await
861 .unwrap()
862 .unwrap();
863
864 assert_eq!(lookup.password, registration.password);
865
866 let registration = repo
868 .user_registration()
869 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
870 .await
871 .unwrap();
872
873 assert_eq!(
874 registration.password,
875 Some(UserRegistrationPassword {
876 hashed_password: "fakehashedpassword2".to_owned(),
877 version: 2,
878 })
879 );
880
881 let lookup = repo
882 .user_registration()
883 .lookup(registration.id)
884 .await
885 .unwrap()
886 .unwrap();
887
888 assert_eq!(lookup.password, registration.password);
889
890 let registration = repo
892 .user_registration()
893 .complete(&clock, registration)
894 .await
895 .unwrap();
896
897 let res = repo
898 .user_registration()
899 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
900 .await;
901 assert!(res.is_err());
902 }
903
904 #[sqlx::test(migrator = "crate::MIGRATOR")]
905 async fn test_set_upstream_oauth_link(pool: PgPool) {
906 let mut rng = ChaChaRng::seed_from_u64(42);
907 let clock = MockClock::default();
908
909 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
910
911 let registration = repo
912 .user_registration()
913 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
914 .await
915 .unwrap();
916
917 assert_eq!(registration.upstream_oauth_authorization_session_id, None);
918
919 let provider = repo
920 .upstream_oauth_provider()
921 .add(
922 &mut rng,
923 &clock,
924 UpstreamOAuthProviderParams {
925 issuer: Some("https://example.com/".to_owned()),
926 human_name: Some("Example Ltd.".to_owned()),
927 brand_name: None,
928 scope: Scope::from_iter([oauth2_types::scope::OPENID]),
929 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
930 token_endpoint_signing_alg: None,
931 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
932 client_id: "client".to_owned(),
933 encrypted_client_secret: None,
934 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
935 authorization_endpoint_override: None,
936 token_endpoint_override: None,
937 userinfo_endpoint_override: None,
938 fetch_userinfo: false,
939 userinfo_signed_response_alg: None,
940 jwks_uri_override: None,
941 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
942 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
943 response_mode: None,
944 additional_authorization_parameters: Vec::new(),
945 forward_login_hint: false,
946 ui_order: 0,
947 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
948 },
949 )
950 .await
951 .unwrap();
952
953 let session = repo
954 .upstream_oauth_session()
955 .add(&mut rng, &clock, &provider, "state".to_owned(), None, None)
956 .await
957 .unwrap();
958
959 let registration = repo
960 .user_registration()
961 .set_upstream_oauth_authorization_session(registration, &session)
962 .await
963 .unwrap();
964
965 assert_eq!(
966 registration.upstream_oauth_authorization_session_id,
967 Some(session.id)
968 );
969
970 let lookup = repo
971 .user_registration()
972 .lookup(registration.id)
973 .await
974 .unwrap()
975 .unwrap();
976
977 assert_eq!(
978 lookup.upstream_oauth_authorization_session_id,
979 registration.upstream_oauth_authorization_session_id
980 );
981
982 let registration = repo
984 .user_registration()
985 .set_upstream_oauth_authorization_session(registration, &session)
986 .await
987 .unwrap();
988
989 assert_eq!(
990 registration.upstream_oauth_authorization_session_id,
991 Some(session.id)
992 );
993
994 let lookup = repo
995 .user_registration()
996 .lookup(registration.id)
997 .await
998 .unwrap()
999 .unwrap();
1000
1001 assert_eq!(
1002 lookup.upstream_oauth_authorization_session_id,
1003 registration.upstream_oauth_authorization_session_id
1004 );
1005
1006 let registration = repo
1008 .user_registration()
1009 .complete(&clock, registration)
1010 .await
1011 .unwrap();
1012
1013 let res = repo
1014 .user_registration()
1015 .set_upstream_oauth_authorization_session(registration, &session)
1016 .await;
1017 assert!(res.is_err());
1018 }
1019}