mas_storage_pg/personal/
mod.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6//! A module containing the PostgreSQL implementations of the
7//! Personal Access Token / Personal Session repositories
8
9mod access_token;
10mod session;
11
12pub use access_token::PgPersonalAccessTokenRepository;
13pub use session::PgPersonalSessionRepository;
14
15#[cfg(test)]
16mod tests {
17    use chrono::Duration;
18    use mas_data_model::{
19        Clock, Device, clock::MockClock, personal::session::PersonalSessionOwner,
20    };
21    use mas_storage::{
22        Pagination, RepositoryAccess,
23        personal::{
24            PersonalAccessTokenRepository, PersonalSessionFilter, PersonalSessionRepository,
25        },
26        user::UserRepository,
27    };
28    use oauth2_types::scope::{OPENID, PROFILE, Scope};
29    use rand::SeedableRng;
30    use rand_chacha::ChaChaRng;
31    use sqlx::PgPool;
32
33    use crate::PgRepository;
34
35    #[sqlx::test(migrator = "crate::MIGRATOR")]
36    async fn test_session_repository(pool: PgPool) {
37        let mut rng = ChaChaRng::seed_from_u64(42);
38        let clock = MockClock::default();
39        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
40
41        // Create a user
42        let admin_user = repo
43            .user()
44            .add(&mut rng, &clock, "john".to_owned())
45            .await
46            .unwrap();
47        let bot_user = repo
48            .user()
49            .add(&mut rng, &clock, "marvin".to_owned())
50            .await
51            .unwrap();
52
53        let all = PersonalSessionFilter::new().for_actor_user(&bot_user);
54        let active = all.active_only();
55        let finished = all.finished_only();
56        let pagination = Pagination::first(10);
57
58        assert_eq!(repo.personal_session().count(all).await.unwrap(), 0);
59        assert_eq!(repo.personal_session().count(active).await.unwrap(), 0);
60        assert_eq!(repo.personal_session().count(finished).await.unwrap(), 0);
61
62        // We start off with no sessions
63        let full_list = repo.personal_session().list(all, pagination).await.unwrap();
64        assert!(full_list.edges.is_empty());
65        let active_list = repo
66            .personal_session()
67            .list(active, pagination)
68            .await
69            .unwrap();
70        assert!(active_list.edges.is_empty());
71        let finished_list = repo
72            .personal_session()
73            .list(finished, pagination)
74            .await
75            .unwrap();
76        assert!(finished_list.edges.is_empty());
77
78        // Start a personal session for that user
79        let device = Device::generate(&mut rng);
80        let scope: Scope = [OPENID, PROFILE]
81            .into_iter()
82            .chain(device.to_scope_token().unwrap())
83            .collect();
84        let session = repo
85            .personal_session()
86            .add(
87                &mut rng,
88                &clock,
89                (&admin_user).into(),
90                &bot_user,
91                "Test Personal Session".to_owned(),
92                scope.clone(),
93            )
94            .await
95            .unwrap();
96        assert_eq!(session.owner, PersonalSessionOwner::User(admin_user.id));
97        assert_eq!(session.actor_user_id, bot_user.id);
98        assert!(session.is_valid());
99        assert!(!session.is_revoked());
100        assert_eq!(session.scope, scope);
101
102        assert_eq!(repo.personal_session().count(all).await.unwrap(), 1);
103        assert_eq!(repo.personal_session().count(active).await.unwrap(), 1);
104        assert_eq!(repo.personal_session().count(finished).await.unwrap(), 0);
105
106        let full_list = repo.personal_session().list(all, pagination).await.unwrap();
107        assert_eq!(full_list.edges.len(), 1);
108        assert_eq!(full_list.edges[0].node.0.id, session.id);
109        assert!(full_list.edges[0].node.0.is_valid());
110        let active_list = repo
111            .personal_session()
112            .list(active, pagination)
113            .await
114            .unwrap();
115        assert_eq!(active_list.edges.len(), 1);
116        assert_eq!(active_list.edges[0].node.0.id, session.id);
117        assert!(active_list.edges[0].node.0.is_valid());
118        let finished_list = repo
119            .personal_session()
120            .list(finished, pagination)
121            .await
122            .unwrap();
123        assert!(finished_list.edges.is_empty());
124
125        // Lookup the session and check it didn't change
126        let session_lookup = repo
127            .personal_session()
128            .lookup(session.id)
129            .await
130            .unwrap()
131            .expect("personal session not found");
132        assert_eq!(session_lookup.id, session.id);
133        assert_eq!(
134            session_lookup.owner,
135            PersonalSessionOwner::User(admin_user.id)
136        );
137        assert_eq!(session_lookup.actor_user_id, bot_user.id);
138        assert_eq!(session_lookup.scope, scope);
139        assert!(session_lookup.is_valid());
140        assert!(!session_lookup.is_revoked());
141
142        // Revoke the session
143        let session = repo
144            .personal_session()
145            .revoke(&clock, session)
146            .await
147            .unwrap();
148        assert!(!session.is_valid());
149        assert!(session.is_revoked());
150
151        assert_eq!(repo.personal_session().count(all).await.unwrap(), 1);
152        assert_eq!(repo.personal_session().count(active).await.unwrap(), 0);
153        assert_eq!(repo.personal_session().count(finished).await.unwrap(), 1);
154
155        let full_list = repo.personal_session().list(all, pagination).await.unwrap();
156        assert_eq!(full_list.edges.len(), 1);
157        assert_eq!(full_list.edges[0].node.0.id, session.id);
158        let active_list = repo
159            .personal_session()
160            .list(active, pagination)
161            .await
162            .unwrap();
163        assert!(active_list.edges.is_empty());
164        let finished_list = repo
165            .personal_session()
166            .list(finished, pagination)
167            .await
168            .unwrap();
169        assert_eq!(finished_list.edges.len(), 1);
170        assert_eq!(finished_list.edges[0].node.0.id, session.id);
171        assert!(finished_list.edges[0].node.0.is_revoked());
172
173        // Reload the session and check again
174        let session_lookup = repo
175            .personal_session()
176            .lookup(session.id)
177            .await
178            .unwrap()
179            .expect("personal session not found");
180        assert!(!session_lookup.is_valid());
181        assert!(session_lookup.is_revoked());
182    }
183
184    #[sqlx::test(migrator = "crate::MIGRATOR")]
185    async fn test_session_revoke_bulk(pool: PgPool) {
186        let mut rng = ChaChaRng::seed_from_u64(42);
187        let clock = MockClock::default();
188        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
189
190        let alice_user = repo
191            .user()
192            .add(&mut rng, &clock, "alice".to_owned())
193            .await
194            .unwrap();
195        let bob_user = repo
196            .user()
197            .add(&mut rng, &clock, "bob".to_owned())
198            .await
199            .unwrap();
200
201        let session1 = repo
202            .personal_session()
203            .add(
204                &mut rng,
205                &clock,
206                (&alice_user).into(),
207                &bob_user,
208                "Test Personal Session".to_owned(),
209                "openid".parse().unwrap(),
210            )
211            .await
212            .unwrap();
213        repo.personal_access_token()
214            .add(
215                &mut rng,
216                &clock,
217                &session1,
218                "mpt_hiss",
219                Some(Duration::days(42)),
220            )
221            .await
222            .unwrap();
223
224        let session2 = repo
225            .personal_session()
226            .add(
227                &mut rng,
228                &clock,
229                (&bob_user).into(),
230                &bob_user,
231                "Test Personal Session".to_owned(),
232                "openid".parse().unwrap(),
233            )
234            .await
235            .unwrap();
236        repo.personal_access_token()
237            .add(
238                &mut rng, &clock, &session2, "mpt_meow", // No expiry
239                None,
240            )
241            .await
242            .unwrap();
243
244        // Just one session without a token expiry time
245        assert_eq!(
246            repo.personal_session()
247                .revoke_bulk(
248                    &clock,
249                    PersonalSessionFilter::new()
250                        .active_only()
251                        .with_expires(false)
252                )
253                .await
254                .unwrap(),
255            1
256        );
257
258        // Just one session with a token expiry time
259        assert_eq!(
260            repo.personal_session()
261                .revoke_bulk(
262                    &clock,
263                    PersonalSessionFilter::new()
264                        .active_only()
265                        .with_expires(true)
266                )
267                .await
268                .unwrap(),
269            1
270        );
271
272        // No active sessions left
273        assert_eq!(
274            repo.personal_session()
275                .revoke_bulk(&clock, PersonalSessionFilter::new().active_only())
276                .await
277                .unwrap(),
278            0
279        );
280    }
281
282    #[sqlx::test(migrator = "crate::MIGRATOR")]
283    async fn test_access_token_repository(pool: PgPool) {
284        const FIRST_TOKEN: &str = "first_access_token";
285        const SECOND_TOKEN: &str = "second_access_token";
286        let mut rng = ChaChaRng::seed_from_u64(42);
287        let clock = MockClock::default();
288        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
289
290        // Create a user
291        let admin_user = repo
292            .user()
293            .add(&mut rng, &clock, "john".to_owned())
294            .await
295            .unwrap();
296        let bot_user = repo
297            .user()
298            .add(&mut rng, &clock, "marvin".to_owned())
299            .await
300            .unwrap();
301
302        // Start a personal session for that user
303        let device = Device::generate(&mut rng);
304        let scope: Scope = [OPENID, PROFILE]
305            .into_iter()
306            .chain(device.to_scope_token().unwrap())
307            .collect();
308        let session = repo
309            .personal_session()
310            .add(
311                &mut rng,
312                &clock,
313                (&admin_user).into(),
314                &bot_user,
315                "Test Personal Session".to_owned(),
316                scope,
317            )
318            .await
319            .unwrap();
320
321        // Add an access token to that session
322        let token = repo
323            .personal_access_token()
324            .add(
325                &mut rng,
326                &clock,
327                &session,
328                FIRST_TOKEN,
329                Some(Duration::try_minutes(1).unwrap()),
330            )
331            .await
332            .unwrap();
333        assert_eq!(token.session_id, session.id);
334
335        // Commit the txn and grab a new transaction, to test a conflict
336        repo.save().await.unwrap();
337
338        {
339            let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
340            // Adding the same token a second time should conflict
341            assert!(
342                repo.personal_access_token()
343                    .add(
344                        &mut rng,
345                        &clock,
346                        &session,
347                        FIRST_TOKEN,
348                        Some(Duration::try_minutes(1).unwrap()),
349                    )
350                    .await
351                    .is_err()
352            );
353            repo.cancel().await.unwrap();
354        }
355
356        // Grab a new repo
357        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
358
359        // Looking up via ID works
360        let token_lookup = repo
361            .personal_access_token()
362            .lookup(token.id)
363            .await
364            .unwrap()
365            .expect("personal access token not found");
366        assert_eq!(token.id, token_lookup.id);
367        assert_eq!(token_lookup.session_id, session.id);
368
369        // Looking up via the token value works
370        let token_lookup = repo
371            .personal_access_token()
372            .find_by_token(FIRST_TOKEN)
373            .await
374            .unwrap()
375            .expect("personal access token not found");
376        assert_eq!(token.id, token_lookup.id);
377        assert_eq!(token_lookup.session_id, session.id);
378
379        // Token is currently valid
380        assert!(token.is_valid(clock.now()));
381
382        clock.advance(Duration::try_minutes(1).unwrap());
383        // Token should have expired
384        assert!(!token.is_valid(clock.now()));
385
386        // Add a second access token, this time without expiration
387        let token = repo
388            .personal_access_token()
389            .add(&mut rng, &clock, &session, SECOND_TOKEN, None)
390            .await
391            .unwrap();
392        assert_eq!(token.session_id, session.id);
393
394        // Token is currently valid
395        assert!(token.is_valid(clock.now()));
396
397        // Revoke it
398        let _token = repo
399            .personal_access_token()
400            .revoke(&clock, token)
401            .await
402            .unwrap();
403
404        // Reload it
405        let token = repo
406            .personal_access_token()
407            .find_by_token(SECOND_TOKEN)
408            .await
409            .unwrap()
410            .expect("personal access token not found");
411
412        // Token is not valid anymore
413        assert!(!token.is_valid(clock.now()));
414
415        repo.save().await.unwrap();
416    }
417}