1use std::sync::LazyLock;
8
9use axum::{
10 Form,
11 extract::{Path, State},
12 http::Method,
13 response::{Html, IntoResponse, Response},
14};
15use hyper::StatusCode;
16use mas_axum_utils::{cookies::CookieJar, record_error};
17use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
18use mas_jose::claims::TokenHash;
19use mas_keystore::{Encrypter, Keystore};
20use mas_oidc_client::requests::jose::JwtVerificationData;
21use mas_router::UrlBuilder;
22use mas_storage::{
23 BoxClock, BoxRepository, BoxRng, Clock,
24 upstream_oauth2::{
25 UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
26 UpstreamOAuthSessionRepository,
27 },
28};
29use mas_templates::{FormPostContext, Templates};
30use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
31use opentelemetry::{Key, KeyValue, metrics::Counter};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34use thiserror::Error;
35use ulid::Ulid;
36
37use super::{
38 UpstreamSessionsCookie,
39 cache::LazyProviderInfos,
40 client_credentials_for_provider,
41 template::{AttributeMappingContext, environment},
42};
43use crate::{
44 METER, PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
45};
46
47static CALLBACK_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
48 METER
49 .u64_counter("mas.upstream_oauth2.callback")
50 .with_description("Number of requests to the upstream OAuth2 callback endpoint")
51 .build()
52});
53const PROVIDER: Key = Key::from_static_str("provider");
54const RESULT: Key = Key::from_static_str("result");
55
56#[derive(Serialize, Deserialize)]
57pub struct Params {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 state: Option<String>,
60
61 #[serde(default)]
64 did_mas_repost_to_itself: bool,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
67 code: Option<String>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
70 error: Option<ClientErrorCode>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 error_description: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 error_uri: Option<String>,
75
76 #[serde(flatten)]
77 extra_callback_parameters: Option<serde_json::Value>,
78}
79
80impl Params {
81 pub fn is_empty(&self) -> bool {
83 self.state.is_none()
84 && self.code.is_none()
85 && self.error.is_none()
86 && self.error_description.is_none()
87 && self.error_uri.is_none()
88 }
89}
90
91#[derive(Debug, Error)]
92pub(crate) enum RouteError {
93 #[error("Session not found")]
94 SessionNotFound,
95
96 #[error("Provider not found")]
97 ProviderNotFound,
98
99 #[error("Provider mismatch")]
100 ProviderMismatch,
101
102 #[error("Session already completed")]
103 AlreadyCompleted,
104
105 #[error("State parameter mismatch")]
106 StateMismatch,
107
108 #[error("Missing state parameter")]
109 MissingState,
110
111 #[error("Missing code parameter")]
112 MissingCode,
113
114 #[error("Could not extract subject from ID token")]
115 ExtractSubject(#[source] minijinja::Error),
116
117 #[error("Subject is empty")]
118 EmptySubject,
119
120 #[error("Error from the provider: {error}")]
121 ClientError {
122 error: ClientErrorCode,
123 error_description: Option<String>,
124 },
125
126 #[error("Missing session cookie")]
127 MissingCookie,
128
129 #[error("Missing query parameters")]
130 MissingQueryParams,
131
132 #[error("Missing form parameters")]
133 MissingFormParams,
134
135 #[error("Invalid response mode, expected '{expected}'")]
136 InvalidResponseMode {
137 expected: UpstreamOAuthProviderResponseMode,
138 },
139
140 #[error(transparent)]
141 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
142}
143
144impl_from_error_for_route!(mas_templates::TemplateError);
145impl_from_error_for_route!(mas_storage::RepositoryError);
146impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
147impl_from_error_for_route!(mas_oidc_client::error::JwksError);
148impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
149impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
150impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
151impl_from_error_for_route!(super::ProviderCredentialsError);
152impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
153
154impl IntoResponse for RouteError {
155 fn into_response(self) -> axum::response::Response {
156 let sentry_event_id = record_error!(self, Self::Internal(_));
157 let response = match self {
158 Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
159 Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
160 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
161 e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
162 };
163
164 (sentry_event_id, response).into_response()
165 }
166}
167
168#[tracing::instrument(
169 name = "handlers.upstream_oauth2.callback.handler",
170 fields(upstream_oauth_provider.id = %provider_id),
171 skip_all,
172)]
173#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
174pub(crate) async fn handler(
175 mut rng: BoxRng,
176 clock: BoxClock,
177 State(metadata_cache): State<MetadataCache>,
178 mut repo: BoxRepository,
179 State(url_builder): State<UrlBuilder>,
180 State(encrypter): State<Encrypter>,
181 State(keystore): State<Keystore>,
182 State(client): State<reqwest::Client>,
183 State(templates): State<Templates>,
184 method: Method,
185 PreferredLanguage(locale): PreferredLanguage,
186 cookie_jar: CookieJar,
187 Path(provider_id): Path<Ulid>,
188 Form(params): Form<Params>,
189) -> Result<Response, RouteError> {
190 let provider = repo
191 .upstream_oauth_provider()
192 .lookup(provider_id)
193 .await?
194 .filter(UpstreamOAuthProvider::enabled)
195 .ok_or(RouteError::ProviderNotFound)?;
196
197 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
198
199 if params.is_empty() {
200 if let Method::GET = method {
201 return Err(RouteError::MissingQueryParams);
202 }
203
204 return Err(RouteError::MissingFormParams);
205 }
206
207 match (provider.response_mode, method) {
211 (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
212 if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
218 let params = Params {
219 did_mas_repost_to_itself: true,
220 ..params
221 };
222 let context = FormPostContext::new_for_current_url(params).with_language(&locale);
223 let html = templates.render_form_post(&context)?;
224 return Ok(Html(html).into_response());
225 }
226 }
227 (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
228 (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
229 }
230
231 if let Some(error) = params.error {
232 CALLBACK_COUNTER.add(
233 1,
234 &[
235 KeyValue::new(PROVIDER, provider_id.to_string()),
236 KeyValue::new(RESULT, "error"),
237 ],
238 );
239
240 return Err(RouteError::ClientError {
241 error,
242 error_description: params.error_description.clone(),
243 });
244 }
245
246 let Some(state) = params.state else {
247 return Err(RouteError::MissingState);
248 };
249
250 let (session_id, _post_auth_action) = sessions_cookie
251 .find_session(provider_id, &state)
252 .map_err(|_| RouteError::MissingCookie)?;
253
254 let session = repo
255 .upstream_oauth_session()
256 .lookup(session_id)
257 .await?
258 .ok_or(RouteError::SessionNotFound)?;
259
260 if provider.id != session.provider_id {
261 return Err(RouteError::ProviderMismatch);
263 }
264
265 if state != session.state_str {
266 return Err(RouteError::StateMismatch);
268 }
269
270 if !session.is_pending() {
271 return Err(RouteError::AlreadyCompleted);
273 }
274
275 let Some(code) = params.code else {
277 return Err(RouteError::MissingCode);
278 };
279
280 CALLBACK_COUNTER.add(
281 1,
282 &[
283 KeyValue::new(PROVIDER, provider_id.to_string()),
284 KeyValue::new(RESULT, "success"),
285 ],
286 );
287
288 let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
289
290 let client_credentials = client_credentials_for_provider(
292 &provider,
293 lazy_metadata.token_endpoint().await?,
294 &keystore,
295 &encrypter,
296 )?;
297
298 let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
299
300 let token_response = mas_oidc_client::requests::token::request_access_token(
301 &client,
302 client_credentials,
303 lazy_metadata.token_endpoint().await?,
304 AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
305 code: code.clone(),
306 redirect_uri: Some(redirect_uri),
307 code_verifier: session.code_challenge_verifier.clone(),
308 }),
309 clock.now(),
310 &mut rng,
311 )
312 .await?;
313
314 let mut jwks = None;
315
316 let mut context = AttributeMappingContext::new();
317 if let Some(id_token) = token_response.id_token.as_ref() {
318 jwks = Some(
319 mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
320 .await?,
321 );
322
323 let id_token_verification_data = JwtVerificationData {
324 issuer: provider.issuer.as_deref(),
325 jwks: jwks.as_ref().unwrap(),
326 signing_algorithm: &provider.id_token_signed_response_alg,
327 client_id: &provider.client_id,
328 };
329
330 let id_token = mas_oidc_client::requests::jose::verify_id_token(
332 id_token,
333 id_token_verification_data,
334 None,
335 clock.now(),
336 )?;
337
338 let (_headers, mut claims) = id_token.into_parts();
339
340 mas_jose::claims::AT_HASH
342 .extract_optional_with_options(
343 &mut claims,
344 TokenHash::new(
345 id_token_verification_data.signing_algorithm,
346 &token_response.access_token,
347 ),
348 )
349 .map_err(mas_oidc_client::error::IdTokenError::from)?;
350
351 mas_jose::claims::C_HASH
353 .extract_optional_with_options(
354 &mut claims,
355 TokenHash::new(id_token_verification_data.signing_algorithm, &code),
356 )
357 .map_err(mas_oidc_client::error::IdTokenError::from)?;
358
359 mas_jose::claims::NONCE
361 .extract_required_with_options(&mut claims, session.nonce.as_str())
362 .map_err(mas_oidc_client::error::IdTokenError::from)?;
363
364 context = context.with_id_token_claims(claims);
365 }
366
367 if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
368 context = context.with_extra_callback_parameters(extra_callback_parameters);
369 }
370
371 let userinfo = if provider.fetch_userinfo {
372 Some(json!(match &provider.userinfo_signed_response_alg {
373 Some(signing_algorithm) => {
374 let jwks = match jwks {
375 Some(jwks) => jwks,
376 None => {
377 mas_oidc_client::requests::jose::fetch_jwks(
378 &client,
379 lazy_metadata.jwks_uri().await?,
380 )
381 .await?
382 }
383 };
384
385 mas_oidc_client::requests::userinfo::fetch_userinfo(
386 &client,
387 lazy_metadata.userinfo_endpoint().await?,
388 token_response.access_token.as_str(),
389 Some(JwtVerificationData {
390 issuer: provider.issuer.as_deref(),
391 jwks: &jwks,
392 signing_algorithm,
393 client_id: &provider.client_id,
394 }),
395 )
396 .await?
397 }
398 None => {
399 mas_oidc_client::requests::userinfo::fetch_userinfo(
400 &client,
401 lazy_metadata.userinfo_endpoint().await?,
402 token_response.access_token.as_str(),
403 None,
404 )
405 .await?
406 }
407 }))
408 } else {
409 None
410 };
411
412 if let Some(userinfo) = userinfo.clone() {
413 context = context.with_userinfo_claims(userinfo);
414 }
415
416 let context = context.build();
417
418 let env = environment();
419
420 let template = provider
421 .claims_imports
422 .subject
423 .template
424 .as_deref()
425 .unwrap_or("{{ user.sub }}");
426 let subject = env
427 .render_str(template, context.clone())
428 .map_err(RouteError::ExtractSubject)?;
429
430 if subject.is_empty() {
431 return Err(RouteError::EmptySubject);
432 }
433
434 let maybe_link = repo
436 .upstream_oauth_link()
437 .find_by_subject(&provider, &subject)
438 .await?;
439
440 let link = if let Some(link) = maybe_link {
441 link
442 } else {
443 let human_account_name = provider
446 .claims_imports
447 .account_name
448 .template
449 .as_deref()
450 .and_then(|template| match env.render_str(template, context) {
451 Ok(name) => Some(name),
452 Err(e) => {
453 tracing::warn!(
454 error = &e as &dyn std::error::Error,
455 "Failed to render account name"
456 );
457 None
458 }
459 });
460
461 repo.upstream_oauth_link()
462 .add(&mut rng, &clock, &provider, subject, human_account_name)
463 .await?
464 };
465
466 let session = repo
467 .upstream_oauth_session()
468 .complete_with_link(
469 &clock,
470 session,
471 &link,
472 token_response.id_token,
473 params.extra_callback_parameters,
474 userinfo,
475 )
476 .await?;
477
478 let cookie_jar = sessions_cookie
479 .add_link_to_session(session.id, link.id)?
480 .save(cookie_jar, &clock);
481
482 repo.save().await?;
483
484 Ok((
485 cookie_jar,
486 url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
487 )
488 .into_response())
489}