1use std::sync::{Arc, LazyLock};
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17 FancyError, SessionInfoExt,
18 cookies::CookieJar,
19 csrf::{CsrfExt, ProtectedForm},
20 record_error,
21};
22use mas_jose::jwt::Jwt;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27 BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28 queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
29 upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
30 user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
31};
32use mas_templates::{
33 AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
34 ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
35};
36use minijinja::Environment;
37use opentelemetry::{Key, KeyValue, metrics::Counter};
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44 UpstreamSessionsCookie,
45 template::{AttributeMappingContext, environment},
46};
47use crate::{
48 BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49 views::shared::OptionalPostAuthAction,
50};
51
52static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
53 METER
54 .u64_counter("mas.upstream_oauth2.login")
55 .with_description("Successful upstream OAuth 2.0 login to existing accounts")
56 .with_unit("{login}")
57 .build()
58});
59static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
60 METER
61 .u64_counter("mas.upstream_oauth2.registration")
62 .with_description("Successful upstream OAuth 2.0 registration")
63 .with_unit("{registration}")
64 .build()
65});
66const PROVIDER: Key = Key::from_static_str("provider");
67
68const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
69const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
70const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
71
72#[derive(Debug, Error)]
73pub(crate) enum RouteError {
74 #[error("Link not found")]
76 LinkNotFound,
77
78 #[error("Session {0} not found")]
80 SessionNotFound(Ulid),
81
82 #[error("User {0} not found")]
84 UserNotFound(Ulid),
85
86 #[error("Upstream provider {0} not found")]
88 ProviderNotFound(Ulid),
89
90 #[error("Template {template:?} rendered to an empty string")]
92 RequiredAttributeEmpty { template: String },
93
94 #[error(
96 "Template {template:?} could not be rendered from the upstream provider's response for required claim"
97 )]
98 RequiredAttributeRender {
99 template: String,
100
101 #[source]
102 source: minijinja::Error,
103 },
104
105 #[error("Session {0} already consumed")]
107 SessionConsumed(Ulid),
108
109 #[error("Missing session cookie")]
110 MissingCookie,
111
112 #[error("Invalid form action")]
113 InvalidFormAction,
114
115 #[error("Homeserver connection error")]
116 HomeserverConnection(#[source] anyhow::Error),
117
118 #[error(transparent)]
119 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
120}
121
122impl_from_error_for_route!(mas_templates::TemplateError);
123impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
124impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
125impl_from_error_for_route!(mas_storage::RepositoryError);
126impl_from_error_for_route!(mas_policy::EvaluationError);
127impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
128
129impl IntoResponse for RouteError {
130 fn into_response(self) -> axum::response::Response {
131 let sentry_event_id = record_error!(
132 self,
133 Self::Internal(_)
134 | Self::RequiredAttributeEmpty { .. }
135 | Self::RequiredAttributeRender { .. }
136 | Self::SessionNotFound(_)
137 | Self::ProviderNotFound(_)
138 | Self::UserNotFound(_)
139 | Self::HomeserverConnection(_)
140 );
141 let response = match self {
142 Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
143 Self::Internal(e) => FancyError::from(e).into_response(),
144 e => FancyError::from(e).into_response(),
145 };
146
147 (sentry_event_id, response).into_response()
148 }
149}
150
151fn render_attribute_template(
164 environment: &Environment,
165 template: &str,
166 context: &minijinja::Value,
167 required: bool,
168) -> Result<Option<String>, RouteError> {
169 match environment.render_str(template, context) {
170 Ok(value) if value.is_empty() => {
171 if required {
172 return Err(RouteError::RequiredAttributeEmpty {
173 template: template.to_owned(),
174 });
175 }
176
177 Ok(None)
178 }
179
180 Ok(value) => Ok(Some(value)),
181
182 Err(source) => {
183 if required {
184 return Err(RouteError::RequiredAttributeRender {
185 template: template.to_owned(),
186 source,
187 });
188 }
189
190 tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
191 Ok(None)
192 }
193 }
194}
195
196#[derive(Deserialize, Serialize)]
197#[serde(rename_all = "lowercase", tag = "action")]
198pub(crate) enum FormData {
199 Register {
200 #[serde(default)]
201 username: Option<String>,
202 #[serde(default)]
203 import_email: Option<String>,
204 #[serde(default)]
205 import_display_name: Option<String>,
206 #[serde(default)]
207 accept_terms: Option<String>,
208 },
209 Link,
210}
211
212impl ToFormState for FormData {
213 type Field = mas_templates::UpstreamRegisterFormField;
214}
215
216#[tracing::instrument(
217 name = "handlers.upstream_oauth2.link.get",
218 fields(upstream_oauth_link.id = %link_id),
219 skip_all,
220)]
221pub(crate) async fn get(
222 mut rng: BoxRng,
223 clock: BoxClock,
224 mut repo: BoxRepository,
225 mut policy: Policy,
226 PreferredLanguage(locale): PreferredLanguage,
227 State(templates): State<Templates>,
228 State(url_builder): State<UrlBuilder>,
229 State(homeserver): State<Arc<dyn HomeserverConnection>>,
230 cookie_jar: CookieJar,
231 activity_tracker: BoundActivityTracker,
232 user_agent: Option<TypedHeader<headers::UserAgent>>,
233 Path(link_id): Path<Ulid>,
234) -> Result<impl IntoResponse, RouteError> {
235 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
236 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
237 let (session_id, post_auth_action) = sessions_cookie
238 .lookup_link(link_id)
239 .map_err(|_| RouteError::MissingCookie)?;
240
241 let post_auth_action = OptionalPostAuthAction {
242 post_auth_action: post_auth_action.cloned(),
243 };
244
245 let link = repo
246 .upstream_oauth_link()
247 .lookup(link_id)
248 .await?
249 .ok_or(RouteError::LinkNotFound)?;
250
251 let upstream_session = repo
252 .upstream_oauth_session()
253 .lookup(session_id)
254 .await?
255 .ok_or(RouteError::SessionNotFound(session_id))?;
256
257 if upstream_session.link_id() != Some(link.id) {
260 return Err(RouteError::SessionNotFound(session_id));
261 }
262
263 if upstream_session.is_consumed() {
264 return Err(RouteError::SessionConsumed(session_id));
265 }
266
267 let (user_session_info, cookie_jar) = cookie_jar.session_info();
268 let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
269 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
270
271 let response = match (maybe_user_session, link.user_id) {
272 (Some(session), Some(user_id)) if session.user.id == user_id => {
273 let upstream_session = repo
276 .upstream_oauth_session()
277 .consume(&clock, upstream_session)
278 .await?;
279
280 repo.browser_session()
281 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
282 .await?;
283
284 cookie_jar = cookie_jar.set_session(&session);
285
286 repo.save().await?;
287
288 post_auth_action.go_next(&url_builder).into_response()
289 }
290
291 (Some(user_session), Some(user_id)) => {
292 let user = repo
296 .user()
297 .lookup(user_id)
298 .await?
299 .ok_or(RouteError::UserNotFound(user_id))?;
300
301 let ctx = UpstreamExistingLinkContext::new(user)
302 .with_session(user_session)
303 .with_csrf(csrf_token.form_value())
304 .with_language(locale);
305
306 Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
307 }
308
309 (Some(user_session), None) => {
310 let ctx = UpstreamSuggestLink::new(&link)
312 .with_session(user_session)
313 .with_csrf(csrf_token.form_value())
314 .with_language(locale);
315
316 Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
317 }
318
319 (None, Some(user_id)) => {
320 let user = repo
322 .user()
323 .lookup(user_id)
324 .await?
325 .ok_or(RouteError::UserNotFound(user_id))?;
326
327 if user.deactivated_at.is_some() {
329 let ctx = AccountInactiveContext::new(user)
331 .with_csrf(csrf_token.form_value())
332 .with_language(locale);
333 let fallback = templates.render_account_deactivated(&ctx)?;
334 return Ok((cookie_jar, Html(fallback).into_response()));
335 }
336
337 if user.locked_at.is_some() {
338 let ctx = AccountInactiveContext::new(user)
340 .with_csrf(csrf_token.form_value())
341 .with_language(locale);
342 let fallback = templates.render_account_locked(&ctx)?;
343 return Ok((cookie_jar, Html(fallback).into_response()));
344 }
345
346 let session = repo
347 .browser_session()
348 .add(&mut rng, &clock, &user, user_agent)
349 .await?;
350
351 let upstream_session = repo
352 .upstream_oauth_session()
353 .consume(&clock, upstream_session)
354 .await?;
355
356 repo.browser_session()
357 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
358 .await?;
359
360 cookie_jar = sessions_cookie
361 .consume_link(link_id)?
362 .save(cookie_jar, &clock);
363 cookie_jar = cookie_jar.set_session(&session);
364
365 repo.save().await?;
366
367 LOGIN_COUNTER.add(
368 1,
369 &[KeyValue::new(
370 PROVIDER,
371 upstream_session.provider_id.to_string(),
372 )],
373 );
374
375 post_auth_action.go_next(&url_builder).into_response()
376 }
377
378 (None, None) => {
379 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
382
383 let provider = repo
384 .upstream_oauth_provider()
385 .lookup(link.provider_id)
386 .await?
387 .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
388
389 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
390
391 let env = environment();
392
393 let mut context = AttributeMappingContext::new();
394 if let Some(id_token) = id_token {
395 let (_, payload) = id_token.into_parts();
396 context = context.with_id_token_claims(payload);
397 }
398 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
399 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
400 }
401 if let Some(userinfo) = upstream_session.userinfo() {
402 context = context.with_userinfo_claims(userinfo.clone());
403 }
404 let context = context.build();
405
406 let ctx = if provider.claims_imports.displayname.ignore() {
407 ctx
408 } else {
409 let template = provider
410 .claims_imports
411 .displayname
412 .template
413 .as_deref()
414 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
415
416 match render_attribute_template(
417 &env,
418 template,
419 &context,
420 provider.claims_imports.displayname.is_required(),
421 )? {
422 Some(value) => ctx
423 .with_display_name(value, provider.claims_imports.displayname.is_forced()),
424 None => ctx,
425 }
426 };
427
428 let ctx = if provider.claims_imports.email.ignore() {
429 ctx
430 } else {
431 let template = provider
432 .claims_imports
433 .email
434 .template
435 .as_deref()
436 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
437
438 match render_attribute_template(
439 &env,
440 template,
441 &context,
442 provider.claims_imports.email.is_required(),
443 )? {
444 Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
445 None => ctx,
446 }
447 };
448
449 let ctx = if provider.claims_imports.localpart.ignore() {
450 ctx
451 } else {
452 let template = provider
453 .claims_imports
454 .localpart
455 .template
456 .as_deref()
457 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
458
459 match render_attribute_template(
460 &env,
461 template,
462 &context,
463 provider.claims_imports.localpart.is_required(),
464 )? {
465 Some(localpart) => {
466 let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
470 let is_available = homeserver
471 .is_localpart_available(&localpart)
472 .await
473 .map_err(RouteError::HomeserverConnection)?;
474
475 if maybe_existing_user.is_some() || !is_available {
476 if let Some(existing_user) = maybe_existing_user {
477 warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
480 }
481
482 let ctx = ErrorContext::new()
484 .with_code("User exists")
485 .with_description(format!(
486 r"Upstream account provider returned {localpart:?} as username,
487 which is not linked to that upstream account"
488 ))
489 .with_language(&locale);
490
491 return Ok((
492 cookie_jar,
493 Html(templates.render_error(&ctx)?).into_response(),
494 ));
495 }
496
497 let res = policy
498 .evaluate_register(mas_policy::RegisterInput {
499 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
500 username: &localpart,
501 email: None,
502 requester: mas_policy::Requester {
503 ip_address: activity_tracker.ip(),
504 user_agent: user_agent.clone(),
505 },
506 })
507 .await?;
508
509 if res.valid() {
510 ctx.with_localpart(
512 localpart,
513 provider.claims_imports.localpart.is_forced(),
514 )
515 } else if provider.claims_imports.localpart.is_forced() {
516 let ctx = ErrorContext::new()
520 .with_code("Policy error")
521 .with_description(format!(
522 r"Upstream account provider returned {localpart:?} as username,
523 which does not pass the policy check: {res}"
524 ))
525 .with_language(&locale);
526
527 return Ok((
528 cookie_jar,
529 Html(templates.render_error(&ctx)?).into_response(),
530 ));
531 } else {
532 ctx
534 }
535 }
536 None => ctx,
537 }
538 };
539
540 let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
541
542 Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
543 }
544 };
545
546 Ok((cookie_jar, response))
547}
548
549#[tracing::instrument(
550 name = "handlers.upstream_oauth2.link.post",
551 fields(upstream_oauth_link.id = %link_id),
552 skip_all,
553)]
554pub(crate) async fn post(
555 mut rng: BoxRng,
556 clock: BoxClock,
557 mut repo: BoxRepository,
558 cookie_jar: CookieJar,
559 user_agent: Option<TypedHeader<headers::UserAgent>>,
560 mut policy: Policy,
561 PreferredLanguage(locale): PreferredLanguage,
562 activity_tracker: BoundActivityTracker,
563 State(templates): State<Templates>,
564 State(homeserver): State<Arc<dyn HomeserverConnection>>,
565 State(url_builder): State<UrlBuilder>,
566 State(site_config): State<SiteConfig>,
567 Path(link_id): Path<Ulid>,
568 Form(form): Form<ProtectedForm<FormData>>,
569) -> Result<Response, RouteError> {
570 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
571 let form = cookie_jar.verify_form(&clock, form)?;
572
573 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
574 let (session_id, post_auth_action) = sessions_cookie
575 .lookup_link(link_id)
576 .map_err(|_| RouteError::MissingCookie)?;
577
578 let post_auth_action = OptionalPostAuthAction {
579 post_auth_action: post_auth_action.cloned(),
580 };
581
582 let link = repo
583 .upstream_oauth_link()
584 .lookup(link_id)
585 .await?
586 .ok_or(RouteError::LinkNotFound)?;
587
588 let upstream_session = repo
589 .upstream_oauth_session()
590 .lookup(session_id)
591 .await?
592 .ok_or(RouteError::SessionNotFound(session_id))?;
593
594 if upstream_session.link_id() != Some(link.id) {
597 return Err(RouteError::SessionNotFound(session_id));
598 }
599
600 if upstream_session.is_consumed() {
601 return Err(RouteError::SessionConsumed(session_id));
602 }
603
604 let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
605 let (user_session_info, cookie_jar) = cookie_jar.session_info();
606 let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
607 let form_state = form.to_form_state();
608
609 let session = match (maybe_user_session, link.user_id, form) {
610 (Some(session), None, FormData::Link) => {
611 repo.upstream_oauth_link()
614 .associate_to_user(&link, &session.user)
615 .await?;
616
617 session
618 }
619
620 (
621 None,
622 None,
623 FormData::Register {
624 username,
625 import_email,
626 import_display_name,
627 accept_terms,
628 },
629 ) => {
630 let import_email = import_email.is_some();
637 let import_display_name = import_display_name.is_some();
638 let accept_terms = accept_terms.is_some();
639
640 let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
641
642 let provider = repo
643 .upstream_oauth_provider()
644 .lookup(link.provider_id)
645 .await?
646 .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
647
648 let env = environment();
650
651 let mut context = AttributeMappingContext::new();
652 if let Some(id_token) = id_token {
653 let (_, payload) = id_token.into_parts();
654 context = context.with_id_token_claims(payload);
655 }
656 if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
657 context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
658 }
659 if let Some(userinfo) = upstream_session.userinfo() {
660 context = context.with_userinfo_claims(userinfo.clone());
661 }
662 let context = context.build();
663
664 let ctx = UpstreamRegister::new(link.clone(), provider.clone());
666
667 let display_name = if provider
668 .claims_imports
669 .displayname
670 .should_import(import_display_name)
671 {
672 let template = provider
673 .claims_imports
674 .displayname
675 .template
676 .as_deref()
677 .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
678
679 render_attribute_template(
680 &env,
681 template,
682 &context,
683 provider.claims_imports.displayname.is_required(),
684 )?
685 } else {
686 None
687 };
688
689 let ctx = if let Some(ref display_name) = display_name {
690 ctx.with_display_name(
691 display_name.clone(),
692 provider.claims_imports.email.is_forced(),
693 )
694 } else {
695 ctx
696 };
697
698 let email = if provider.claims_imports.email.should_import(import_email) {
699 let template = provider
700 .claims_imports
701 .email
702 .template
703 .as_deref()
704 .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
705
706 render_attribute_template(
707 &env,
708 template,
709 &context,
710 provider.claims_imports.email.is_required(),
711 )?
712 } else {
713 None
714 };
715
716 let ctx = if let Some(ref email) = email {
717 ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
718 } else {
719 ctx
720 };
721
722 let username = if provider.claims_imports.localpart.is_forced() {
723 let template = provider
724 .claims_imports
725 .localpart
726 .template
727 .as_deref()
728 .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
729
730 render_attribute_template(&env, template, &context, true)?
731 } else {
732 username
734 }
735 .unwrap_or_default();
736
737 let ctx = ctx.with_localpart(
738 username.clone(),
739 provider.claims_imports.localpart.is_forced(),
740 );
741
742 let form_state = {
744 let mut form_state = form_state;
745 let mut homeserver_denied_username = false;
746 if username.is_empty() {
747 form_state.add_error_on_field(
748 mas_templates::UpstreamRegisterFormField::Username,
749 FieldError::Required,
750 );
751 } else if repo.user().exists(&username).await? {
752 form_state.add_error_on_field(
753 mas_templates::UpstreamRegisterFormField::Username,
754 FieldError::Exists,
755 );
756 } else if !homeserver
757 .is_localpart_available(&username)
758 .await
759 .map_err(RouteError::HomeserverConnection)?
760 {
761 tracing::warn!(
763 %username,
764 "Homeserver denied username provided by user"
765 );
766
767 homeserver_denied_username = true;
770 }
771
772 if site_config.tos_uri.is_some() && !accept_terms {
774 form_state.add_error_on_field(
775 mas_templates::UpstreamRegisterFormField::AcceptTerms,
776 FieldError::Required,
777 );
778 }
779
780 let res = policy
782 .evaluate_register(mas_policy::RegisterInput {
783 registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
784 username: &username,
785 email: email.as_deref(),
786 requester: mas_policy::Requester {
787 ip_address: activity_tracker.ip(),
788 user_agent: user_agent.clone(),
789 },
790 })
791 .await?;
792
793 for violation in res.violations {
794 match violation.field.as_deref() {
795 Some("username") => {
796 homeserver_denied_username = false;
800 form_state.add_error_on_field(
801 mas_templates::UpstreamRegisterFormField::Username,
802 FieldError::Policy {
803 code: violation.code.map(|c| c.as_str()),
804 message: violation.msg,
805 },
806 );
807 }
808 _ => form_state.add_error_on_form(FormError::Policy {
809 code: violation.code.map(|c| c.as_str()),
810 message: violation.msg,
811 }),
812 }
813 }
814
815 if homeserver_denied_username {
816 form_state.add_error_on_field(
818 mas_templates::UpstreamRegisterFormField::Username,
819 FieldError::Exists,
820 );
821 }
822
823 form_state
824 };
825
826 if !form_state.is_valid() {
827 let ctx = ctx
828 .with_form_state(form_state)
829 .with_csrf(csrf_token.form_value())
830 .with_language(locale);
831
832 return Ok((
833 cookie_jar,
834 Html(templates.render_upstream_oauth2_do_register(&ctx)?),
835 )
836 .into_response());
837 }
838
839 REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
840
841 let user = repo.user().add(&mut rng, &clock, username).await?;
843
844 if let Some(terms_url) = &site_config.tos_uri {
845 repo.user_terms()
846 .accept_terms(&mut rng, &clock, &user, terms_url.clone())
847 .await?;
848 }
849
850 let mut job = ProvisionUserJob::new(&user);
852
853 if let Some(name) = display_name {
855 job = job.set_display_name(name);
856 }
857
858 repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
859
860 if let Some(email) = email {
862 repo.user_email()
863 .add(&mut rng, &clock, &user, email)
864 .await?;
865 }
866
867 repo.upstream_oauth_link()
868 .associate_to_user(&link, &user)
869 .await?;
870
871 repo.browser_session()
872 .add(&mut rng, &clock, &user, user_agent)
873 .await?
874 }
875
876 _ => return Err(RouteError::InvalidFormAction),
877 };
878
879 let upstream_session = repo
880 .upstream_oauth_session()
881 .consume(&clock, upstream_session)
882 .await?;
883
884 repo.browser_session()
885 .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
886 .await?;
887
888 let cookie_jar = sessions_cookie
889 .consume_link(link_id)?
890 .save(cookie_jar, &clock);
891 let cookie_jar = cookie_jar.set_session(&session);
892
893 repo.save().await?;
894
895 Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
896}
897
898#[cfg(test)]
899mod tests {
900 use hyper::{Request, StatusCode, header::CONTENT_TYPE};
901 use mas_data_model::{
902 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
903 UpstreamOAuthProviderTokenAuthMethod,
904 };
905 use mas_iana::jose::JsonWebSignatureAlg;
906 use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
907 use mas_router::Route;
908 use mas_storage::{
909 Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
910 };
911 use oauth2_types::scope::{OPENID, Scope};
912 use sqlx::PgPool;
913
914 use super::UpstreamSessionsCookie;
915 use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
916
917 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
918 async fn test_register(pool: PgPool) {
919 setup();
920 let state = TestState::from_pool(pool).await.unwrap();
921 let mut rng = state.rng();
922 let cookies = CookieHelper::new();
923
924 let claims_imports = UpstreamOAuthProviderClaimsImports {
925 localpart: UpstreamOAuthProviderImportPreference {
926 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
927 template: None,
928 },
929 email: UpstreamOAuthProviderImportPreference {
930 action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
931 template: None,
932 },
933 ..UpstreamOAuthProviderClaimsImports::default()
934 };
935
936 let id_token = serde_json::json!({
937 "preferred_username": "john",
938 "email": "john@example.com",
939 "email_verified": true,
940 });
941
942 let key = state
946 .key_store
947 .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
948 .unwrap();
949
950 let signer = key
951 .params()
952 .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
953 .unwrap();
954 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
955 let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
956
957 let mut repo = state.repository().await.unwrap();
959 let provider = repo
960 .upstream_oauth_provider()
961 .add(
962 &mut rng,
963 &state.clock,
964 UpstreamOAuthProviderParams {
965 issuer: Some("https://example.com/".to_owned()),
966 human_name: Some("Example Ltd.".to_owned()),
967 brand_name: None,
968 scope: Scope::from_iter([OPENID]),
969 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
970 token_endpoint_signing_alg: None,
971 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
972 client_id: "client".to_owned(),
973 encrypted_client_secret: None,
974 claims_imports,
975 authorization_endpoint_override: None,
976 token_endpoint_override: None,
977 userinfo_endpoint_override: None,
978 fetch_userinfo: false,
979 userinfo_signed_response_alg: None,
980 jwks_uri_override: None,
981 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
982 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
983 response_mode: None,
984 additional_authorization_parameters: Vec::new(),
985 ui_order: 0,
986 },
987 )
988 .await
989 .unwrap();
990
991 let session = repo
992 .upstream_oauth_session()
993 .add(
994 &mut rng,
995 &state.clock,
996 &provider,
997 "state".to_owned(),
998 None,
999 "nonce".to_owned(),
1000 )
1001 .await
1002 .unwrap();
1003
1004 let link = repo
1005 .upstream_oauth_link()
1006 .add(
1007 &mut rng,
1008 &state.clock,
1009 &provider,
1010 "subject".to_owned(),
1011 None,
1012 )
1013 .await
1014 .unwrap();
1015
1016 let session = repo
1017 .upstream_oauth_session()
1018 .complete_with_link(
1019 &state.clock,
1020 session,
1021 &link,
1022 Some(id_token.into_string()),
1023 None,
1024 None,
1025 )
1026 .await
1027 .unwrap();
1028
1029 repo.save().await.unwrap();
1030
1031 let cookie_jar = state.cookie_jar();
1032 let upstream_sessions = UpstreamSessionsCookie::default()
1033 .add(session.id, provider.id, "state".to_owned(), None)
1034 .add_link_to_session(session.id, link.id)
1035 .unwrap();
1036 let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1037 cookies.import(cookie_jar);
1038
1039 let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1040 let request = cookies.with_cookies(request);
1041 let response = state.request(request).await;
1042 cookies.save_cookies(&response);
1043 response.assert_status(StatusCode::OK);
1044 response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1045
1046 let csrf_token = response
1048 .body()
1049 .split("name=\"csrf\" value=\"")
1050 .nth(1)
1051 .unwrap()
1052 .split('\"')
1053 .next()
1054 .unwrap();
1055
1056 let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1057 serde_json::json!({
1058 "csrf": csrf_token,
1059 "action": "register",
1060 "import_email": "on",
1061 "accept_terms": "on",
1062 }),
1063 );
1064 let request = cookies.with_cookies(request);
1065 let response = state.request(request).await;
1066 cookies.save_cookies(&response);
1067 response.assert_status(StatusCode::SEE_OTHER);
1068
1069 let mut repo = state.repository().await.unwrap();
1071 let user = repo
1072 .user()
1073 .find_by_username("john")
1074 .await
1075 .unwrap()
1076 .expect("user exists");
1077
1078 let link = repo
1079 .upstream_oauth_link()
1080 .find_by_subject(&provider, "subject")
1081 .await
1082 .unwrap()
1083 .expect("link exists");
1084
1085 assert_eq!(link.user_id, Some(user.id));
1086
1087 let page = repo
1088 .user_email()
1089 .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1090 .await
1091 .unwrap();
1092 let email = page.edges.first().expect("email exists");
1093
1094 assert_eq!(email.email, "john@example.com");
1095 }
1096}