mas_handlers/oauth2/
registration.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::sync::LazyLock;
8
9use axum::{Json, extract::State, response::IntoResponse};
10use axum_extra::TypedHeader;
11use hyper::StatusCode;
12use mas_axum_utils::record_error;
13use mas_iana::oauth::OAuthClientAuthenticationMethod;
14use mas_keystore::Encrypter;
15use mas_policy::{EvaluationResult, Policy};
16use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
17use oauth2_types::{
18    errors::{ClientError, ClientErrorCode},
19    registration::{
20        ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
21        VerifiedClientMetadata,
22    },
23};
24use opentelemetry::{Key, KeyValue, metrics::Counter};
25use psl::Psl;
26use rand::distributions::{Alphanumeric, DistString};
27use serde::Serialize;
28use sha2::Digest as _;
29use thiserror::Error;
30use tracing::info;
31use url::Url;
32
33use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
34
35static REGISTRATION_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
36    METER
37        .u64_counter("mas.oauth2.registration_request")
38        .with_description("Number of OAuth2 registration requests")
39        .with_unit("{request}")
40        .build()
41});
42const RESULT: Key = Key::from_static_str("result");
43
44#[derive(Debug, Error)]
45pub(crate) enum RouteError {
46    #[error(transparent)]
47    Internal(Box<dyn std::error::Error + Send + Sync>),
48
49    #[error(transparent)]
50    JsonExtract(#[from] axum::extract::rejection::JsonRejection),
51
52    #[error("invalid client metadata")]
53    InvalidClientMetadata(#[from] ClientMetadataVerificationError),
54
55    #[error("{0} is a public suffix, not a valid domain")]
56    UrlIsPublicSuffix(&'static str),
57
58    #[error("client registration denied by the policy: {0}")]
59    PolicyDenied(EvaluationResult),
60}
61
62impl_from_error_for_route!(mas_storage::RepositoryError);
63impl_from_error_for_route!(mas_policy::LoadError);
64impl_from_error_for_route!(mas_policy::EvaluationError);
65impl_from_error_for_route!(mas_keystore::aead::Error);
66impl_from_error_for_route!(serde_json::Error);
67
68impl IntoResponse for RouteError {
69    fn into_response(self) -> axum::response::Response {
70        let sentry_event_id = record_error!(self, Self::Internal(_));
71
72        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "denied")]);
73
74        let response = match self {
75            Self::Internal(_) => (
76                StatusCode::INTERNAL_SERVER_ERROR,
77                Json(ClientError::from(ClientErrorCode::ServerError)),
78            )
79                .into_response(),
80
81            // This error happens if we managed to parse the incomiong JSON but it can't be
82            // deserialized to the expected type. In this case we return an
83            // `invalid_client_metadata` error with the details of the error.
84            Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
85                StatusCode::BAD_REQUEST,
86                Json(
87                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
88                        .with_description(e.to_string()),
89                ),
90            )
91                .into_response(),
92
93            // For all other JSON errors we return a `invalid_request` error, since this is
94            // probably due to a malformed request.
95            Self::JsonExtract(_) => (
96                StatusCode::BAD_REQUEST,
97                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
98            )
99                .into_response(),
100
101            // This error comes from the `ClientMetadata::validate` method. We return an
102            // `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
103            // return an `invalid_client_metadata` error.
104            Self::InvalidClientMetadata(
105                ClientMetadataVerificationError::MissingRedirectUris
106                | ClientMetadataVerificationError::RedirectUriWithFragment(_),
107            ) => (
108                StatusCode::BAD_REQUEST,
109                Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
110            )
111                .into_response(),
112
113            Self::InvalidClientMetadata(e) => (
114                StatusCode::BAD_REQUEST,
115                Json(
116                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
117                        .with_description(e.to_string()),
118                ),
119            )
120                .into_response(),
121
122            // This error happens if the any of the client's URIs are public suffixes. We return
123            // an `invalid_redirect_uri` error if it's a `redirect_uri`, else we return an
124            // `invalid_client_metadata` error.
125            Self::UrlIsPublicSuffix("redirect_uri") => (
126                StatusCode::BAD_REQUEST,
127                Json(
128                    ClientError::from(ClientErrorCode::InvalidRedirectUri)
129                        .with_description("redirect_uri is not using a valid domain".to_owned()),
130                ),
131            )
132                .into_response(),
133
134            Self::UrlIsPublicSuffix(field) => (
135                StatusCode::BAD_REQUEST,
136                Json(
137                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
138                        .with_description(format!("{field} is not using a valid domain")),
139                ),
140            )
141                .into_response(),
142
143            // For policy violations, we return an `invalid_client_metadata` error with the details
144            // of the violations in most cases. If a violation includes `redirect_uri` in the
145            // message, we return an `invalid_redirect_uri` error instead.
146            Self::PolicyDenied(evaluation) => {
147                // TODO: detect them better
148                let code = if evaluation
149                    .violations
150                    .iter()
151                    .any(|v| v.msg.contains("redirect_uri"))
152                {
153                    ClientErrorCode::InvalidRedirectUri
154                } else {
155                    ClientErrorCode::InvalidClientMetadata
156                };
157
158                let collected = &evaluation
159                    .violations
160                    .iter()
161                    .map(|v| v.msg.clone())
162                    .collect::<Vec<String>>();
163                let joined = collected.join("; ");
164
165                (
166                    StatusCode::BAD_REQUEST,
167                    Json(ClientError::from(code).with_description(joined)),
168                )
169                    .into_response()
170            }
171        };
172
173        (sentry_event_id, response).into_response()
174    }
175}
176
177#[derive(Serialize)]
178struct RouteResponse {
179    #[serde(flatten)]
180    response: ClientRegistrationResponse,
181    #[serde(flatten)]
182    metadata: VerifiedClientMetadata,
183}
184
185/// Check if the host of the given URL is a public suffix
186fn host_is_public_suffix(url: &Url) -> bool {
187    let host = url.host_str().unwrap_or_default().as_bytes();
188    let Some(suffix) = psl::List.suffix(host) else {
189        // There is no suffix, which is the case for empty hosts, like with custom
190        // schemes
191        return false;
192    };
193
194    if !suffix.is_known() {
195        // The suffix is not known, so it's not a public suffix
196        return false;
197    }
198
199    // We want to cover two cases:
200    // - The host is the suffix itself, like `com`
201    // - The host is a dot followed by the suffix, like `.com`
202    if host.len() <= suffix.as_bytes().len() + 1 {
203        // The host only has the suffix in it, so it's a public suffix
204        return true;
205    }
206
207    false
208}
209
210/// Check if any of the URLs in the given `Localized` field is a public suffix
211fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
212    url.iter().any(|(_lang, url)| host_is_public_suffix(url))
213}
214
215#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all)]
216pub(crate) async fn post(
217    mut rng: BoxRng,
218    clock: BoxClock,
219    mut repo: BoxRepository,
220    mut policy: Policy,
221    activity_tracker: BoundActivityTracker,
222    user_agent: Option<TypedHeader<headers::UserAgent>>,
223    State(encrypter): State<Encrypter>,
224    body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
225) -> Result<impl IntoResponse, RouteError> {
226    // Propagate any JSON extraction error
227    let Json(body) = body?;
228
229    // Sort the properties to ensure a stable serialisation order for hashing
230    let body = body.sorted();
231
232    // We need to serialize the body to compute the hash, and to log it
233    let body_json = serde_json::to_string(&body)?;
234
235    info!(body = body_json, "Client registration");
236
237    let user_agent = user_agent.map(|ua| ua.to_string());
238
239    // Validate the body
240    let metadata = body.validate()?;
241
242    // Some extra validation that is hard to do in OPA and not done by the
243    // `validate` method either
244    if let Some(client_uri) = &metadata.client_uri {
245        if localised_url_has_public_suffix(client_uri) {
246            return Err(RouteError::UrlIsPublicSuffix("client_uri"));
247        }
248    }
249
250    if let Some(logo_uri) = &metadata.logo_uri {
251        if localised_url_has_public_suffix(logo_uri) {
252            return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
253        }
254    }
255
256    if let Some(policy_uri) = &metadata.policy_uri {
257        if localised_url_has_public_suffix(policy_uri) {
258            return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
259        }
260    }
261
262    if let Some(tos_uri) = &metadata.tos_uri {
263        if localised_url_has_public_suffix(tos_uri) {
264            return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
265        }
266    }
267
268    if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
269        if host_is_public_suffix(initiate_login_uri) {
270            return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
271        }
272    }
273
274    for redirect_uri in metadata.redirect_uris() {
275        if host_is_public_suffix(redirect_uri) {
276            return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
277        }
278    }
279
280    let res = policy
281        .evaluate_client_registration(mas_policy::ClientRegistrationInput {
282            client_metadata: &metadata,
283            requester: mas_policy::Requester {
284                ip_address: activity_tracker.ip(),
285                user_agent,
286            },
287        })
288        .await?;
289    if !res.valid() {
290        return Err(RouteError::PolicyDenied(res));
291    }
292
293    let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
294        Some(
295            OAuthClientAuthenticationMethod::ClientSecretJwt
296            | OAuthClientAuthenticationMethod::ClientSecretPost
297            | OAuthClientAuthenticationMethod::ClientSecretBasic,
298        ) => {
299            // Let's generate a random client secret
300            let client_secret = Alphanumeric.sample_string(&mut rng, 20);
301            let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
302            (Some(client_secret), Some(encrypted_client_secret))
303        }
304        _ => (None, None),
305    };
306
307    // If the client doesn't have a secret, we may be able to deduplicate it. To
308    // do so, we hash the client metadata, and look for it in the database
309    let (digest_hash, existing_client) = if client_secret.is_none() {
310        // XXX: One interesting caveat is that we hash *before* saving to the database.
311        // It means it takes into account fields that we don't care about *yet*.
312        //
313        // This means that if later we start supporting a particular field, we
314        // will still serve the 'old' client_id, without updating the client in the
315        // database
316        let hash = sha2::Sha256::digest(body_json);
317        let hash = hex::encode(hash);
318        let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
319        (Some(hash), client)
320    } else {
321        (None, None)
322    };
323
324    let client = if let Some(client) = existing_client {
325        tracing::info!(%client.id, "Reusing existing client");
326        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "reused")]);
327        client
328    } else {
329        let client = repo
330            .oauth2_client()
331            .add(
332                &mut rng,
333                &clock,
334                metadata.redirect_uris().to_vec(),
335                digest_hash,
336                encrypted_client_secret,
337                metadata.application_type.clone(),
338                //&metadata.response_types(),
339                metadata.grant_types().to_vec(),
340                metadata
341                    .client_name
342                    .clone()
343                    .map(Localized::to_non_localized),
344                metadata.logo_uri.clone().map(Localized::to_non_localized),
345                metadata.client_uri.clone().map(Localized::to_non_localized),
346                metadata.policy_uri.clone().map(Localized::to_non_localized),
347                metadata.tos_uri.clone().map(Localized::to_non_localized),
348                metadata.jwks_uri.clone(),
349                metadata.jwks.clone(),
350                // XXX: those might not be right, should be function calls
351                metadata.id_token_signed_response_alg.clone(),
352                metadata.userinfo_signed_response_alg.clone(),
353                metadata.token_endpoint_auth_method.clone(),
354                metadata.token_endpoint_auth_signing_alg.clone(),
355                metadata.initiate_login_uri.clone(),
356            )
357            .await?;
358        tracing::info!(%client.id, "Registered new client");
359        REGISTRATION_COUNTER.add(1, &[KeyValue::new(RESULT, "created")]);
360        client
361    };
362
363    let response = ClientRegistrationResponse {
364        client_id: client.client_id.clone(),
365        client_secret,
366        // XXX: we should have a `created_at` field on the clients
367        client_id_issued_at: Some(client.id.datetime().into()),
368        client_secret_expires_at: None,
369    };
370
371    // We round-trip back to the metadata to output it in the response
372    // This should never fail, as the client is valid
373    let metadata = client.into_metadata().validate()?;
374
375    repo.save().await?;
376
377    let response = RouteResponse { response, metadata };
378
379    Ok((StatusCode::CREATED, Json(response)))
380}
381
382#[cfg(test)]
383mod tests {
384    use hyper::{Request, StatusCode};
385    use mas_router::SimpleRoute;
386    use oauth2_types::{
387        errors::{ClientError, ClientErrorCode},
388        registration::ClientRegistrationResponse,
389    };
390    use sqlx::PgPool;
391    use url::Url;
392
393    use crate::{
394        oauth2::registration::host_is_public_suffix,
395        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
396    };
397
398    #[test]
399    fn test_public_suffix_list() {
400        fn url_is_public_suffix(url: &str) -> bool {
401            host_is_public_suffix(&Url::parse(url).unwrap())
402        }
403
404        assert!(url_is_public_suffix("https://.com"));
405        assert!(url_is_public_suffix("https://.com."));
406        assert!(url_is_public_suffix("https://co.uk"));
407        assert!(url_is_public_suffix("https://github.io"));
408        assert!(!url_is_public_suffix("https://example.com"));
409        assert!(!url_is_public_suffix("https://example.com."));
410        assert!(!url_is_public_suffix("https://x.com"));
411        assert!(!url_is_public_suffix("https://x.com."));
412        assert!(!url_is_public_suffix("https://matrix-org.github.io"));
413        assert!(!url_is_public_suffix("http://localhost"));
414        assert!(!url_is_public_suffix("org.matrix:/callback"));
415        assert!(!url_is_public_suffix("http://somerandominternaldomain"));
416    }
417
418    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
419    async fn test_registration_error(pool: PgPool) {
420        setup();
421        let state = TestState::from_pool(pool).await.unwrap();
422
423        // Body is not a JSON
424        let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
425            .body("this is not a json".to_owned())
426            .unwrap();
427
428        let response = state.request(request).await;
429        response.assert_status(StatusCode::BAD_REQUEST);
430        let response: ClientError = response.json();
431        assert_eq!(response.error, ClientErrorCode::InvalidRequest);
432
433        // Invalid client metadata
434        let request =
435            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
436                "client_uri": "this is not a uri",
437            }));
438
439        let response = state.request(request).await;
440        response.assert_status(StatusCode::BAD_REQUEST);
441        let response: ClientError = response.json();
442        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
443
444        // Invalid redirect URI
445        let request =
446            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
447                "application_type": "web",
448                "client_uri": "https://example.com/",
449                "redirect_uris": ["http://this-is-insecure.com/"],
450            }));
451
452        let response = state.request(request).await;
453        response.assert_status(StatusCode::BAD_REQUEST);
454        let response: ClientError = response.json();
455        assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
456
457        // Incoherent response types
458        let request =
459            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
460                "client_uri": "https://example.com/",
461                "redirect_uris": ["https://example.com/"],
462                "response_types": ["id_token"],
463                "grant_types": ["authorization_code"],
464            }));
465
466        let response = state.request(request).await;
467        response.assert_status(StatusCode::BAD_REQUEST);
468        let response: ClientError = response.json();
469        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
470
471        // Using a public suffix
472        let request =
473            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
474                "client_uri": "https://github.io/",
475                "redirect_uris": ["https://github.io/"],
476                "response_types": ["code"],
477                "grant_types": ["authorization_code"],
478                "token_endpoint_auth_method": "client_secret_basic",
479            }));
480
481        let response = state.request(request).await;
482        response.assert_status(StatusCode::BAD_REQUEST);
483        let response: ClientError = response.json();
484        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
485        assert_eq!(
486            response.error_description.unwrap(),
487            "client_uri is not using a valid domain"
488        );
489
490        // Using a public suffix in a translated URL
491        let request =
492            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
493                "client_uri": "https://example.com/",
494                "client_uri#fr-FR": "https://github.io/",
495                "redirect_uris": ["https://example.com/"],
496                "response_types": ["code"],
497                "grant_types": ["authorization_code"],
498                "token_endpoint_auth_method": "client_secret_basic",
499            }));
500
501        let response = state.request(request).await;
502        response.assert_status(StatusCode::BAD_REQUEST);
503        let response: ClientError = response.json();
504        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
505        assert_eq!(
506            response.error_description.unwrap(),
507            "client_uri is not using a valid domain"
508        );
509    }
510
511    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
512    async fn test_registration(pool: PgPool) {
513        setup();
514        let state = TestState::from_pool(pool).await.unwrap();
515
516        // A successful registration with no authentication should not return a client
517        // secret
518        let request =
519            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
520                "client_uri": "https://example.com/",
521                "redirect_uris": ["https://example.com/"],
522                "response_types": ["code"],
523                "grant_types": ["authorization_code"],
524                "token_endpoint_auth_method": "none",
525            }));
526
527        let response = state.request(request).await;
528        response.assert_status(StatusCode::CREATED);
529        let response: ClientRegistrationResponse = response.json();
530        assert!(response.client_secret.is_none());
531
532        // A successful registration with client_secret based authentication should
533        // return a client secret
534        let request =
535            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
536                "client_uri": "https://example.com/",
537                "redirect_uris": ["https://example.com/"],
538                "response_types": ["code"],
539                "grant_types": ["authorization_code"],
540                "token_endpoint_auth_method": "client_secret_basic",
541            }));
542
543        let response = state.request(request).await;
544        response.assert_status(StatusCode::CREATED);
545        let response: ClientRegistrationResponse = response.json();
546        assert!(response.client_secret.is_some());
547    }
548    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
549    async fn test_registration_dedupe(pool: PgPool) {
550        setup();
551        let state = TestState::from_pool(pool).await.unwrap();
552
553        // Post a client registration twice, we should get the same client ID
554        let request =
555            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
556                "client_uri": "https://example.com/",
557                "client_name": "Example",
558                "client_name#en": "Example",
559                "client_name#fr": "Exemple",
560                "client_name#de": "Beispiel",
561                "redirect_uris": ["https://example.com/", "https://example.com/callback"],
562                "response_types": ["code"],
563                "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
564                "token_endpoint_auth_method": "none",
565            }));
566
567        let response = state.request(request.clone()).await;
568        response.assert_status(StatusCode::CREATED);
569        let response: ClientRegistrationResponse = response.json();
570        let client_id = response.client_id;
571
572        let response = state.request(request).await;
573        response.assert_status(StatusCode::CREATED);
574        let response: ClientRegistrationResponse = response.json();
575        assert_eq!(response.client_id, client_id);
576
577        // Check that the order of some properties doesn't matter
578        let request =
579            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
580                "client_uri": "https://example.com/",
581                "client_name": "Example",
582                "client_name#de": "Beispiel",
583                "client_name#fr": "Exemple",
584                "client_name#en": "Example",
585                "redirect_uris": ["https://example.com/callback", "https://example.com/"],
586                "response_types": ["code"],
587                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
588                "token_endpoint_auth_method": "none",
589            }));
590
591        let response = state.request(request).await;
592        response.assert_status(StatusCode::CREATED);
593        let response: ClientRegistrationResponse = response.json();
594        assert_eq!(response.client_id, client_id);
595
596        // Doing that with a client that has a client_secret should not deduplicate
597        let request =
598            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
599                "client_uri": "https://example.com/",
600                "redirect_uris": ["https://example.com/"],
601                "response_types": ["code"],
602                "grant_types": ["authorization_code"],
603                "token_endpoint_auth_method": "client_secret_basic",
604            }));
605
606        let response = state.request(request.clone()).await;
607        response.assert_status(StatusCode::CREATED);
608        let response: ClientRegistrationResponse = response.json();
609        // Sanity check that the client_id is different
610        assert_ne!(response.client_id, client_id);
611        let client_id = response.client_id;
612
613        let response = state.request(request).await;
614        response.assert_status(StatusCode::CREATED);
615        let response: ClientRegistrationResponse = response.json();
616        assert_ne!(response.client_id, client_id);
617    }
618}