mas_handlers/admin/v1/upstream_oauth_links/
add.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use aide::{NoApi, OperationIo, transform::TransformOperation};
7use axum::{Json, response::IntoResponse};
8use hyper::StatusCode;
9use mas_axum_utils::record_error;
10use mas_storage::BoxRng;
11use schemars::JsonSchema;
12use serde::Deserialize;
13use ulid::Ulid;
14
15use crate::{
16    admin::{
17        call_context::CallContext,
18        model::{Resource, UpstreamOAuthLink},
19        response::{ErrorResponse, SingleResponse},
20    },
21    impl_from_error_for_route,
22};
23
24#[derive(Debug, thiserror::Error, OperationIo)]
25#[aide(output_with = "Json<ErrorResponse>")]
26pub enum RouteError {
27    #[error(transparent)]
28    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
29
30    #[error("Upstream Oauth 2.0 Provider ID {0} with subject {1} is already linked to a user")]
31    LinkAlreadyExists(Ulid, String),
32
33    #[error("User ID {0} not found")]
34    UserNotFound(Ulid),
35
36    #[error("Upstream OAuth 2.0 Provider ID {0} not found")]
37    ProviderNotFound(Ulid),
38}
39
40impl_from_error_for_route!(mas_storage::RepositoryError);
41
42impl IntoResponse for RouteError {
43    fn into_response(self) -> axum::response::Response {
44        let error = ErrorResponse::from_error(&self);
45        let sentry_event_id = record_error!(self, Self::Internal(_));
46        let status = match self {
47            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
48            Self::LinkAlreadyExists(_, _) => StatusCode::CONFLICT,
49            Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
50        };
51        (status, sentry_event_id, Json(error)).into_response()
52    }
53}
54
55/// # JSON payload for the `POST /api/admin/v1/upstream-oauth-links`
56#[derive(Deserialize, JsonSchema)]
57#[serde(rename = "AddUpstreamOauthLinkRequest")]
58pub struct Request {
59    /// The ID of the user to which the link should be added.
60    #[schemars(with = "crate::admin::schema::Ulid")]
61    user_id: Ulid,
62
63    /// The ID of the upstream provider to which the link is for.
64    #[schemars(with = "crate::admin::schema::Ulid")]
65    provider_id: Ulid,
66
67    /// The subject (sub) claim of the user on the provider.
68    subject: String,
69
70    /// A human readable account name.
71    human_account_name: Option<String>,
72}
73
74pub fn doc(operation: TransformOperation) -> TransformOperation {
75    operation
76        .id("addUpstreamOAuthLink")
77        .summary("Add an upstream OAuth 2.0 link")
78        .tag("upstream-oauth-link")
79        .response_with::<200, Json<SingleResponse<UpstreamOAuthLink>>, _>(|t| {
80            let [sample, ..] = UpstreamOAuthLink::samples();
81            let response = SingleResponse::new_canonical(sample);
82            t.description("An existing Upstream OAuth 2.0 link was associated to a user")
83                .example(response)
84        })
85        .response_with::<201, Json<SingleResponse<UpstreamOAuthLink>>, _>(|t| {
86            let [sample, ..] = UpstreamOAuthLink::samples();
87            let response = SingleResponse::new_canonical(sample);
88            t.description("A new Upstream OAuth 2.0 link was created")
89                .example(response)
90        })
91        .response_with::<409, RouteError, _>(|t| {
92            let [provider_sample, ..] = UpstreamOAuthLink::samples();
93            let response = ErrorResponse::from_error(&RouteError::LinkAlreadyExists(
94                provider_sample.id(),
95                String::from("subject1"),
96            ));
97            t.description("The subject from the provider is already linked to another user")
98                .example(response)
99        })
100        .response_with::<404, RouteError, _>(|t| {
101            let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
102            t.description("User or provider was not found")
103                .example(response)
104        })
105}
106
107#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all)]
108pub async fn handler(
109    CallContext {
110        mut repo, clock, ..
111    }: CallContext,
112    NoApi(mut rng): NoApi<BoxRng>,
113    Json(params): Json<Request>,
114) -> Result<(StatusCode, Json<SingleResponse<UpstreamOAuthLink>>), RouteError> {
115    // Find the user
116    let user = repo
117        .user()
118        .lookup(params.user_id)
119        .await?
120        .ok_or(RouteError::UserNotFound(params.user_id))?;
121
122    // Find the provider
123    let provider = repo
124        .upstream_oauth_provider()
125        .lookup(params.provider_id)
126        .await?
127        .ok_or(RouteError::ProviderNotFound(params.provider_id))?;
128
129    let maybe_link = repo
130        .upstream_oauth_link()
131        .find_by_subject(&provider, &params.subject)
132        .await?;
133    if let Some(mut link) = maybe_link {
134        if link.user_id.is_some() {
135            return Err(RouteError::LinkAlreadyExists(
136                link.provider_id,
137                link.subject,
138            ));
139        }
140
141        repo.upstream_oauth_link()
142            .associate_to_user(&link, &user)
143            .await?;
144        link.user_id = Some(user.id);
145
146        repo.save().await?;
147
148        return Ok((
149            StatusCode::OK,
150            Json(SingleResponse::new_canonical(link.into())),
151        ));
152    }
153
154    let mut link = repo
155        .upstream_oauth_link()
156        .add(
157            &mut rng,
158            &clock,
159            &provider,
160            params.subject,
161            params.human_account_name,
162        )
163        .await?;
164
165    repo.upstream_oauth_link()
166        .associate_to_user(&link, &user)
167        .await?;
168    link.user_id = Some(user.id);
169
170    repo.save().await?;
171
172    Ok((
173        StatusCode::CREATED,
174        Json(SingleResponse::new_canonical(link.into())),
175    ))
176}
177
178#[cfg(test)]
179mod tests {
180    use hyper::{Request, StatusCode};
181    use insta::assert_json_snapshot;
182    use sqlx::PgPool;
183    use ulid::Ulid;
184
185    use super::super::test_utils;
186    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
187
188    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
189    async fn test_create(pool: PgPool) {
190        setup();
191        let mut state = TestState::from_pool(pool).await.unwrap();
192        let token = state.token_with_scope("urn:mas:admin").await;
193        let mut rng = state.rng();
194        let mut repo = state.repository().await.unwrap();
195
196        let alice = repo
197            .user()
198            .add(&mut rng, &state.clock, "alice".to_owned())
199            .await
200            .unwrap();
201
202        let provider = repo
203            .upstream_oauth_provider()
204            .add(
205                &mut rng,
206                &state.clock,
207                test_utils::oidc_provider_params("provider1"),
208            )
209            .await
210            .unwrap();
211
212        repo.save().await.unwrap();
213
214        let request = Request::post("/api/admin/v1/upstream-oauth-links")
215            .bearer(&token)
216            .json(serde_json::json!({
217                "user_id": alice.id,
218                "provider_id": provider.id,
219                "subject": "subject1"
220            }));
221        let response = state.request(request).await;
222        response.assert_status(StatusCode::CREATED);
223        let body: serde_json::Value = response.json();
224        assert_json_snapshot!(body, @r###"
225        {
226          "data": {
227            "type": "upstream-oauth-link",
228            "id": "01FSHN9AG07HNEZXNQM2KNBNF6",
229            "attributes": {
230              "created_at": "2022-01-16T14:40:00Z",
231              "provider_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
232              "subject": "subject1",
233              "user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
234              "human_account_name": null
235            },
236            "links": {
237              "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
238            }
239          },
240          "links": {
241            "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
242          }
243        }
244        "###);
245    }
246
247    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
248    async fn test_association(pool: PgPool) {
249        setup();
250        let mut state = TestState::from_pool(pool).await.unwrap();
251        let token = state.token_with_scope("urn:mas:admin").await;
252        let mut rng = state.rng();
253        let mut repo = state.repository().await.unwrap();
254
255        let alice = repo
256            .user()
257            .add(&mut rng, &state.clock, "alice".to_owned())
258            .await
259            .unwrap();
260
261        let provider = repo
262            .upstream_oauth_provider()
263            .add(
264                &mut rng,
265                &state.clock,
266                test_utils::oidc_provider_params("provider1"),
267            )
268            .await
269            .unwrap();
270
271        // Existing unfinished link
272        repo.upstream_oauth_link()
273            .add(
274                &mut rng,
275                &state.clock,
276                &provider,
277                String::from("subject1"),
278                None,
279            )
280            .await
281            .unwrap();
282
283        repo.save().await.unwrap();
284
285        let request = Request::post("/api/admin/v1/upstream-oauth-links")
286            .bearer(&token)
287            .json(serde_json::json!({
288                "user_id": alice.id,
289                "provider_id": provider.id,
290                "subject": "subject1"
291            }));
292        let response = state.request(request).await;
293        response.assert_status(StatusCode::OK);
294        let body: serde_json::Value = response.json();
295        assert_json_snapshot!(body, @r###"
296        {
297          "data": {
298            "type": "upstream-oauth-link",
299            "id": "01FSHN9AG09NMZYX8MFYH578R9",
300            "attributes": {
301              "created_at": "2022-01-16T14:40:00Z",
302              "provider_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
303              "subject": "subject1",
304              "user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
305              "human_account_name": null
306            },
307            "links": {
308              "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG09NMZYX8MFYH578R9"
309            }
310          },
311          "links": {
312            "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG09NMZYX8MFYH578R9"
313          }
314        }
315        "###);
316    }
317
318    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
319    async fn test_link_already_exists(pool: PgPool) {
320        setup();
321        let mut state = TestState::from_pool(pool).await.unwrap();
322        let token = state.token_with_scope("urn:mas:admin").await;
323        let mut rng = state.rng();
324        let mut repo = state.repository().await.unwrap();
325
326        let alice = repo
327            .user()
328            .add(&mut rng, &state.clock, "alice".to_owned())
329            .await
330            .unwrap();
331
332        let bob = repo
333            .user()
334            .add(&mut rng, &state.clock, "bob".to_owned())
335            .await
336            .unwrap();
337
338        let provider = repo
339            .upstream_oauth_provider()
340            .add(
341                &mut rng,
342                &state.clock,
343                test_utils::oidc_provider_params("provider1"),
344            )
345            .await
346            .unwrap();
347
348        let link = repo
349            .upstream_oauth_link()
350            .add(
351                &mut rng,
352                &state.clock,
353                &provider,
354                String::from("subject1"),
355                None,
356            )
357            .await
358            .unwrap();
359
360        repo.upstream_oauth_link()
361            .associate_to_user(&link, &alice)
362            .await
363            .unwrap();
364
365        repo.save().await.unwrap();
366
367        let request = Request::post("/api/admin/v1/upstream-oauth-links")
368            .bearer(&token)
369            .json(serde_json::json!({
370                "user_id": bob.id,
371                "provider_id": provider.id,
372                "subject": "subject1"
373            }));
374        let response = state.request(request).await;
375        response.assert_status(StatusCode::CONFLICT);
376        let body: serde_json::Value = response.json();
377        assert_json_snapshot!(body, @r###"
378        {
379          "errors": [
380            {
381              "title": "Upstream Oauth 2.0 Provider ID 01FSHN9AG09NMZYX8MFYH578R9 with subject subject1 is already linked to a user"
382            }
383          ]
384        }
385        "###);
386    }
387
388    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
389    async fn test_user_not_found(pool: PgPool) {
390        setup();
391        let mut state = TestState::from_pool(pool).await.unwrap();
392        let token = state.token_with_scope("urn:mas:admin").await;
393        let mut rng = state.rng();
394        let mut repo = state.repository().await.unwrap();
395
396        let provider = repo
397            .upstream_oauth_provider()
398            .add(
399                &mut rng,
400                &state.clock,
401                test_utils::oidc_provider_params("provider1"),
402            )
403            .await
404            .unwrap();
405
406        repo.save().await.unwrap();
407
408        let request = Request::post("/api/admin/v1/upstream-oauth-links")
409            .bearer(&token)
410            .json(serde_json::json!({
411                "user_id": Ulid::nil(),
412                "provider_id": provider.id,
413                "subject": "subject1"
414            }));
415        let response = state.request(request).await;
416        response.assert_status(StatusCode::NOT_FOUND);
417        let body: serde_json::Value = response.json();
418        assert_json_snapshot!(body, @r###"
419        {
420          "errors": [
421            {
422              "title": "User ID 00000000000000000000000000 not found"
423            }
424          ]
425        }
426        "###);
427    }
428
429    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
430    async fn test_provider_not_found(pool: PgPool) {
431        setup();
432        let mut state = TestState::from_pool(pool).await.unwrap();
433        let token = state.token_with_scope("urn:mas:admin").await;
434        let mut rng = state.rng();
435        let mut repo = state.repository().await.unwrap();
436
437        let alice = repo
438            .user()
439            .add(&mut rng, &state.clock, "alice".to_owned())
440            .await
441            .unwrap();
442
443        repo.save().await.unwrap();
444
445        let request = Request::post("/api/admin/v1/upstream-oauth-links")
446            .bearer(&token)
447            .json(serde_json::json!({
448                "user_id": alice.id,
449                "provider_id": Ulid::nil(),
450                "subject": "subject1"
451            }));
452        let response = state.request(request).await;
453        response.assert_status(StatusCode::NOT_FOUND);
454        let body: serde_json::Value = response.json();
455        assert_json_snapshot!(body, @r###"
456        {
457          "errors": [
458            {
459              "title": "Upstream OAuth 2.0 Provider ID 00000000000000000000000000 not found"
460            }
461          ]
462        }
463        "###);
464    }
465}