1use ruma::{
18 events::{
19 relation::Thread,
20 room::{
21 encrypted::Relation as EncryptedRelation,
22 message::{
23 AddMentions, ForwardThread, OriginalRoomMessageEvent, Relation, ReplyWithinThread,
24 RoomMessageEventContent, RoomMessageEventContentWithoutRelation,
25 },
26 },
27 AnySyncMessageLikeEvent, AnySyncTimelineEvent, SyncMessageLikeEvent,
28 },
29 serde::Raw,
30 EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId, RoomId, UserId,
31};
32use serde::Deserialize;
33use thiserror::Error;
34use tracing::{error, instrument};
35
36use super::{EventSource, Room};
37
38#[derive(Debug)]
40pub struct Reply {
41 pub event_id: OwnedEventId,
43 pub enforce_thread: EnforceThread,
45}
46
47#[derive(Debug, Clone)]
49struct RepliedToInfo {
50 event_id: OwnedEventId,
52 sender: OwnedUserId,
54 timestamp: MilliSecondsSinceUnixEpoch,
56 content: ReplyContent,
58}
59
60#[derive(Debug, Clone)]
62enum ReplyContent {
63 Message(Box<RoomMessageEventContent>),
65 Raw(Raw<AnySyncTimelineEvent>),
67}
68
69#[derive(Debug, Error)]
71pub enum ReplyError {
72 #[error("Couldn't fetch the remote event: {0}")]
74 Fetch(Box<crate::Error>),
75 #[error("failed to deserialize event to reply to")]
77 Deserialization,
78 #[error("tried to reply to a state event")]
80 StateEvent,
81}
82
83#[derive(Clone, Copy, Debug, PartialEq, Eq)]
85pub enum EnforceThread {
86 Threaded(ReplyWithinThread),
89
90 MaybeThreaded,
93
94 Unthreaded,
97}
98
99impl Room {
100 #[instrument(skip(self, content), fields(room = %self.room_id()))]
112 pub async fn make_reply_event(
113 &self,
114 content: RoomMessageEventContentWithoutRelation,
115 reply: Reply,
116 ) -> Result<RoomMessageEventContent, ReplyError> {
117 make_reply_event(self, self.room_id(), self.own_user_id(), content, reply).await
118 }
119}
120
121async fn make_reply_event<S: EventSource>(
122 source: S,
123 room_id: &RoomId,
124 own_user_id: &UserId,
125 content: RoomMessageEventContentWithoutRelation,
126 reply: Reply,
127) -> Result<RoomMessageEventContent, ReplyError> {
128 let replied_to_info = replied_to_info_from_event_id(source, &reply.event_id).await?;
129
130 let mention_the_sender =
138 if own_user_id == replied_to_info.sender { AddMentions::No } else { AddMentions::Yes };
139
140 let content = match replied_to_info.content {
141 ReplyContent::Message(replied_to_content) => {
142 let event = OriginalRoomMessageEvent {
143 event_id: replied_to_info.event_id,
144 sender: replied_to_info.sender,
145 origin_server_ts: replied_to_info.timestamp,
146 room_id: room_id.to_owned(),
147 content: *replied_to_content,
148 unsigned: Default::default(),
149 };
150
151 match reply.enforce_thread {
152 EnforceThread::Threaded(is_reply) => {
153 content.make_for_thread(&event, is_reply, mention_the_sender)
154 }
155 EnforceThread::MaybeThreaded => {
156 content.make_reply_to(&event, ForwardThread::Yes, mention_the_sender)
157 }
158 EnforceThread::Unthreaded => {
159 content.make_reply_to(&event, ForwardThread::No, mention_the_sender)
160 }
161 }
162 }
163
164 ReplyContent::Raw(raw_event) => {
165 match reply.enforce_thread {
166 EnforceThread::Threaded(is_reply) => {
167 #[derive(Deserialize)]
172 struct ContentDeHelper {
173 #[serde(rename = "m.relates_to")]
174 relates_to: Option<EncryptedRelation>,
175 }
176
177 let previous_content =
178 raw_event.get_field::<ContentDeHelper>("content").ok().flatten();
179
180 let mut content = if is_reply == ReplyWithinThread::Yes {
181 content.make_reply_to_raw(
182 &raw_event,
183 replied_to_info.event_id.to_owned(),
184 room_id,
185 ForwardThread::No,
186 mention_the_sender,
187 )
188 } else {
189 content.into()
190 };
191
192 let thread_root = if let Some(EncryptedRelation::Thread(thread)) =
193 previous_content.as_ref().and_then(|c| c.relates_to.as_ref())
194 {
195 thread.event_id.to_owned()
196 } else {
197 replied_to_info.event_id.to_owned()
198 };
199
200 let thread = if is_reply == ReplyWithinThread::Yes {
201 Thread::reply(thread_root, replied_to_info.event_id)
202 } else {
203 Thread::plain(thread_root, replied_to_info.event_id)
204 };
205
206 content.relates_to = Some(Relation::Thread(thread));
207 content
208 }
209
210 EnforceThread::MaybeThreaded => content.make_reply_to_raw(
211 &raw_event,
212 replied_to_info.event_id,
213 room_id,
214 ForwardThread::Yes,
215 mention_the_sender,
216 ),
217
218 EnforceThread::Unthreaded => content.make_reply_to_raw(
219 &raw_event,
220 replied_to_info.event_id,
221 room_id,
222 ForwardThread::No,
223 mention_the_sender,
224 ),
225 }
226 }
227 };
228
229 Ok(content)
230}
231
232async fn replied_to_info_from_event_id<S: EventSource>(
233 source: S,
234 event_id: &EventId,
235) -> Result<RepliedToInfo, ReplyError> {
236 let event = source.get_event(event_id).await.map_err(|err| ReplyError::Fetch(Box::new(err)))?;
237
238 let raw_event = event.into_raw();
239 let event = raw_event.deserialize().map_err(|_| ReplyError::Deserialization)?;
240
241 let reply_content = match &event {
242 AnySyncTimelineEvent::MessageLike(event) => {
243 if let AnySyncMessageLikeEvent::RoomMessage(SyncMessageLikeEvent::Original(
244 original_event,
245 )) = event
246 {
247 ReplyContent::Message(Box::new(original_event.content.clone()))
248 } else {
249 ReplyContent::Raw(raw_event)
250 }
251 }
252 AnySyncTimelineEvent::State(_) => return Err(ReplyError::StateEvent),
253 };
254
255 Ok(RepliedToInfo {
256 event_id: event_id.to_owned(),
257 sender: event.sender().to_owned(),
258 timestamp: event.origin_server_ts(),
259 content: reply_content,
260 })
261}
262
263#[cfg(test)]
264mod tests {
265 use std::collections::BTreeMap;
266
267 use assert_matches2::{assert_let, assert_matches};
268 use matrix_sdk_base::deserialized_responses::TimelineEvent;
269 use matrix_sdk_test::{async_test, event_factory::EventFactory};
270 use ruma::{
271 event_id,
272 events::{
273 room::message::{Relation, ReplyWithinThread, RoomMessageEventContentWithoutRelation},
274 AnySyncTimelineEvent,
275 },
276 room_id,
277 serde::Raw,
278 user_id, EventId, OwnedEventId,
279 };
280 use serde_json::json;
281
282 use super::{make_reply_event, EnforceThread, EventSource, Reply, ReplyError};
283 use crate::{event_cache::EventCacheError, Error};
284
285 #[derive(Default)]
286 struct TestEventCache {
287 events: BTreeMap<OwnedEventId, TimelineEvent>,
288 }
289
290 impl EventSource for TestEventCache {
291 async fn get_event(&self, event_id: &EventId) -> Result<TimelineEvent, Error> {
292 self.events
293 .get(event_id)
294 .cloned()
295 .ok_or(Error::EventCache(Box::new(EventCacheError::ClientDropped)))
296 }
297 }
298
299 #[async_test]
300 async fn test_cannot_reply_to_unknown_event() {
301 let event_id = event_id!("$1");
302 let own_user_id = user_id!("@me:saucisse.bzh");
303
304 let mut cache = TestEventCache::default();
305 let f = EventFactory::new();
306 cache.events.insert(
307 event_id.to_owned(),
308 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
309 );
310
311 let room_id = room_id!("!galette:saucisse.bzh");
312 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
313
314 assert_matches!(
315 make_reply_event(
316 cache,
317 room_id,
318 own_user_id,
319 content,
320 Reply {
321 event_id: event_id!("$2").into(),
322 enforce_thread: EnforceThread::Unthreaded
323 },
324 )
325 .await,
326 Err(ReplyError::Fetch(_))
327 );
328 }
329
330 #[async_test]
331 async fn test_cannot_reply_to_invalid_event() {
332 let event_id = event_id!("$1");
333 let own_user_id = user_id!("@me:saucisse.bzh");
334
335 let mut cache = TestEventCache::default();
336
337 cache.events.insert(
338 event_id.to_owned(),
339 TimelineEvent::from_plaintext(
340 Raw::<AnySyncTimelineEvent>::from_json_string(
341 json!({
342 "content": {
343 "body": "hi"
344 },
345 "event_id": event_id,
346 "origin_server_ts": 1,
347 "type": "m.room.message",
348 })
350 .to_string(),
351 )
352 .unwrap(),
353 ),
354 );
355
356 let room_id = room_id!("!galette:saucisse.bzh");
357 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
358
359 assert_matches!(
360 make_reply_event(
361 cache,
362 room_id,
363 own_user_id,
364 content,
365 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
366 )
367 .await,
368 Err(ReplyError::Deserialization)
369 );
370 }
371
372 #[async_test]
373 async fn test_cannot_reply_to_state_event() {
374 let event_id = event_id!("$1");
375 let own_user_id = user_id!("@me:saucisse.bzh");
376
377 let mut cache = TestEventCache::default();
378 let f = EventFactory::new();
379 cache.events.insert(
380 event_id.to_owned(),
381 f.room_name("lobby").event_id(event_id).sender(own_user_id).into(),
382 );
383
384 let room_id = room_id!("!galette:saucisse.bzh");
385 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
386
387 assert_matches!(
388 make_reply_event(
389 cache,
390 room_id,
391 own_user_id,
392 content,
393 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
394 )
395 .await,
396 Err(ReplyError::StateEvent)
397 );
398 }
399
400 #[async_test]
401 async fn test_reply_unthreaded() {
402 let event_id = event_id!("$1");
403 let own_user_id = user_id!("@me:saucisse.bzh");
404
405 let mut cache = TestEventCache::default();
406 let f = EventFactory::new();
407 cache.events.insert(
408 event_id.to_owned(),
409 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
410 );
411
412 let room_id = room_id!("!galette:saucisse.bzh");
413 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
414
415 let reply_event = make_reply_event(
416 cache,
417 room_id,
418 own_user_id,
419 content,
420 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::Unthreaded },
421 )
422 .await
423 .unwrap();
424
425 assert_let!(Some(Relation::Reply { in_reply_to }) = &reply_event.relates_to);
426
427 assert_eq!(in_reply_to.event_id, event_id);
428 }
429
430 #[async_test]
431 async fn test_start_thread() {
432 let event_id = event_id!("$1");
433 let own_user_id = user_id!("@me:saucisse.bzh");
434
435 let mut cache = TestEventCache::default();
436 let f = EventFactory::new();
437 cache.events.insert(
438 event_id.to_owned(),
439 f.text_msg("hi").event_id(event_id).sender(own_user_id).into(),
440 );
441
442 let room_id = room_id!("!galette:saucisse.bzh");
443 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
444
445 let reply_event = make_reply_event(
446 cache,
447 room_id,
448 own_user_id,
449 content,
450 Reply {
451 event_id: event_id.into(),
452 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
453 },
454 )
455 .await
456 .unwrap();
457
458 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
459
460 assert_eq!(thread.event_id, event_id);
461 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
462 assert!(thread.is_falling_back);
463 }
464
465 #[async_test]
466 async fn test_reply_on_thread() {
467 let thread_root = event_id!("$1");
468 let event_id = event_id!("$2");
469 let own_user_id = user_id!("@me:saucisse.bzh");
470
471 let mut cache = TestEventCache::default();
472 let f = EventFactory::new();
473 cache.events.insert(
474 thread_root.to_owned(),
475 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
476 );
477 cache.events.insert(
478 event_id.to_owned(),
479 f.text_msg("ho")
480 .in_thread(thread_root, thread_root)
481 .event_id(event_id)
482 .sender(own_user_id)
483 .into(),
484 );
485
486 let room_id = room_id!("!galette:saucisse.bzh");
487 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
488
489 let reply_event = make_reply_event(
490 cache,
491 room_id,
492 own_user_id,
493 content,
494 Reply {
495 event_id: event_id.into(),
496 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::No),
497 },
498 )
499 .await
500 .unwrap();
501
502 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
503
504 assert_eq!(thread.event_id, thread_root);
505 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
506 assert!(thread.is_falling_back);
507 }
508
509 #[async_test]
510 async fn test_reply_on_thread_as_reply() {
511 let thread_root = event_id!("$1");
512 let event_id = event_id!("$2");
513 let own_user_id = user_id!("@me:saucisse.bzh");
514
515 let mut cache = TestEventCache::default();
516 let f = EventFactory::new();
517 cache.events.insert(
518 thread_root.to_owned(),
519 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
520 );
521 cache.events.insert(
522 event_id.to_owned(),
523 f.text_msg("ho")
524 .in_thread(thread_root, thread_root)
525 .event_id(event_id)
526 .sender(own_user_id)
527 .into(),
528 );
529
530 let room_id = room_id!("!galette:saucisse.bzh");
531 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
532
533 let reply_event = make_reply_event(
534 cache,
535 room_id,
536 own_user_id,
537 content,
538 Reply {
539 event_id: event_id.into(),
540 enforce_thread: EnforceThread::Threaded(ReplyWithinThread::Yes),
541 },
542 )
543 .await
544 .unwrap();
545
546 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
547
548 assert_eq!(thread.event_id, thread_root);
549 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
550 assert!(!thread.is_falling_back);
551 }
552
553 #[async_test]
554 async fn test_reply_forwarding_thread() {
555 let thread_root = event_id!("$1");
556 let event_id = event_id!("$2");
557 let own_user_id = user_id!("@me:saucisse.bzh");
558
559 let mut cache = TestEventCache::default();
560 let f = EventFactory::new();
561 cache.events.insert(
562 thread_root.to_owned(),
563 f.text_msg("hi").event_id(thread_root).sender(own_user_id).into(),
564 );
565 cache.events.insert(
566 event_id.to_owned(),
567 f.text_msg("ho")
568 .in_thread(thread_root, thread_root)
569 .event_id(event_id)
570 .sender(own_user_id)
571 .into(),
572 );
573
574 let room_id = room_id!("!galette:saucisse.bzh");
575 let content = RoomMessageEventContentWithoutRelation::text_plain("the reply");
576
577 let reply_event = make_reply_event(
578 cache,
579 room_id,
580 own_user_id,
581 content,
582 Reply { event_id: event_id.into(), enforce_thread: EnforceThread::MaybeThreaded },
583 )
584 .await
585 .unwrap();
586
587 assert_let!(Some(Relation::Thread(thread)) = &reply_event.relates_to);
588
589 assert_eq!(thread.event_id, thread_root);
590 assert_eq!(thread.in_reply_to.as_ref().unwrap().event_id, event_id);
591 assert!(thread.is_falling_back);
592 }
593}