mas_handlers/oauth2/device/
authorize.rs1use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::typed_header::TypedHeader;
9use chrono::Duration;
10use headers::{CacheControl, Pragma};
11use hyper::StatusCode;
12use mas_axum_utils::{
13 client_authorization::{ClientAuthorization, CredentialsVerificationError},
14 record_error,
15};
16use mas_keystore::Encrypter;
17use mas_router::UrlBuilder;
18use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2DeviceCodeGrantParams};
19use oauth2_types::{
20 errors::{ClientError, ClientErrorCode},
21 requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse, GrantType},
22 scope::ScopeToken,
23};
24use rand::distributions::{Alphanumeric, DistString};
25use thiserror::Error;
26use ulid::Ulid;
27
28use crate::{BoundActivityTracker, impl_from_error_for_route};
29
30#[derive(Debug, Error)]
31pub(crate) enum RouteError {
32 #[error(transparent)]
33 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35 #[error("client not found")]
36 ClientNotFound,
37
38 #[error("client {0} is not allowed to use the device code grant")]
39 ClientNotAllowed(Ulid),
40
41 #[error("invalid client credentials for client {client_id}")]
42 InvalidClientCredentials {
43 client_id: Ulid,
44 #[source]
45 source: CredentialsVerificationError,
46 },
47
48 #[error("could not verify client credentials for client {client_id}")]
49 ClientCredentialsVerification {
50 client_id: Ulid,
51 #[source]
52 source: CredentialsVerificationError,
53 },
54}
55
56impl_from_error_for_route!(mas_storage::RepositoryError);
57
58impl IntoResponse for RouteError {
59 fn into_response(self) -> axum::response::Response {
60 let sentry_event_id = record_error!(self, Self::Internal(_));
61
62 let response = match self {
63 Self::Internal(_) | Self::ClientCredentialsVerification { .. } => (
64 StatusCode::INTERNAL_SERVER_ERROR,
65 Json(ClientError::from(ClientErrorCode::ServerError)),
66 ),
67 Self::ClientNotFound | Self::InvalidClientCredentials { .. } => (
68 StatusCode::UNAUTHORIZED,
69 Json(ClientError::from(ClientErrorCode::InvalidClient)),
70 ),
71 Self::ClientNotAllowed(_) => (
72 StatusCode::UNAUTHORIZED,
73 Json(ClientError::from(ClientErrorCode::UnauthorizedClient)),
74 ),
75 };
76
77 (sentry_event_id, response).into_response()
78 }
79}
80
81#[tracing::instrument(
82 name = "handlers.oauth2.device.request.post",
83 fields(client.id = client_authorization.client_id()),
84 skip_all,
85)]
86pub(crate) async fn post(
87 mut rng: BoxRng,
88 clock: BoxClock,
89 mut repo: BoxRepository,
90 user_agent: Option<TypedHeader<headers::UserAgent>>,
91 activity_tracker: BoundActivityTracker,
92 State(url_builder): State<UrlBuilder>,
93 State(http_client): State<reqwest::Client>,
94 State(encrypter): State<Encrypter>,
95 client_authorization: ClientAuthorization<DeviceAuthorizationRequest>,
96) -> Result<impl IntoResponse, RouteError> {
97 let client = client_authorization
98 .credentials
99 .fetch(&mut repo)
100 .await?
101 .ok_or(RouteError::ClientNotFound)?;
102
103 let method = client
105 .token_endpoint_auth_method
106 .as_ref()
107 .ok_or(RouteError::ClientNotAllowed(client.id))?;
108
109 client_authorization
110 .credentials
111 .verify(&http_client, &encrypter, method, &client)
112 .await
113 .map_err(|err| {
114 if err.is_internal() {
115 RouteError::ClientCredentialsVerification {
116 client_id: client.id,
117 source: err,
118 }
119 } else {
120 RouteError::InvalidClientCredentials {
121 client_id: client.id,
122 source: err,
123 }
124 }
125 })?;
126
127 if !client.grant_types.contains(&GrantType::DeviceCode) {
128 return Err(RouteError::ClientNotAllowed(client.id));
129 }
130
131 let scope = client_authorization
132 .form
133 .and_then(|f| f.scope)
134 .unwrap_or(std::iter::empty::<ScopeToken>().collect());
136
137 let expires_in = Duration::microseconds(20 * 60 * 1000 * 1000);
138
139 let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
140 let ip_address = activity_tracker.ip();
141
142 let device_code = Alphanumeric.sample_string(&mut rng, 32);
143 let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase();
144
145 let device_code = repo
146 .oauth2_device_code_grant()
147 .add(
148 &mut rng,
149 &clock,
150 OAuth2DeviceCodeGrantParams {
151 client: &client,
152 scope,
153 device_code,
154 user_code,
155 expires_in,
156 user_agent,
157 ip_address,
158 },
159 )
160 .await?;
161
162 repo.save().await?;
163
164 let response = DeviceAuthorizationResponse {
165 device_code: device_code.device_code,
166 user_code: device_code.user_code.clone(),
167 verification_uri: url_builder.device_code_link(),
168 verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)),
169 expires_in,
170 interval: Some(Duration::microseconds(5 * 1000 * 1000)),
171 };
172
173 Ok((
174 StatusCode::OK,
175 TypedHeader(CacheControl::new().with_no_store()),
176 TypedHeader(Pragma::no_cache()),
177 Json(response),
178 ))
179}
180
181#[cfg(test)]
182mod tests {
183 use hyper::{Request, StatusCode};
184 use mas_router::SimpleRoute;
185 use oauth2_types::{
186 registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse,
187 };
188 use sqlx::PgPool;
189
190 use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
191
192 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
193 async fn test_device_code_request(pool: PgPool) {
194 setup();
195 let state = TestState::from_pool(pool).await.unwrap();
196
197 let request =
199 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
200 "client_uri": "https://example.com/",
201 "token_endpoint_auth_method": "none",
202 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"],
203 "response_types": [],
204 }));
205
206 let response = state.request(request).await;
207 response.assert_status(StatusCode::CREATED);
208
209 let response: ClientRegistrationResponse = response.json();
210 let client_id = response.client_id;
211
212 let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form(
214 serde_json::json!({
215 "client_id": client_id,
216 "scope": "openid",
217 }),
218 );
219 let response = state.request(request).await;
220 response.assert_status(StatusCode::OK);
221
222 let response: DeviceAuthorizationResponse = response.json();
223 assert_eq!(response.device_code.len(), 32);
224 assert_eq!(response.user_code.len(), 6);
225 }
226}