mas_handlers/compat/
refresh.rs1use axum::{Json, extract::State, response::IntoResponse};
8use chrono::Duration;
9use hyper::StatusCode;
10use mas_axum_utils::record_error;
11use mas_data_model::{SiteConfig, TokenFormatError, TokenType};
12use mas_storage::{
13 BoxClock, BoxRepository, BoxRng, Clock,
14 compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
15};
16use serde::{Deserialize, Serialize};
17use serde_with::{DurationMilliSeconds, serde_as};
18use thiserror::Error;
19use ulid::Ulid;
20
21use super::MatrixError;
22use crate::{BoundActivityTracker, impl_from_error_for_route};
23
24#[derive(Debug, Deserialize)]
25pub struct RequestBody {
26 refresh_token: String,
27}
28
29#[derive(Debug, Error)]
30pub enum RouteError {
31 #[error(transparent)]
32 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
33
34 #[error("invalid token")]
35 InvalidToken(#[from] TokenFormatError),
36
37 #[error("unknown token")]
38 UnknownToken,
39
40 #[error("invalid token type {0}, expected a compat refresh token")]
41 InvalidTokenType(TokenType),
42
43 #[error("refresh token already consumed {0}")]
44 RefreshTokenConsumed(Ulid),
45
46 #[error("invalid compat session {0}")]
47 InvalidSession(Ulid),
48
49 #[error("unknown comapt session {0}")]
50 UnknownSession(Ulid),
51}
52
53impl IntoResponse for RouteError {
54 fn into_response(self) -> axum::response::Response {
55 let sentry_event_id = record_error!(self, Self::Internal(_) | Self::UnknownSession(_));
56 let response = match self {
57 Self::Internal(_) | Self::UnknownSession(_) => MatrixError {
58 errcode: "M_UNKNOWN",
59 error: "Internal error",
60 status: StatusCode::INTERNAL_SERVER_ERROR,
61 },
62 Self::InvalidToken(_)
63 | Self::UnknownToken
64 | Self::InvalidTokenType(_)
65 | Self::InvalidSession(_)
66 | Self::RefreshTokenConsumed(_) => MatrixError {
67 errcode: "M_UNKNOWN_TOKEN",
68 error: "Invalid refresh token",
69 status: StatusCode::UNAUTHORIZED,
70 },
71 };
72
73 (sentry_event_id, response).into_response()
74 }
75}
76
77impl_from_error_for_route!(mas_storage::RepositoryError);
78
79#[serde_as]
80#[derive(Debug, Serialize)]
81pub struct ResponseBody {
82 access_token: String,
83 refresh_token: String,
84 #[serde_as(as = "DurationMilliSeconds<i64>")]
85 expires_in_ms: Duration,
86}
87
88#[tracing::instrument(name = "handlers.compat.refresh.post", skip_all)]
89pub(crate) async fn post(
90 mut rng: BoxRng,
91 clock: BoxClock,
92 mut repo: BoxRepository,
93 activity_tracker: BoundActivityTracker,
94 State(site_config): State<SiteConfig>,
95 Json(input): Json<RequestBody>,
96) -> Result<impl IntoResponse, RouteError> {
97 let token_type = TokenType::check(&input.refresh_token)?;
98
99 if token_type != TokenType::CompatRefreshToken {
100 return Err(RouteError::InvalidTokenType(token_type));
101 }
102
103 let refresh_token = repo
104 .compat_refresh_token()
105 .find_by_token(&input.refresh_token)
106 .await?
107 .ok_or(RouteError::UnknownToken)?;
108
109 if !refresh_token.is_valid() {
110 return Err(RouteError::RefreshTokenConsumed(refresh_token.id));
111 }
112
113 let session = repo
114 .compat_session()
115 .lookup(refresh_token.session_id)
116 .await?
117 .ok_or(RouteError::UnknownSession(refresh_token.session_id))?;
118
119 if !session.is_valid() {
120 return Err(RouteError::InvalidSession(refresh_token.session_id));
121 }
122
123 activity_tracker
124 .record_compat_session(&clock, &session)
125 .await;
126
127 let access_token = repo
128 .compat_access_token()
129 .lookup(refresh_token.access_token_id)
130 .await?
131 .filter(|t| t.is_valid(clock.now()));
132
133 let new_refresh_token_str = TokenType::CompatRefreshToken.generate(&mut rng);
134 let new_access_token_str = TokenType::CompatAccessToken.generate(&mut rng);
135
136 let expires_in = site_config.compat_token_ttl;
137 let new_access_token = repo
138 .compat_access_token()
139 .add(
140 &mut rng,
141 &clock,
142 &session,
143 new_access_token_str,
144 Some(expires_in),
145 )
146 .await?;
147 let new_refresh_token = repo
148 .compat_refresh_token()
149 .add(
150 &mut rng,
151 &clock,
152 &session,
153 &new_access_token,
154 new_refresh_token_str,
155 )
156 .await?;
157
158 repo.compat_refresh_token()
159 .consume(&clock, refresh_token)
160 .await?;
161
162 if let Some(access_token) = access_token {
163 repo.compat_access_token()
164 .expire(&clock, access_token)
165 .await?;
166 }
167
168 repo.save().await?;
169
170 Ok(Json(ResponseBody {
171 access_token: new_access_token.token,
172 refresh_token: new_refresh_token.token,
173 expires_in_ms: expires_in,
174 }))
175}