1use std::collections::BTreeMap;
18
19use matrix_sdk_common::deserialized_responses::WithheldCode;
20use ruma::{OwnedDeviceId, OwnedRoomId};
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use vodozemac::Curve25519PublicKey;
24
25use super::{EventType, ToDeviceEvent};
26use crate::types::{deserialize_curve_key, serialize_curve_key, EventEncryptionAlgorithm};
27
28pub type RoomKeyWithheldEvent = ToDeviceEvent<RoomKeyWithheldContent>;
30
31impl Clone for RoomKeyWithheldEvent {
32 fn clone(&self) -> Self {
33 Self {
34 sender: self.sender.clone(),
35 content: self.content.clone(),
36 other: self.other.clone(),
37 }
38 }
39}
40
41#[derive(Clone, Debug, Deserialize)]
49#[serde(try_from = "WithheldHelper")]
50pub enum RoomKeyWithheldContent {
51 MegolmV1AesSha2(MegolmV1AesSha2WithheldContent),
53 #[cfg(feature = "experimental-algorithms")]
55 MegolmV2AesSha2(MegolmV2AesSha2WithheldContent),
56 Unknown(UnknownRoomKeyWithHeld),
58}
59
60macro_rules! construct_withheld_content {
61 ($algorithm:ident, $code:ident, $room_id:ident, $session_id:ident, $sender_key:ident, $from_device:ident) => {
62 match $code {
63 WithheldCode::Blacklisted
64 | WithheldCode::Unverified
65 | WithheldCode::Unauthorised
66 | WithheldCode::Unavailable => {
67 let content = CommonWithheldCodeContent {
68 $room_id,
69 $session_id,
70 $sender_key,
71 $from_device,
72 other: Default::default(),
73 };
74
75 RoomKeyWithheldContent::$algorithm(
76 MegolmV1AesSha2WithheldContent::from_code_and_content($code, content),
77 )
78 }
79 WithheldCode::NoOlm => {
80 RoomKeyWithheldContent::$algorithm(MegolmV1AesSha2WithheldContent::NoOlm(
81 NoOlmWithheldContent { $sender_key, $from_device, other: Default::default() }
82 .into(),
83 ))
84 }
85 _ => unreachable!("Can't create an unknown withheld code content"),
86 }
87 };
88}
89
90impl RoomKeyWithheldContent {
91 pub fn new(
98 algorithm: EventEncryptionAlgorithm,
99 code: WithheldCode,
100 room_id: OwnedRoomId,
101 session_id: String,
102 sender_key: Curve25519PublicKey,
103 from_device: OwnedDeviceId,
104 ) -> Self {
105 let from_device = Some(from_device);
106
107 match algorithm {
108 EventEncryptionAlgorithm::MegolmV1AesSha2 => {
109 construct_withheld_content!(
110 MegolmV1AesSha2,
111 code,
112 room_id,
113 session_id,
114 sender_key,
115 from_device
116 )
117 }
118 #[cfg(feature = "experimental-algorithms")]
119 EventEncryptionAlgorithm::MegolmV2AesSha2 => {
120 construct_withheld_content!(
121 MegolmV2AesSha2,
122 code,
123 room_id,
124 session_id,
125 sender_key,
126 from_device
127 )
128 }
129 _ => unreachable!("Unsupported algorithm {algorithm}"),
130 }
131 }
132
133 pub fn withheld_code(&self) -> WithheldCode {
135 match self {
136 RoomKeyWithheldContent::MegolmV1AesSha2(c) => c.withheld_code(),
137 #[cfg(feature = "experimental-algorithms")]
138 RoomKeyWithheldContent::MegolmV2AesSha2(c) => c.withheld_code(),
139 RoomKeyWithheldContent::Unknown(c) => c.code.to_owned(),
140 }
141 }
142
143 pub fn algorithm(&self) -> EventEncryptionAlgorithm {
145 match &self {
146 RoomKeyWithheldContent::MegolmV1AesSha2(_) => EventEncryptionAlgorithm::MegolmV1AesSha2,
147 #[cfg(feature = "experimental-algorithms")]
148 RoomKeyWithheldContent::MegolmV2AesSha2(_) => EventEncryptionAlgorithm::MegolmV2AesSha2,
149 RoomKeyWithheldContent::Unknown(c) => c.algorithm.to_owned(),
150 }
151 }
152}
153
154impl EventType for RoomKeyWithheldContent {
155 const EVENT_TYPE: &'static str = "m.room_key.withheld";
156}
157
158#[derive(Debug, Deserialize, Serialize)]
159struct WithheldHelper {
160 pub algorithm: EventEncryptionAlgorithm,
161 pub reason: Option<String>,
162 pub code: WithheldCode,
163 #[serde(flatten)]
164 other: Value,
165}
166
167#[derive(Clone, Debug)]
170pub enum MegolmV1AesSha2WithheldContent {
171 BlackListed(Box<CommonWithheldCodeContent>),
173 Unverified(Box<CommonWithheldCodeContent>),
175 Unauthorised(Box<CommonWithheldCodeContent>),
177 Unavailable(Box<CommonWithheldCodeContent>),
179 NoOlm(Box<NoOlmWithheldContent>),
181}
182
183#[derive(Clone, PartialEq, Eq, Deserialize, Serialize)]
185pub struct CommonWithheldCodeContent {
186 pub room_id: OwnedRoomId,
188
189 pub session_id: String,
191
192 #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")]
194 pub sender_key: Curve25519PublicKey,
195
196 #[serde(skip_serializing_if = "Option::is_none")]
199 pub from_device: Option<OwnedDeviceId>,
200
201 #[serde(flatten)]
202 other: BTreeMap<String, Value>,
203}
204
205impl CommonWithheldCodeContent {
206 pub fn new(
208 room_id: OwnedRoomId,
209 session_id: String,
210 sender_key: Curve25519PublicKey,
211 device_id: OwnedDeviceId,
212 ) -> Self {
213 Self {
214 room_id,
215 session_id,
216 sender_key,
217 from_device: Some(device_id),
218 other: Default::default(),
219 }
220 }
221}
222
223impl MegolmV1AesSha2WithheldContent {
224 pub fn withheld_code(&self) -> WithheldCode {
226 match self {
227 MegolmV1AesSha2WithheldContent::BlackListed(_) => WithheldCode::Blacklisted,
228 MegolmV1AesSha2WithheldContent::Unverified(_) => WithheldCode::Unverified,
229 MegolmV1AesSha2WithheldContent::Unauthorised(_) => WithheldCode::Unauthorised,
230 MegolmV1AesSha2WithheldContent::Unavailable(_) => WithheldCode::Unavailable,
231 MegolmV1AesSha2WithheldContent::NoOlm(_) => WithheldCode::NoOlm,
232 }
233 }
234
235 fn from_code_and_content(code: WithheldCode, content: CommonWithheldCodeContent) -> Self {
236 let content = content.into();
237
238 match code {
239 WithheldCode::Blacklisted => Self::BlackListed(content),
240 WithheldCode::Unverified => Self::Unverified(content),
241 WithheldCode::Unauthorised => Self::Unauthorised(content),
242 WithheldCode::Unavailable => Self::Unavailable(content),
243 _ => unreachable!("This constructor requires one of the common withheld codes"),
244 }
245 }
246}
247
248#[derive(Clone, Deserialize, Serialize)]
250pub struct NoOlmWithheldContent {
251 #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")]
252 pub sender_key: Curve25519PublicKey,
254
255 #[serde(skip_serializing_if = "Option::is_none")]
258 pub from_device: Option<OwnedDeviceId>,
259
260 #[serde(flatten)]
261 other: BTreeMap<String, Value>,
262}
263
264#[cfg(not(tarpaulin_include))]
265impl std::fmt::Debug for CommonWithheldCodeContent {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("CommonWithheldCodeContent")
268 .field("room_id", &self.room_id)
269 .field("session_id", &self.session_id)
270 .field("sender_key", &self.sender_key)
271 .field("from_device", &self.from_device)
272 .finish_non_exhaustive()
273 }
274}
275
276#[cfg(not(tarpaulin_include))]
277impl std::fmt::Debug for NoOlmWithheldContent {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 f.debug_struct("NoOlmWithheldContent")
280 .field("sender_key", &self.sender_key)
281 .field("from_device", &self.from_device)
282 .finish_non_exhaustive()
283 }
284}
285
286pub type MegolmV2AesSha2WithheldContent = MegolmV1AesSha2WithheldContent;
288
289#[derive(Clone, Debug, Serialize, Deserialize)]
291pub struct UnknownRoomKeyWithHeld {
292 pub algorithm: EventEncryptionAlgorithm,
294 pub code: WithheldCode,
296 #[serde(skip_serializing_if = "Option::is_none")]
298 pub reason: Option<String>,
299 #[serde(flatten)]
301 other: BTreeMap<String, Value>,
302}
303
304impl TryFrom<WithheldHelper> for RoomKeyWithheldContent {
305 type Error = serde_json::Error;
306
307 fn try_from(value: WithheldHelper) -> Result<Self, Self::Error> {
308 let unknown = |value: WithheldHelper| -> Result<RoomKeyWithheldContent, _> {
309 Ok(Self::Unknown(UnknownRoomKeyWithHeld {
310 algorithm: value.algorithm,
311 code: value.code,
312 reason: value.reason,
313 other: serde_json::from_value(value.other)?,
314 }))
315 };
316
317 Ok(match value.algorithm {
318 EventEncryptionAlgorithm::MegolmV1AesSha2 => match value.code {
319 WithheldCode::NoOlm => {
320 let content: NoOlmWithheldContent = serde_json::from_value(value.other)?;
321 Self::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::NoOlm(content.into()))
322 }
323 WithheldCode::Blacklisted
324 | WithheldCode::Unverified
325 | WithheldCode::Unauthorised
326 | WithheldCode::Unavailable => {
327 let content: CommonWithheldCodeContent = serde_json::from_value(value.other)?;
328
329 Self::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::from_code_and_content(
330 value.code, content,
331 ))
332 }
333 _ => unknown(value)?,
334 },
335 #[cfg(feature = "experimental-algorithms")]
336 EventEncryptionAlgorithm::MegolmV2AesSha2 => match value.code {
337 WithheldCode::NoOlm => {
338 let content: NoOlmWithheldContent = serde_json::from_value(value.other)?;
339 Self::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::NoOlm(content.into()))
340 }
341 WithheldCode::Blacklisted
342 | WithheldCode::Unverified
343 | WithheldCode::Unauthorised
344 | WithheldCode::Unavailable => {
345 let content: CommonWithheldCodeContent = serde_json::from_value(value.other)?;
346
347 Self::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::from_code_and_content(
348 value.code, content,
349 ))
350 }
351 _ => unknown(value)?,
352 },
353 _ => unknown(value)?,
354 })
355 }
356}
357
358impl Serialize for RoomKeyWithheldContent {
359 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
360 where
361 S: serde::Serializer,
362 {
363 let algorithm = self.algorithm();
364
365 let helper = match self {
366 Self::MegolmV1AesSha2(r) => {
367 let code = r.withheld_code();
368 let reason = Some(code.to_string());
369
370 match r {
371 MegolmV1AesSha2WithheldContent::BlackListed(content)
372 | MegolmV1AesSha2WithheldContent::Unverified(content)
373 | MegolmV1AesSha2WithheldContent::Unauthorised(content)
374 | MegolmV1AesSha2WithheldContent::Unavailable(content) => WithheldHelper {
375 algorithm,
376 code,
377 reason,
378 other: serde_json::to_value(content).map_err(serde::ser::Error::custom)?,
379 },
380 MegolmV1AesSha2WithheldContent::NoOlm(content) => WithheldHelper {
381 algorithm,
382 code,
383 reason,
384 other: serde_json::to_value(content).map_err(serde::ser::Error::custom)?,
385 },
386 }
387 }
388 #[cfg(feature = "experimental-algorithms")]
389 Self::MegolmV2AesSha2(r) => {
390 let code = r.withheld_code();
391 let reason = Some(code.to_string());
392
393 match r {
394 MegolmV1AesSha2WithheldContent::BlackListed(content)
395 | MegolmV1AesSha2WithheldContent::Unverified(content)
396 | MegolmV1AesSha2WithheldContent::Unauthorised(content)
397 | MegolmV1AesSha2WithheldContent::Unavailable(content) => WithheldHelper {
398 algorithm,
399 code,
400 reason,
401 other: serde_json::to_value(content).map_err(serde::ser::Error::custom)?,
402 },
403 MegolmV1AesSha2WithheldContent::NoOlm(content) => WithheldHelper {
404 algorithm,
405 code,
406 reason,
407 other: serde_json::to_value(content).map_err(serde::ser::Error::custom)?,
408 },
409 }
410 }
411 Self::Unknown(r) => WithheldHelper {
412 algorithm: r.algorithm.to_owned(),
413 code: r.code.to_owned(),
414 reason: r.reason.to_owned(),
415 other: serde_json::to_value(r.other.clone()).map_err(serde::ser::Error::custom)?,
416 },
417 };
418
419 helper.serialize(serializer)
420 }
421}
422
423#[cfg(test)]
424pub(super) mod tests {
425 use std::collections::BTreeMap;
426
427 use assert_matches::assert_matches;
428 use assert_matches2::assert_let;
429 use matrix_sdk_common::deserialized_responses::WithheldCode;
430 use ruma::{device_id, room_id, serde::Raw, to_device::DeviceIdOrAllDevices, user_id};
431 use serde_json::{json, Value};
432 use vodozemac::Curve25519PublicKey;
433
434 use super::RoomKeyWithheldEvent;
435 use crate::types::{
436 events::room_key_withheld::{MegolmV1AesSha2WithheldContent, RoomKeyWithheldContent},
437 EventEncryptionAlgorithm,
438 };
439
440 pub fn json(code: &WithheldCode) -> Value {
441 json!({
442 "sender": "@alice:example.org",
443 "content": {
444 "room_id": "!DwLygpkclUAfQNnfva:localhost:8481",
445 "session_id": "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs",
446 "algorithm": "m.megolm.v1.aes-sha2",
447 "sender_key": "9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA",
448 "code": code.to_owned(),
449 "reason": code.to_string(),
450 "org.matrix.msgid": "8836f2f0-635d-4f0e-9228-446c63ba3ea3"
451 },
452 "type": "m.room_key.withheld",
453 "m.custom.top": "something custom in the top",
454 })
455 }
456
457 pub fn no_olm_json() -> Value {
458 json!({
459 "sender": "@alice:example.org",
460 "content": {
461 "algorithm": "m.megolm.v1.aes-sha2",
462 "sender_key": "9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA",
463 "code": "m.no_olm",
464 "reason": "Unable to establish a secure channel.",
465 "org.matrix.msgid": "8836f2f0-635d-4f0e-9228-446c63ba3ea3"
466 },
467 "type": "m.room_key.withheld",
468 "m.custom.top": "something custom in the top",
469 })
470 }
471
472 pub fn unknown_alg_json() -> Value {
473 json!({
474 "sender": "@alice:example.org",
475 "content": {
476 "algorithm": "caesar.cipher",
477 "sender_key": "9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA",
478 "code": "m.brutus",
479 "reason": "Tu quoque fili",
480 "org.matrix.msgid": "8836f2f0-635d-4f0e-9228-446c63ba3ea3"
481 },
482 "type": "m.room_key.withheld",
483 "m.custom.top": "something custom in the top",
484 })
485 }
486
487 pub fn unknown_code_json() -> Value {
488 json!({
489 "sender": "@alice:example.org",
490 "content": {
491 "room_id": "!DwLygpkclUAfQNnfva:localhost:8481",
492 "session_id": "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs",
493 "algorithm": "m.megolm.v1.aes-sha2",
494 "sender_key": "9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA",
495 "code": "org.mscXXX.new_code",
496 "reason": "Unable to establish a secure channel.",
497 "org.matrix.msgid": "8836f2f0-635d-4f0e-9228-446c63ba3ea3"
498 },
499 "type": "m.room_key.withheld",
500 "m.custom.top": "something custom in the top",
501 })
502 }
503
504 #[test]
505 fn deserialization() -> Result<(), serde_json::Error> {
506 let codes = [
507 WithheldCode::Unverified,
508 WithheldCode::Blacklisted,
509 WithheldCode::Unauthorised,
510 WithheldCode::Unavailable,
511 ];
512 for code in codes {
513 let json = json(&code);
514 let event: RoomKeyWithheldEvent = serde_json::from_value(json.clone())?;
515
516 assert_let!(RoomKeyWithheldContent::MegolmV1AesSha2(content) = &event.content);
517 assert_eq!(code, content.withheld_code());
518
519 assert_eq!(event.content.algorithm(), EventEncryptionAlgorithm::MegolmV1AesSha2);
520 let serialized = serde_json::to_value(event)?;
521 assert_eq!(json, serialized);
522 }
523 Ok(())
524 }
525
526 #[test]
527 fn deserialization_no_olm() -> Result<(), serde_json::Error> {
528 let json = no_olm_json();
529 let event: RoomKeyWithheldEvent = serde_json::from_value(json.clone())?;
530 assert_matches!(
531 event.content,
532 RoomKeyWithheldContent::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::NoOlm(_))
533 );
534 let serialized = serde_json::to_value(event)?;
535 assert_eq!(json, serialized);
536
537 Ok(())
538 }
539
540 #[test]
541 fn deserialization_unknown_code() -> Result<(), serde_json::Error> {
542 let json = unknown_code_json();
543 let event: RoomKeyWithheldEvent = serde_json::from_value(json.clone())?;
544 assert_matches!(event.content, RoomKeyWithheldContent::Unknown(_));
545
546 assert_let!(RoomKeyWithheldContent::Unknown(content) = &event.content);
547 assert_eq!(content.code.as_str(), "org.mscXXX.new_code");
548
549 let serialized = serde_json::to_value(event)?;
550 assert_eq!(json, serialized);
551
552 Ok(())
553 }
554
555 #[test]
556 fn deserialization_unknown_alg() -> Result<(), serde_json::Error> {
557 let json = unknown_alg_json();
558 let event: RoomKeyWithheldEvent = serde_json::from_value(json.clone())?;
559 assert_matches!(event.content, RoomKeyWithheldContent::Unknown(_));
560
561 assert_let!(RoomKeyWithheldContent::Unknown(content) = &event.content);
562 assert_matches!(&content.code, WithheldCode::_Custom(_));
563 let serialized = serde_json::to_value(event)?;
564 assert_eq!(json, serialized);
565
566 Ok(())
567 }
568
569 #[test]
570 fn serialization_to_device() {
571 let mut messages = BTreeMap::new();
572
573 let room_id = room_id!("!DwLygpkclUAfQNnfva:localhost:8481");
574 let user_id = user_id!("@alice:example.org");
575 let device_id = device_id!("DEV001");
576 let sender_key =
577 Curve25519PublicKey::from_base64("9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA")
578 .unwrap();
579
580 let content = RoomKeyWithheldContent::new(
581 EventEncryptionAlgorithm::MegolmV1AesSha2,
582 WithheldCode::Unverified,
583 room_id.to_owned(),
584 "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs".to_owned(),
585 sender_key,
586 device_id.to_owned(),
587 );
588 let content: Raw<RoomKeyWithheldContent> =
589 Raw::new(&content).expect("We can always serialize a withheld content info").cast();
590
591 messages
592 .entry(user_id.to_owned())
593 .or_insert_with(BTreeMap::new)
594 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
595
596 let serialized = serde_json::to_value(messages).unwrap();
597
598 let expected: Value = json!({
599 "@alice:example.org":{
600 "DEV001":{
601 "algorithm":"m.megolm.v1.aes-sha2",
602 "code":"m.unverified",
603 "from_device":"DEV001",
604 "reason":"The sender has disabled encrypting to unverified devices.",
605 "room_id":"!DwLygpkclUAfQNnfva:localhost:8481",
606 "sender_key":"9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA",
607 "session_id":"0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs"
608 }
609 }
610 });
611 assert_eq!(serialized, expected);
612 }
613
614 #[test]
615 fn no_olm_should_not_have_room_and_session() {
616 let room_id = room_id!("!DwLygpkclUAfQNnfva:localhost:8481");
617 let device_id = device_id!("DEV001");
618 let sender_key =
619 Curve25519PublicKey::from_base64("9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA")
620 .unwrap();
621
622 let content = RoomKeyWithheldContent::new(
623 EventEncryptionAlgorithm::MegolmV1AesSha2,
624 WithheldCode::NoOlm,
625 room_id.to_owned(),
626 "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs".to_owned(),
627 sender_key,
628 device_id.to_owned(),
629 );
630
631 assert_let!(
632 RoomKeyWithheldContent::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::NoOlm(
633 content
634 )) = content
635 );
636 assert_eq!(content.sender_key, sender_key);
637
638 let content = RoomKeyWithheldContent::new(
639 EventEncryptionAlgorithm::MegolmV1AesSha2,
640 WithheldCode::Unverified,
641 room_id.to_owned(),
642 "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs".to_owned(),
643 sender_key,
644 device_id.to_owned(),
645 );
646
647 assert_let!(
648 RoomKeyWithheldContent::MegolmV1AesSha2(MegolmV1AesSha2WithheldContent::Unverified(
649 content
650 )) = content
651 );
652 assert_eq!(content.session_id, "0ZcULv8j1nqVWx6orFjD6OW9JQHydDPXfaanA+uRyfs");
653 }
654}