matrix_sdk_base/room/
state.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bitflags::bitflags;
16use ruma::events::room::member::MembershipState;
17use serde::{Deserialize, Serialize};
18
19use super::Room;
20
21impl Room {
22    /// Get the state of the room.
23    pub fn state(&self) -> RoomState {
24        self.inner.read().room_state
25    }
26}
27
28/// Enum keeping track in which state the room is, e.g. if our own user is
29/// joined, RoomState::Invited, or has left the room.
30#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
31pub enum RoomState {
32    /// The room is in a joined state.
33    Joined,
34    /// The room is in a left state.
35    Left,
36    /// The room is in an invited state.
37    Invited,
38    /// The room is in a knocked state.
39    Knocked,
40    /// The room is in a banned state.
41    Banned,
42}
43
44impl From<&MembershipState> for RoomState {
45    fn from(membership_state: &MembershipState) -> Self {
46        match membership_state {
47            MembershipState::Ban => Self::Banned,
48            MembershipState::Invite => Self::Invited,
49            MembershipState::Join => Self::Joined,
50            MembershipState::Knock => Self::Knocked,
51            MembershipState::Leave => Self::Left,
52            _ => panic!("Unexpected MembershipState: {membership_state}"),
53        }
54    }
55}
56
57bitflags! {
58    /// Room state filter as a bitset.
59    ///
60    /// Note that [`RoomStateFilter::empty()`] doesn't filter the results and
61    /// is equivalent to [`RoomStateFilter::all()`].
62    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
63    pub struct RoomStateFilter: u16 {
64        /// The room is in a joined state.
65        const JOINED   = 0b00000001;
66        /// The room is in an invited state.
67        const INVITED  = 0b00000010;
68        /// The room is in a left state.
69        const LEFT     = 0b00000100;
70        /// The room is in a knocked state.
71        const KNOCKED  = 0b00001000;
72        /// The room is in a banned state.
73        const BANNED   = 0b00010000;
74    }
75}
76
77impl RoomStateFilter {
78    /// Whether the given room state matches this `RoomStateFilter`.
79    pub fn matches(&self, state: RoomState) -> bool {
80        if self.is_empty() {
81            return true;
82        }
83
84        let bit_state = match state {
85            RoomState::Joined => Self::JOINED,
86            RoomState::Left => Self::LEFT,
87            RoomState::Invited => Self::INVITED,
88            RoomState::Knocked => Self::KNOCKED,
89            RoomState::Banned => Self::BANNED,
90        };
91
92        self.contains(bit_state)
93    }
94
95    /// Get this `RoomStateFilter` as a list of matching [`RoomState`]s.
96    pub fn as_vec(&self) -> Vec<RoomState> {
97        let mut states = Vec::new();
98
99        if self.contains(Self::JOINED) {
100            states.push(RoomState::Joined);
101        }
102        if self.contains(Self::LEFT) {
103            states.push(RoomState::Left);
104        }
105        if self.contains(Self::INVITED) {
106            states.push(RoomState::Invited);
107        }
108        if self.contains(Self::KNOCKED) {
109            states.push(RoomState::Knocked);
110        }
111        if self.contains(Self::BANNED) {
112            states.push(RoomState::Banned);
113        }
114
115        states
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use matrix_sdk_test::async_test;
122    use ruma::owned_room_id;
123
124    use super::{RoomState, RoomStateFilter};
125    use crate::test_utils::logged_in_base_client;
126
127    #[async_test]
128    async fn test_room_state_filters() {
129        let client = logged_in_base_client(None).await;
130
131        let joined_room_id = owned_room_id!("!joined:example.org");
132        client.get_or_create_room(&joined_room_id, RoomState::Joined);
133
134        let invited_room_id = owned_room_id!("!invited:example.org");
135        client.get_or_create_room(&invited_room_id, RoomState::Invited);
136
137        let left_room_id = owned_room_id!("!left:example.org");
138        client.get_or_create_room(&left_room_id, RoomState::Left);
139
140        let knocked_room_id = owned_room_id!("!knocked:example.org");
141        client.get_or_create_room(&knocked_room_id, RoomState::Knocked);
142
143        let banned_room_id = owned_room_id!("!banned:example.org");
144        client.get_or_create_room(&banned_room_id, RoomState::Banned);
145
146        let joined_rooms = client.rooms_filtered(RoomStateFilter::JOINED);
147        assert_eq!(joined_rooms.len(), 1);
148        assert_eq!(joined_rooms[0].state(), RoomState::Joined);
149        assert_eq!(joined_rooms[0].room_id, joined_room_id);
150
151        let invited_rooms = client.rooms_filtered(RoomStateFilter::INVITED);
152        assert_eq!(invited_rooms.len(), 1);
153        assert_eq!(invited_rooms[0].state(), RoomState::Invited);
154        assert_eq!(invited_rooms[0].room_id, invited_room_id);
155
156        let left_rooms = client.rooms_filtered(RoomStateFilter::LEFT);
157        assert_eq!(left_rooms.len(), 1);
158        assert_eq!(left_rooms[0].state(), RoomState::Left);
159        assert_eq!(left_rooms[0].room_id, left_room_id);
160
161        let knocked_rooms = client.rooms_filtered(RoomStateFilter::KNOCKED);
162        assert_eq!(knocked_rooms.len(), 1);
163        assert_eq!(knocked_rooms[0].state(), RoomState::Knocked);
164        assert_eq!(knocked_rooms[0].room_id, knocked_room_id);
165
166        let banned_rooms = client.rooms_filtered(RoomStateFilter::BANNED);
167        assert_eq!(banned_rooms.len(), 1);
168        assert_eq!(banned_rooms[0].state(), RoomState::Banned);
169        assert_eq!(banned_rooms[0].room_id, banned_room_id);
170    }
171
172    #[test]
173    fn test_room_state_filters_as_vec() {
174        assert_eq!(RoomStateFilter::JOINED.as_vec(), vec![RoomState::Joined]);
175        assert_eq!(RoomStateFilter::LEFT.as_vec(), vec![RoomState::Left]);
176        assert_eq!(RoomStateFilter::INVITED.as_vec(), vec![RoomState::Invited]);
177        assert_eq!(RoomStateFilter::KNOCKED.as_vec(), vec![RoomState::Knocked]);
178        assert_eq!(RoomStateFilter::BANNED.as_vec(), vec![RoomState::Banned]);
179
180        // Check all filters are taken into account
181        assert_eq!(
182            RoomStateFilter::all().as_vec(),
183            vec![
184                RoomState::Joined,
185                RoomState::Left,
186                RoomState::Invited,
187                RoomState::Knocked,
188                RoomState::Banned
189            ]
190        );
191    }
192}