mas_handlers/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::{Arc, LazyLock};
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    response::{Html, IntoResponse, Response},
13};
14use axum_extra::typed_header::TypedHeader;
15use hyper::StatusCode;
16use mas_axum_utils::{
17    FancyError, SessionInfoExt,
18    cookies::CookieJar,
19    csrf::{CsrfExt, ProtectedForm},
20    record_error,
21};
22use mas_jose::jwt::Jwt;
23use mas_matrix::HomeserverConnection;
24use mas_policy::Policy;
25use mas_router::UrlBuilder;
26use mas_storage::{
27    BoxClock, BoxRepository, BoxRng, RepositoryAccess,
28    queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
29    upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
30    user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
31};
32use mas_templates::{
33    AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
34    ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
35};
36use minijinja::Environment;
37use opentelemetry::{Key, KeyValue, metrics::Counter};
38use serde::{Deserialize, Serialize};
39use thiserror::Error;
40use tracing::warn;
41use ulid::Ulid;
42
43use super::{
44    UpstreamSessionsCookie,
45    template::{AttributeMappingContext, environment},
46};
47use crate::{
48    BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49    views::shared::OptionalPostAuthAction,
50};
51
52static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
53    METER
54        .u64_counter("mas.upstream_oauth2.login")
55        .with_description("Successful upstream OAuth 2.0 login to existing accounts")
56        .with_unit("{login}")
57        .build()
58});
59static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
60    METER
61        .u64_counter("mas.upstream_oauth2.registration")
62        .with_description("Successful upstream OAuth 2.0 registration")
63        .with_unit("{registration}")
64        .build()
65});
66const PROVIDER: Key = Key::from_static_str("provider");
67
68const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
69const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}";
70const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}";
71
72#[derive(Debug, Error)]
73pub(crate) enum RouteError {
74    /// Couldn't find the link specified in the URL
75    #[error("Link not found")]
76    LinkNotFound,
77
78    /// Couldn't find the session on the link
79    #[error("Session {0} not found")]
80    SessionNotFound(Ulid),
81
82    /// Couldn't find the user
83    #[error("User {0} not found")]
84    UserNotFound(Ulid),
85
86    /// Couldn't find upstream provider
87    #[error("Upstream provider {0} not found")]
88    ProviderNotFound(Ulid),
89
90    /// Required attribute rendered to an empty string
91    #[error("Template {template:?} rendered to an empty string")]
92    RequiredAttributeEmpty { template: String },
93
94    /// Required claim was missing in `id_token`
95    #[error(
96        "Template {template:?} could not be rendered from the upstream provider's response for required claim"
97    )]
98    RequiredAttributeRender {
99        template: String,
100
101        #[source]
102        source: minijinja::Error,
103    },
104
105    /// Session was already consumed
106    #[error("Session {0} already consumed")]
107    SessionConsumed(Ulid),
108
109    #[error("Missing session cookie")]
110    MissingCookie,
111
112    #[error("Invalid form action")]
113    InvalidFormAction,
114
115    #[error("Homeserver connection error")]
116    HomeserverConnection(#[source] anyhow::Error),
117
118    #[error(transparent)]
119    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
120}
121
122impl_from_error_for_route!(mas_templates::TemplateError);
123impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
124impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
125impl_from_error_for_route!(mas_storage::RepositoryError);
126impl_from_error_for_route!(mas_policy::EvaluationError);
127impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
128
129impl IntoResponse for RouteError {
130    fn into_response(self) -> axum::response::Response {
131        let sentry_event_id = record_error!(
132            self,
133            Self::Internal(_)
134                | Self::RequiredAttributeEmpty { .. }
135                | Self::RequiredAttributeRender { .. }
136                | Self::SessionNotFound(_)
137                | Self::ProviderNotFound(_)
138                | Self::UserNotFound(_)
139                | Self::HomeserverConnection(_)
140        );
141        let response = match self {
142            Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
143            Self::Internal(e) => FancyError::from(e).into_response(),
144            e => FancyError::from(e).into_response(),
145        };
146
147        (sentry_event_id, response).into_response()
148    }
149}
150
151/// Utility function to render an attribute template.
152///
153/// # Parameters
154///
155/// * `environment` - The minijinja environment to use to render the template
156/// * `template` - The template to use to render the claim
157/// * `required` - Whether the attribute is required or not
158///
159/// # Errors
160///
161/// Returns an error if the attribute is required but fails to render or is
162/// empty
163fn render_attribute_template(
164    environment: &Environment,
165    template: &str,
166    context: &minijinja::Value,
167    required: bool,
168) -> Result<Option<String>, RouteError> {
169    match environment.render_str(template, context) {
170        Ok(value) if value.is_empty() => {
171            if required {
172                return Err(RouteError::RequiredAttributeEmpty {
173                    template: template.to_owned(),
174                });
175            }
176
177            Ok(None)
178        }
179
180        Ok(value) => Ok(Some(value)),
181
182        Err(source) => {
183            if required {
184                return Err(RouteError::RequiredAttributeRender {
185                    template: template.to_owned(),
186                    source,
187                });
188            }
189
190            tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template");
191            Ok(None)
192        }
193    }
194}
195
196#[derive(Deserialize, Serialize)]
197#[serde(rename_all = "lowercase", tag = "action")]
198pub(crate) enum FormData {
199    Register {
200        #[serde(default)]
201        username: Option<String>,
202        #[serde(default)]
203        import_email: Option<String>,
204        #[serde(default)]
205        import_display_name: Option<String>,
206        #[serde(default)]
207        accept_terms: Option<String>,
208    },
209    Link,
210}
211
212impl ToFormState for FormData {
213    type Field = mas_templates::UpstreamRegisterFormField;
214}
215
216#[tracing::instrument(
217    name = "handlers.upstream_oauth2.link.get",
218    fields(upstream_oauth_link.id = %link_id),
219    skip_all,
220)]
221pub(crate) async fn get(
222    mut rng: BoxRng,
223    clock: BoxClock,
224    mut repo: BoxRepository,
225    mut policy: Policy,
226    PreferredLanguage(locale): PreferredLanguage,
227    State(templates): State<Templates>,
228    State(url_builder): State<UrlBuilder>,
229    State(homeserver): State<Arc<dyn HomeserverConnection>>,
230    cookie_jar: CookieJar,
231    activity_tracker: BoundActivityTracker,
232    user_agent: Option<TypedHeader<headers::UserAgent>>,
233    Path(link_id): Path<Ulid>,
234) -> Result<impl IntoResponse, RouteError> {
235    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
236    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
237    let (session_id, post_auth_action) = sessions_cookie
238        .lookup_link(link_id)
239        .map_err(|_| RouteError::MissingCookie)?;
240
241    let post_auth_action = OptionalPostAuthAction {
242        post_auth_action: post_auth_action.cloned(),
243    };
244
245    let link = repo
246        .upstream_oauth_link()
247        .lookup(link_id)
248        .await?
249        .ok_or(RouteError::LinkNotFound)?;
250
251    let upstream_session = repo
252        .upstream_oauth_session()
253        .lookup(session_id)
254        .await?
255        .ok_or(RouteError::SessionNotFound(session_id))?;
256
257    // This checks that we're in a browser session which is allowed to consume this
258    // link: the upstream auth session should have been started in this browser.
259    if upstream_session.link_id() != Some(link.id) {
260        return Err(RouteError::SessionNotFound(session_id));
261    }
262
263    if upstream_session.is_consumed() {
264        return Err(RouteError::SessionConsumed(session_id));
265    }
266
267    let (user_session_info, cookie_jar) = cookie_jar.session_info();
268    let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
269    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
270
271    let response = match (maybe_user_session, link.user_id) {
272        (Some(session), Some(user_id)) if session.user.id == user_id => {
273            // Session already linked, and link matches the currently logged
274            // user. Mark the session as consumed and renew the authentication.
275            let upstream_session = repo
276                .upstream_oauth_session()
277                .consume(&clock, upstream_session)
278                .await?;
279
280            repo.browser_session()
281                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
282                .await?;
283
284            cookie_jar = cookie_jar.set_session(&session);
285
286            repo.save().await?;
287
288            post_auth_action.go_next(&url_builder).into_response()
289        }
290
291        (Some(user_session), Some(user_id)) => {
292            // Session already linked, but link doesn't match the currently
293            // logged user. Suggest logging out of the current user
294            // and logging in with the new one
295            let user = repo
296                .user()
297                .lookup(user_id)
298                .await?
299                .ok_or(RouteError::UserNotFound(user_id))?;
300
301            let ctx = UpstreamExistingLinkContext::new(user)
302                .with_session(user_session)
303                .with_csrf(csrf_token.form_value())
304                .with_language(locale);
305
306            Html(templates.render_upstream_oauth2_link_mismatch(&ctx)?).into_response()
307        }
308
309        (Some(user_session), None) => {
310            // Session not linked, but user logged in: suggest linking account
311            let ctx = UpstreamSuggestLink::new(&link)
312                .with_session(user_session)
313                .with_csrf(csrf_token.form_value())
314                .with_language(locale);
315
316            Html(templates.render_upstream_oauth2_suggest_link(&ctx)?).into_response()
317        }
318
319        (None, Some(user_id)) => {
320            // Session linked, but user not logged in: do the login
321            let user = repo
322                .user()
323                .lookup(user_id)
324                .await?
325                .ok_or(RouteError::UserNotFound(user_id))?;
326
327            // Check that the user is not locked or deactivated
328            if user.deactivated_at.is_some() {
329                // The account is deactivated, show the 'account deactivated' fallback
330                let ctx = AccountInactiveContext::new(user)
331                    .with_csrf(csrf_token.form_value())
332                    .with_language(locale);
333                let fallback = templates.render_account_deactivated(&ctx)?;
334                return Ok((cookie_jar, Html(fallback).into_response()));
335            }
336
337            if user.locked_at.is_some() {
338                // The account is locked, show the 'account locked' fallback
339                let ctx = AccountInactiveContext::new(user)
340                    .with_csrf(csrf_token.form_value())
341                    .with_language(locale);
342                let fallback = templates.render_account_locked(&ctx)?;
343                return Ok((cookie_jar, Html(fallback).into_response()));
344            }
345
346            let session = repo
347                .browser_session()
348                .add(&mut rng, &clock, &user, user_agent)
349                .await?;
350
351            let upstream_session = repo
352                .upstream_oauth_session()
353                .consume(&clock, upstream_session)
354                .await?;
355
356            repo.browser_session()
357                .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
358                .await?;
359
360            cookie_jar = sessions_cookie
361                .consume_link(link_id)?
362                .save(cookie_jar, &clock);
363            cookie_jar = cookie_jar.set_session(&session);
364
365            repo.save().await?;
366
367            LOGIN_COUNTER.add(
368                1,
369                &[KeyValue::new(
370                    PROVIDER,
371                    upstream_session.provider_id.to_string(),
372                )],
373            );
374
375            post_auth_action.go_next(&url_builder).into_response()
376        }
377
378        (None, None) => {
379            // Session not linked and used not logged in: suggest creating an
380            // account or logging in an existing user
381            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
382
383            let provider = repo
384                .upstream_oauth_provider()
385                .lookup(link.provider_id)
386                .await?
387                .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
388
389            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
390
391            let env = environment();
392
393            let mut context = AttributeMappingContext::new();
394            if let Some(id_token) = id_token {
395                let (_, payload) = id_token.into_parts();
396                context = context.with_id_token_claims(payload);
397            }
398            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
399                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
400            }
401            if let Some(userinfo) = upstream_session.userinfo() {
402                context = context.with_userinfo_claims(userinfo.clone());
403            }
404            let context = context.build();
405
406            let ctx = if provider.claims_imports.displayname.ignore() {
407                ctx
408            } else {
409                let template = provider
410                    .claims_imports
411                    .displayname
412                    .template
413                    .as_deref()
414                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
415
416                match render_attribute_template(
417                    &env,
418                    template,
419                    &context,
420                    provider.claims_imports.displayname.is_required(),
421                )? {
422                    Some(value) => ctx
423                        .with_display_name(value, provider.claims_imports.displayname.is_forced()),
424                    None => ctx,
425                }
426            };
427
428            let ctx = if provider.claims_imports.email.ignore() {
429                ctx
430            } else {
431                let template = provider
432                    .claims_imports
433                    .email
434                    .template
435                    .as_deref()
436                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
437
438                match render_attribute_template(
439                    &env,
440                    template,
441                    &context,
442                    provider.claims_imports.email.is_required(),
443                )? {
444                    Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()),
445                    None => ctx,
446                }
447            };
448
449            let ctx = if provider.claims_imports.localpart.ignore() {
450                ctx
451            } else {
452                let template = provider
453                    .claims_imports
454                    .localpart
455                    .template
456                    .as_deref()
457                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
458
459                match render_attribute_template(
460                    &env,
461                    template,
462                    &context,
463                    provider.claims_imports.localpart.is_required(),
464                )? {
465                    Some(localpart) => {
466                        // We could run policy & existing user checks when the user submits the
467                        // form, but this lead to poor UX. This is why we do
468                        // it ahead of time here.
469                        let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
470                        let is_available = homeserver
471                            .is_localpart_available(&localpart)
472                            .await
473                            .map_err(RouteError::HomeserverConnection)?;
474
475                        if maybe_existing_user.is_some() || !is_available {
476                            if let Some(existing_user) = maybe_existing_user {
477                                // The mapper returned a username which already exists, but isn't
478                                // linked to this upstream user.
479                                warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");
480                            }
481
482                            // TODO: translate
483                            let ctx = ErrorContext::new()
484                                .with_code("User exists")
485                                .with_description(format!(
486                                    r"Upstream account provider returned {localpart:?} as username,
487                                    which is not linked to that upstream account"
488                                ))
489                                .with_language(&locale);
490
491                            return Ok((
492                                cookie_jar,
493                                Html(templates.render_error(&ctx)?).into_response(),
494                            ));
495                        }
496
497                        let res = policy
498                            .evaluate_register(mas_policy::RegisterInput {
499                                registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
500                                username: &localpart,
501                                email: None,
502                                requester: mas_policy::Requester {
503                                    ip_address: activity_tracker.ip(),
504                                    user_agent: user_agent.clone(),
505                                },
506                            })
507                            .await?;
508
509                        if res.valid() {
510                            // The username passes the policy check, add it to the context
511                            ctx.with_localpart(
512                                localpart,
513                                provider.claims_imports.localpart.is_forced(),
514                            )
515                        } else if provider.claims_imports.localpart.is_forced() {
516                            // If the username claim is 'forced' but doesn't pass the policy check,
517                            // we display an error message.
518                            // TODO: translate
519                            let ctx = ErrorContext::new()
520                                .with_code("Policy error")
521                                .with_description(format!(
522                                    r"Upstream account provider returned {localpart:?} as username,
523                                    which does not pass the policy check: {res}"
524                                ))
525                                .with_language(&locale);
526
527                            return Ok((
528                                cookie_jar,
529                                Html(templates.render_error(&ctx)?).into_response(),
530                            ));
531                        } else {
532                            // Else, we just ignore it when it doesn't pass the policy check.
533                            ctx
534                        }
535                    }
536                    None => ctx,
537                }
538            };
539
540            let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
541
542            Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
543        }
544    };
545
546    Ok((cookie_jar, response))
547}
548
549#[tracing::instrument(
550    name = "handlers.upstream_oauth2.link.post",
551    fields(upstream_oauth_link.id = %link_id),
552    skip_all,
553)]
554pub(crate) async fn post(
555    mut rng: BoxRng,
556    clock: BoxClock,
557    mut repo: BoxRepository,
558    cookie_jar: CookieJar,
559    user_agent: Option<TypedHeader<headers::UserAgent>>,
560    mut policy: Policy,
561    PreferredLanguage(locale): PreferredLanguage,
562    activity_tracker: BoundActivityTracker,
563    State(templates): State<Templates>,
564    State(homeserver): State<Arc<dyn HomeserverConnection>>,
565    State(url_builder): State<UrlBuilder>,
566    State(site_config): State<SiteConfig>,
567    Path(link_id): Path<Ulid>,
568    Form(form): Form<ProtectedForm<FormData>>,
569) -> Result<Response, RouteError> {
570    let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
571    let form = cookie_jar.verify_form(&clock, form)?;
572
573    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
574    let (session_id, post_auth_action) = sessions_cookie
575        .lookup_link(link_id)
576        .map_err(|_| RouteError::MissingCookie)?;
577
578    let post_auth_action = OptionalPostAuthAction {
579        post_auth_action: post_auth_action.cloned(),
580    };
581
582    let link = repo
583        .upstream_oauth_link()
584        .lookup(link_id)
585        .await?
586        .ok_or(RouteError::LinkNotFound)?;
587
588    let upstream_session = repo
589        .upstream_oauth_session()
590        .lookup(session_id)
591        .await?
592        .ok_or(RouteError::SessionNotFound(session_id))?;
593
594    // This checks that we're in a browser session which is allowed to consume this
595    // link: the upstream auth session should have been started in this browser.
596    if upstream_session.link_id() != Some(link.id) {
597        return Err(RouteError::SessionNotFound(session_id));
598    }
599
600    if upstream_session.is_consumed() {
601        return Err(RouteError::SessionConsumed(session_id));
602    }
603
604    let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
605    let (user_session_info, cookie_jar) = cookie_jar.session_info();
606    let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
607    let form_state = form.to_form_state();
608
609    let session = match (maybe_user_session, link.user_id, form) {
610        (Some(session), None, FormData::Link) => {
611            // The user is already logged in, the link is not linked to any user, and the
612            // user asked to link their account.
613            repo.upstream_oauth_link()
614                .associate_to_user(&link, &session.user)
615                .await?;
616
617            session
618        }
619
620        (
621            None,
622            None,
623            FormData::Register {
624                username,
625                import_email,
626                import_display_name,
627                accept_terms,
628            },
629        ) => {
630            // The user got the form to register a new account, and is not logged in.
631            // Depending on the claims_imports, we've let the user choose their username,
632            // choose whether they want to import the email and display name, or
633            // not.
634
635            // Those fields are Some("on") if the checkbox is checked
636            let import_email = import_email.is_some();
637            let import_display_name = import_display_name.is_some();
638            let accept_terms = accept_terms.is_some();
639
640            let id_token = upstream_session.id_token().map(Jwt::try_from).transpose()?;
641
642            let provider = repo
643                .upstream_oauth_provider()
644                .lookup(link.provider_id)
645                .await?
646                .ok_or(RouteError::ProviderNotFound(link.provider_id))?;
647
648            // Let's try to import the claims from the ID token
649            let env = environment();
650
651            let mut context = AttributeMappingContext::new();
652            if let Some(id_token) = id_token {
653                let (_, payload) = id_token.into_parts();
654                context = context.with_id_token_claims(payload);
655            }
656            if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
657                context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
658            }
659            if let Some(userinfo) = upstream_session.userinfo() {
660                context = context.with_userinfo_claims(userinfo.clone());
661            }
662            let context = context.build();
663
664            // Create a template context in case we need to re-render because of an error
665            let ctx = UpstreamRegister::new(link.clone(), provider.clone());
666
667            let display_name = if provider
668                .claims_imports
669                .displayname
670                .should_import(import_display_name)
671            {
672                let template = provider
673                    .claims_imports
674                    .displayname
675                    .template
676                    .as_deref()
677                    .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE);
678
679                render_attribute_template(
680                    &env,
681                    template,
682                    &context,
683                    provider.claims_imports.displayname.is_required(),
684                )?
685            } else {
686                None
687            };
688
689            let ctx = if let Some(ref display_name) = display_name {
690                ctx.with_display_name(
691                    display_name.clone(),
692                    provider.claims_imports.email.is_forced(),
693                )
694            } else {
695                ctx
696            };
697
698            let email = if provider.claims_imports.email.should_import(import_email) {
699                let template = provider
700                    .claims_imports
701                    .email
702                    .template
703                    .as_deref()
704                    .unwrap_or(DEFAULT_EMAIL_TEMPLATE);
705
706                render_attribute_template(
707                    &env,
708                    template,
709                    &context,
710                    provider.claims_imports.email.is_required(),
711                )?
712            } else {
713                None
714            };
715
716            let ctx = if let Some(ref email) = email {
717                ctx.with_email(email.clone(), provider.claims_imports.email.is_forced())
718            } else {
719                ctx
720            };
721
722            let username = if provider.claims_imports.localpart.is_forced() {
723                let template = provider
724                    .claims_imports
725                    .localpart
726                    .template
727                    .as_deref()
728                    .unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
729
730                render_attribute_template(&env, template, &context, true)?
731            } else {
732                // If there is no forced username, we can use the one the user entered
733                username
734            }
735            .unwrap_or_default();
736
737            let ctx = ctx.with_localpart(
738                username.clone(),
739                provider.claims_imports.localpart.is_forced(),
740            );
741
742            // Validate the form
743            let form_state = {
744                let mut form_state = form_state;
745                let mut homeserver_denied_username = false;
746                if username.is_empty() {
747                    form_state.add_error_on_field(
748                        mas_templates::UpstreamRegisterFormField::Username,
749                        FieldError::Required,
750                    );
751                } else if repo.user().exists(&username).await? {
752                    form_state.add_error_on_field(
753                        mas_templates::UpstreamRegisterFormField::Username,
754                        FieldError::Exists,
755                    );
756                } else if !homeserver
757                    .is_localpart_available(&username)
758                    .await
759                    .map_err(RouteError::HomeserverConnection)?
760                {
761                    // The user already exists on the homeserver
762                    tracing::warn!(
763                        %username,
764                        "Homeserver denied username provided by user"
765                    );
766
767                    // We defer adding the error on the field, until we know whether we had another
768                    // error from the policy, to avoid showing both
769                    homeserver_denied_username = true;
770                }
771
772                // If we have a TOS in the config, make sure the user has accepted it
773                if site_config.tos_uri.is_some() && !accept_terms {
774                    form_state.add_error_on_field(
775                        mas_templates::UpstreamRegisterFormField::AcceptTerms,
776                        FieldError::Required,
777                    );
778                }
779
780                // Policy check
781                let res = policy
782                    .evaluate_register(mas_policy::RegisterInput {
783                        registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
784                        username: &username,
785                        email: email.as_deref(),
786                        requester: mas_policy::Requester {
787                            ip_address: activity_tracker.ip(),
788                            user_agent: user_agent.clone(),
789                        },
790                    })
791                    .await?;
792
793                for violation in res.violations {
794                    match violation.field.as_deref() {
795                        Some("username") => {
796                            // If the homeserver denied the username, but we also had an error on
797                            // the policy side, we don't want to show
798                            // both, so we reset the state here
799                            homeserver_denied_username = false;
800                            form_state.add_error_on_field(
801                                mas_templates::UpstreamRegisterFormField::Username,
802                                FieldError::Policy {
803                                    code: violation.code.map(|c| c.as_str()),
804                                    message: violation.msg,
805                                },
806                            );
807                        }
808                        _ => form_state.add_error_on_form(FormError::Policy {
809                            code: violation.code.map(|c| c.as_str()),
810                            message: violation.msg,
811                        }),
812                    }
813                }
814
815                if homeserver_denied_username {
816                    // XXX: we may want to return different errors like "this username is reserved"
817                    form_state.add_error_on_field(
818                        mas_templates::UpstreamRegisterFormField::Username,
819                        FieldError::Exists,
820                    );
821                }
822
823                form_state
824            };
825
826            if !form_state.is_valid() {
827                let ctx = ctx
828                    .with_form_state(form_state)
829                    .with_csrf(csrf_token.form_value())
830                    .with_language(locale);
831
832                return Ok((
833                    cookie_jar,
834                    Html(templates.render_upstream_oauth2_do_register(&ctx)?),
835                )
836                    .into_response());
837            }
838
839            REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
840
841            // Now we can create the user
842            let user = repo.user().add(&mut rng, &clock, username).await?;
843
844            if let Some(terms_url) = &site_config.tos_uri {
845                repo.user_terms()
846                    .accept_terms(&mut rng, &clock, &user, terms_url.clone())
847                    .await?;
848            }
849
850            // And schedule the job to provision it
851            let mut job = ProvisionUserJob::new(&user);
852
853            // If we have a display name, set it during provisioning
854            if let Some(name) = display_name {
855                job = job.set_display_name(name);
856            }
857
858            repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
859
860            // If we have an email, add it to the user
861            if let Some(email) = email {
862                repo.user_email()
863                    .add(&mut rng, &clock, &user, email)
864                    .await?;
865            }
866
867            repo.upstream_oauth_link()
868                .associate_to_user(&link, &user)
869                .await?;
870
871            repo.browser_session()
872                .add(&mut rng, &clock, &user, user_agent)
873                .await?
874        }
875
876        _ => return Err(RouteError::InvalidFormAction),
877    };
878
879    let upstream_session = repo
880        .upstream_oauth_session()
881        .consume(&clock, upstream_session)
882        .await?;
883
884    repo.browser_session()
885        .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
886        .await?;
887
888    let cookie_jar = sessions_cookie
889        .consume_link(link_id)?
890        .save(cookie_jar, &clock);
891    let cookie_jar = cookie_jar.set_session(&session);
892
893    repo.save().await?;
894
895    Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
896}
897
898#[cfg(test)]
899mod tests {
900    use hyper::{Request, StatusCode, header::CONTENT_TYPE};
901    use mas_data_model::{
902        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference,
903        UpstreamOAuthProviderTokenAuthMethod,
904    };
905    use mas_iana::jose::JsonWebSignatureAlg;
906    use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
907    use mas_router::Route;
908    use mas_storage::{
909        Pagination, upstream_oauth2::UpstreamOAuthProviderParams, user::UserEmailFilter,
910    };
911    use oauth2_types::scope::{OPENID, Scope};
912    use sqlx::PgPool;
913
914    use super::UpstreamSessionsCookie;
915    use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
916
917    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
918    async fn test_register(pool: PgPool) {
919        setup();
920        let state = TestState::from_pool(pool).await.unwrap();
921        let mut rng = state.rng();
922        let cookies = CookieHelper::new();
923
924        let claims_imports = UpstreamOAuthProviderClaimsImports {
925            localpart: UpstreamOAuthProviderImportPreference {
926                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
927                template: None,
928            },
929            email: UpstreamOAuthProviderImportPreference {
930                action: mas_data_model::UpstreamOAuthProviderImportAction::Force,
931                template: None,
932            },
933            ..UpstreamOAuthProviderClaimsImports::default()
934        };
935
936        let id_token = serde_json::json!({
937            "preferred_username": "john",
938            "email": "john@example.com",
939            "email_verified": true,
940        });
941
942        // Grab a key to sign the id_token
943        // We could generate a key on the fly, but because we have one available here,
944        // why not use it?
945        let key = state
946            .key_store
947            .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256)
948            .unwrap();
949
950        let signer = key
951            .params()
952            .signing_key_for_alg(&JsonWebSignatureAlg::Rs256)
953            .unwrap();
954        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256);
955        let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap();
956
957        // Provision a provider and a link
958        let mut repo = state.repository().await.unwrap();
959        let provider = repo
960            .upstream_oauth_provider()
961            .add(
962                &mut rng,
963                &state.clock,
964                UpstreamOAuthProviderParams {
965                    issuer: Some("https://example.com/".to_owned()),
966                    human_name: Some("Example Ltd.".to_owned()),
967                    brand_name: None,
968                    scope: Scope::from_iter([OPENID]),
969                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
970                    token_endpoint_signing_alg: None,
971                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
972                    client_id: "client".to_owned(),
973                    encrypted_client_secret: None,
974                    claims_imports,
975                    authorization_endpoint_override: None,
976                    token_endpoint_override: None,
977                    userinfo_endpoint_override: None,
978                    fetch_userinfo: false,
979                    userinfo_signed_response_alg: None,
980                    jwks_uri_override: None,
981                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
982                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
983                    response_mode: None,
984                    additional_authorization_parameters: Vec::new(),
985                    ui_order: 0,
986                },
987            )
988            .await
989            .unwrap();
990
991        let session = repo
992            .upstream_oauth_session()
993            .add(
994                &mut rng,
995                &state.clock,
996                &provider,
997                "state".to_owned(),
998                None,
999                "nonce".to_owned(),
1000            )
1001            .await
1002            .unwrap();
1003
1004        let link = repo
1005            .upstream_oauth_link()
1006            .add(
1007                &mut rng,
1008                &state.clock,
1009                &provider,
1010                "subject".to_owned(),
1011                None,
1012            )
1013            .await
1014            .unwrap();
1015
1016        let session = repo
1017            .upstream_oauth_session()
1018            .complete_with_link(
1019                &state.clock,
1020                session,
1021                &link,
1022                Some(id_token.into_string()),
1023                None,
1024                None,
1025            )
1026            .await
1027            .unwrap();
1028
1029        repo.save().await.unwrap();
1030
1031        let cookie_jar = state.cookie_jar();
1032        let upstream_sessions = UpstreamSessionsCookie::default()
1033            .add(session.id, provider.id, "state".to_owned(), None)
1034            .add_link_to_session(session.id, link.id)
1035            .unwrap();
1036        let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock);
1037        cookies.import(cookie_jar);
1038
1039        let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty();
1040        let request = cookies.with_cookies(request);
1041        let response = state.request(request).await;
1042        cookies.save_cookies(&response);
1043        response.assert_status(StatusCode::OK);
1044        response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
1045
1046        // Extract the CSRF token from the response body
1047        let csrf_token = response
1048            .body()
1049            .split("name=\"csrf\" value=\"")
1050            .nth(1)
1051            .unwrap()
1052            .split('\"')
1053            .next()
1054            .unwrap();
1055
1056        let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form(
1057            serde_json::json!({
1058                "csrf": csrf_token,
1059                "action": "register",
1060                "import_email": "on",
1061                "accept_terms": "on",
1062            }),
1063        );
1064        let request = cookies.with_cookies(request);
1065        let response = state.request(request).await;
1066        cookies.save_cookies(&response);
1067        response.assert_status(StatusCode::SEE_OTHER);
1068
1069        // Check that we have a registered user, with the email imported
1070        let mut repo = state.repository().await.unwrap();
1071        let user = repo
1072            .user()
1073            .find_by_username("john")
1074            .await
1075            .unwrap()
1076            .expect("user exists");
1077
1078        let link = repo
1079            .upstream_oauth_link()
1080            .find_by_subject(&provider, "subject")
1081            .await
1082            .unwrap()
1083            .expect("link exists");
1084
1085        assert_eq!(link.user_id, Some(user.id));
1086
1087        let page = repo
1088            .user_email()
1089            .list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1090            .await
1091            .unwrap();
1092        let email = page.edges.first().expect("email exists");
1093
1094        assert_eq!(email.email, "john@example.com");
1095    }
1096}