mas_handlers/upstream_oauth2/
callback.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::LazyLock;
8
9use axum::{
10    Form,
11    extract::{Path, State},
12    http::Method,
13    response::{Html, IntoResponse, Response},
14};
15use hyper::StatusCode;
16use mas_axum_utils::{cookies::CookieJar, record_error};
17use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
18use mas_jose::claims::TokenHash;
19use mas_keystore::{Encrypter, Keystore};
20use mas_oidc_client::requests::jose::JwtVerificationData;
21use mas_router::UrlBuilder;
22use mas_storage::{
23    BoxClock, BoxRepository, BoxRng, Clock,
24    upstream_oauth2::{
25        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
26        UpstreamOAuthSessionRepository,
27    },
28};
29use mas_templates::{FormPostContext, Templates};
30use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
31use opentelemetry::{Key, KeyValue, metrics::Counter};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34use thiserror::Error;
35use ulid::Ulid;
36
37use super::{
38    UpstreamSessionsCookie,
39    cache::LazyProviderInfos,
40    client_credentials_for_provider,
41    template::{AttributeMappingContext, environment},
42};
43use crate::{
44    METER, PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
45};
46
47static CALLBACK_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
48    METER
49        .u64_counter("mas.upstream_oauth2.callback")
50        .with_description("Number of requests to the upstream OAuth2 callback endpoint")
51        .build()
52});
53const PROVIDER: Key = Key::from_static_str("provider");
54const RESULT: Key = Key::from_static_str("result");
55
56#[derive(Serialize, Deserialize)]
57pub struct Params {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    state: Option<String>,
60
61    /// An extra parameter to track whether the POST request was re-made by us
62    /// to the same URL to escape Same-Site cookies restrictions
63    #[serde(default)]
64    did_mas_repost_to_itself: bool,
65
66    #[serde(skip_serializing_if = "Option::is_none")]
67    code: Option<String>,
68
69    #[serde(skip_serializing_if = "Option::is_none")]
70    error: Option<ClientErrorCode>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    error_description: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    error_uri: Option<String>,
75
76    #[serde(flatten)]
77    extra_callback_parameters: Option<serde_json::Value>,
78}
79
80impl Params {
81    /// Returns true if none of the fields are set
82    pub fn is_empty(&self) -> bool {
83        self.state.is_none()
84            && self.code.is_none()
85            && self.error.is_none()
86            && self.error_description.is_none()
87            && self.error_uri.is_none()
88    }
89}
90
91#[derive(Debug, Error)]
92pub(crate) enum RouteError {
93    #[error("Session not found")]
94    SessionNotFound,
95
96    #[error("Provider not found")]
97    ProviderNotFound,
98
99    #[error("Provider mismatch")]
100    ProviderMismatch,
101
102    #[error("Session already completed")]
103    AlreadyCompleted,
104
105    #[error("State parameter mismatch")]
106    StateMismatch,
107
108    #[error("Missing state parameter")]
109    MissingState,
110
111    #[error("Missing code parameter")]
112    MissingCode,
113
114    #[error("Could not extract subject from ID token")]
115    ExtractSubject(#[source] minijinja::Error),
116
117    #[error("Subject is empty")]
118    EmptySubject,
119
120    #[error("Error from the provider: {error}")]
121    ClientError {
122        error: ClientErrorCode,
123        error_description: Option<String>,
124    },
125
126    #[error("Missing session cookie")]
127    MissingCookie,
128
129    #[error("Missing query parameters")]
130    MissingQueryParams,
131
132    #[error("Missing form parameters")]
133    MissingFormParams,
134
135    #[error("Invalid response mode, expected '{expected}'")]
136    InvalidResponseMode {
137        expected: UpstreamOAuthProviderResponseMode,
138    },
139
140    #[error(transparent)]
141    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
142}
143
144impl_from_error_for_route!(mas_templates::TemplateError);
145impl_from_error_for_route!(mas_storage::RepositoryError);
146impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
147impl_from_error_for_route!(mas_oidc_client::error::JwksError);
148impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
149impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
150impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
151impl_from_error_for_route!(super::ProviderCredentialsError);
152impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
153
154impl IntoResponse for RouteError {
155    fn into_response(self) -> axum::response::Response {
156        let sentry_event_id = record_error!(self, Self::Internal(_));
157        let response = match self {
158            Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
159            Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
160            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
161            e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
162        };
163
164        (sentry_event_id, response).into_response()
165    }
166}
167
168#[tracing::instrument(
169    name = "handlers.upstream_oauth2.callback.handler",
170    fields(upstream_oauth_provider.id = %provider_id),
171    skip_all,
172)]
173#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
174pub(crate) async fn handler(
175    mut rng: BoxRng,
176    clock: BoxClock,
177    State(metadata_cache): State<MetadataCache>,
178    mut repo: BoxRepository,
179    State(url_builder): State<UrlBuilder>,
180    State(encrypter): State<Encrypter>,
181    State(keystore): State<Keystore>,
182    State(client): State<reqwest::Client>,
183    State(templates): State<Templates>,
184    method: Method,
185    PreferredLanguage(locale): PreferredLanguage,
186    cookie_jar: CookieJar,
187    Path(provider_id): Path<Ulid>,
188    Form(params): Form<Params>,
189) -> Result<Response, RouteError> {
190    let provider = repo
191        .upstream_oauth_provider()
192        .lookup(provider_id)
193        .await?
194        .filter(UpstreamOAuthProvider::enabled)
195        .ok_or(RouteError::ProviderNotFound)?;
196
197    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
198
199    if params.is_empty() {
200        if let Method::GET = method {
201            return Err(RouteError::MissingQueryParams);
202        }
203
204        return Err(RouteError::MissingFormParams);
205    }
206
207    // The `Form` extractor will use the body of the request for POST requests and
208    // the query parameters for GET requests. We need to then look at the method do
209    // make sure it matches the expected `response_mode`
210    match (provider.response_mode, method) {
211        (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
212            // We set the cookies with a `Same-Site` policy set to `Lax`, so because this is
213            // usually a cross-site form POST, we need to render a form with the
214            // same values, which posts back to the same URL. However, there are
215            // other valid reasons for the cookie to be missing, so to track whether we did
216            // this POST ourselves, we set a flag.
217            if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
218                let params = Params {
219                    did_mas_repost_to_itself: true,
220                    ..params
221                };
222                let context = FormPostContext::new_for_current_url(params).with_language(&locale);
223                let html = templates.render_form_post(&context)?;
224                return Ok(Html(html).into_response());
225            }
226        }
227        (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
228        (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
229    }
230
231    if let Some(error) = params.error {
232        CALLBACK_COUNTER.add(
233            1,
234            &[
235                KeyValue::new(PROVIDER, provider_id.to_string()),
236                KeyValue::new(RESULT, "error"),
237            ],
238        );
239
240        return Err(RouteError::ClientError {
241            error,
242            error_description: params.error_description.clone(),
243        });
244    }
245
246    let Some(state) = params.state else {
247        return Err(RouteError::MissingState);
248    };
249
250    let (session_id, _post_auth_action) = sessions_cookie
251        .find_session(provider_id, &state)
252        .map_err(|_| RouteError::MissingCookie)?;
253
254    let session = repo
255        .upstream_oauth_session()
256        .lookup(session_id)
257        .await?
258        .ok_or(RouteError::SessionNotFound)?;
259
260    if provider.id != session.provider_id {
261        // The provider in the session cookie should match the one from the URL
262        return Err(RouteError::ProviderMismatch);
263    }
264
265    if state != session.state_str {
266        // The state in the session cookie should match the one from the params
267        return Err(RouteError::StateMismatch);
268    }
269
270    if !session.is_pending() {
271        // The session was already completed
272        return Err(RouteError::AlreadyCompleted);
273    }
274
275    // Let's extract the code from the params, and return if there was an error
276    let Some(code) = params.code else {
277        return Err(RouteError::MissingCode);
278    };
279
280    CALLBACK_COUNTER.add(
281        1,
282        &[
283            KeyValue::new(PROVIDER, provider_id.to_string()),
284            KeyValue::new(RESULT, "success"),
285        ],
286    );
287
288    let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
289
290    // Figure out the client credentials
291    let client_credentials = client_credentials_for_provider(
292        &provider,
293        lazy_metadata.token_endpoint().await?,
294        &keystore,
295        &encrypter,
296    )?;
297
298    let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
299
300    let token_response = mas_oidc_client::requests::token::request_access_token(
301        &client,
302        client_credentials,
303        lazy_metadata.token_endpoint().await?,
304        AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
305            code: code.clone(),
306            redirect_uri: Some(redirect_uri),
307            code_verifier: session.code_challenge_verifier.clone(),
308        }),
309        clock.now(),
310        &mut rng,
311    )
312    .await?;
313
314    let mut jwks = None;
315
316    let mut context = AttributeMappingContext::new();
317    if let Some(id_token) = token_response.id_token.as_ref() {
318        jwks = Some(
319            mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
320                .await?,
321        );
322
323        let id_token_verification_data = JwtVerificationData {
324            issuer: provider.issuer.as_deref(),
325            jwks: jwks.as_ref().unwrap(),
326            signing_algorithm: &provider.id_token_signed_response_alg,
327            client_id: &provider.client_id,
328        };
329
330        // Decode and verify the ID token
331        let id_token = mas_oidc_client::requests::jose::verify_id_token(
332            id_token,
333            id_token_verification_data,
334            None,
335            clock.now(),
336        )?;
337
338        let (_headers, mut claims) = id_token.into_parts();
339
340        // Access token hash must match.
341        mas_jose::claims::AT_HASH
342            .extract_optional_with_options(
343                &mut claims,
344                TokenHash::new(
345                    id_token_verification_data.signing_algorithm,
346                    &token_response.access_token,
347                ),
348            )
349            .map_err(mas_oidc_client::error::IdTokenError::from)?;
350
351        // Code hash must match.
352        mas_jose::claims::C_HASH
353            .extract_optional_with_options(
354                &mut claims,
355                TokenHash::new(id_token_verification_data.signing_algorithm, &code),
356            )
357            .map_err(mas_oidc_client::error::IdTokenError::from)?;
358
359        // Nonce must match.
360        mas_jose::claims::NONCE
361            .extract_required_with_options(&mut claims, session.nonce.as_str())
362            .map_err(mas_oidc_client::error::IdTokenError::from)?;
363
364        context = context.with_id_token_claims(claims);
365    }
366
367    if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
368        context = context.with_extra_callback_parameters(extra_callback_parameters);
369    }
370
371    let userinfo = if provider.fetch_userinfo {
372        Some(json!(match &provider.userinfo_signed_response_alg {
373            Some(signing_algorithm) => {
374                let jwks = match jwks {
375                    Some(jwks) => jwks,
376                    None => {
377                        mas_oidc_client::requests::jose::fetch_jwks(
378                            &client,
379                            lazy_metadata.jwks_uri().await?,
380                        )
381                        .await?
382                    }
383                };
384
385                mas_oidc_client::requests::userinfo::fetch_userinfo(
386                    &client,
387                    lazy_metadata.userinfo_endpoint().await?,
388                    token_response.access_token.as_str(),
389                    Some(JwtVerificationData {
390                        issuer: provider.issuer.as_deref(),
391                        jwks: &jwks,
392                        signing_algorithm,
393                        client_id: &provider.client_id,
394                    }),
395                )
396                .await?
397            }
398            None => {
399                mas_oidc_client::requests::userinfo::fetch_userinfo(
400                    &client,
401                    lazy_metadata.userinfo_endpoint().await?,
402                    token_response.access_token.as_str(),
403                    None,
404                )
405                .await?
406            }
407        }))
408    } else {
409        None
410    };
411
412    if let Some(userinfo) = userinfo.clone() {
413        context = context.with_userinfo_claims(userinfo);
414    }
415
416    let context = context.build();
417
418    let env = environment();
419
420    let template = provider
421        .claims_imports
422        .subject
423        .template
424        .as_deref()
425        .unwrap_or("{{ user.sub }}");
426    let subject = env
427        .render_str(template, context.clone())
428        .map_err(RouteError::ExtractSubject)?;
429
430    if subject.is_empty() {
431        return Err(RouteError::EmptySubject);
432    }
433
434    // Look for an existing link
435    let maybe_link = repo
436        .upstream_oauth_link()
437        .find_by_subject(&provider, &subject)
438        .await?;
439
440    let link = if let Some(link) = maybe_link {
441        link
442    } else {
443        // Try to render the human account name if we have one,
444        // but just log if it fails
445        let human_account_name = provider
446            .claims_imports
447            .account_name
448            .template
449            .as_deref()
450            .and_then(|template| match env.render_str(template, context) {
451                Ok(name) => Some(name),
452                Err(e) => {
453                    tracing::warn!(
454                        error = &e as &dyn std::error::Error,
455                        "Failed to render account name"
456                    );
457                    None
458                }
459            });
460
461        repo.upstream_oauth_link()
462            .add(&mut rng, &clock, &provider, subject, human_account_name)
463            .await?
464    };
465
466    let session = repo
467        .upstream_oauth_session()
468        .complete_with_link(
469            &clock,
470            session,
471            &link,
472            token_response.id_token,
473            params.extra_callback_parameters,
474            userinfo,
475        )
476        .await?;
477
478    let cookie_jar = sessions_cookie
479        .add_link_to_session(session.id, link.id)?
480        .save(cookie_jar, &clock);
481
482    repo.save().await?;
483
484    Ok((
485        cookie_jar,
486        url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
487    )
488        .into_response())
489}