authenticator/models/
provider.rs

1use std::{
2    string::ToString,
3    time::{SystemTime, UNIX_EPOCH},
4};
5
6use anyhow::Result;
7use diesel::prelude::*;
8use gtk::{
9    gdk_pixbuf, gio,
10    glib::{self, clone},
11    prelude::*,
12    subclass::prelude::*,
13};
14use url::Url;
15
16use crate::{
17    models::{Account, AccountsModel, Algorithm, FAVICONS_PATH, Method, OTP, database},
18    schema::providers,
19};
20
21pub struct ProviderPatch {
22    pub name: String,
23    pub website: Option<String>,
24    pub help_url: Option<String>,
25    pub image_uri: Option<String>,
26    pub period: i32,
27    pub digits: i32,
28    pub default_counter: i32,
29    pub algorithm: String,
30    pub method: String,
31    pub is_backup_restore: bool,
32}
33
34#[derive(Insertable)]
35#[diesel(table_name = providers)]
36struct NewProvider {
37    pub name: String,
38    pub website: Option<String>,
39    pub help_url: Option<String>,
40    pub image_uri: Option<String>,
41    pub period: i32,
42    pub digits: i32,
43    pub default_counter: i32,
44    pub algorithm: String,
45    pub method: String,
46}
47
48#[derive(Identifiable, Queryable)]
49#[diesel(table_name = providers)]
50pub struct DieselProvider {
51    pub id: i32,
52    pub name: String,
53    pub website: Option<String>,
54    pub help_url: Option<String>,
55    pub image_uri: Option<String>,
56    pub period: i32,
57    pub digits: i32,
58    pub default_counter: i32,
59    pub algorithm: String,
60    pub method: String,
61}
62
63mod imp {
64    use std::cell::{Cell, RefCell};
65
66    use super::*;
67
68    #[derive(glib::Properties)]
69    #[properties(wrapper_type = super::Provider)]
70    pub struct Provider {
71        #[property(get, set, construct_only)]
72        pub id: Cell<u32>,
73        #[property(get, set)]
74        pub name: RefCell<String>,
75        #[property(get, set, maximum = 1000, default = OTP::DEFAULT_PERIOD)]
76        pub period: Cell<u32>,
77        #[property(get, set, builder(Method::default()))]
78        pub method: Cell<Method>,
79        #[property(get, set, default = OTP::DEFAULT_COUNTER)]
80        pub default_counter: Cell<u32>,
81        #[property(get, set, builder(Algorithm::default()))]
82        pub algorithm: Cell<Algorithm>,
83        #[property(get, set, maximum = 1000, default = OTP::DEFAULT_DIGITS)]
84        pub digits: Cell<u32>,
85        #[property(get, set)]
86        pub website: RefCell<Option<String>>,
87        #[property(get, set)]
88        pub help_url: RefCell<Option<String>>,
89        #[property(get, set = Self::set_image_uri, explicit_notify)]
90        pub image_uri: RefCell<Option<String>>,
91        #[property(get, set)]
92        pub remaining_time: Cell<u64>,
93        #[property(get)]
94        pub accounts_model: AccountsModel,
95        pub filter_model: gtk::FilterListModel,
96        pub tick_callback: RefCell<Option<glib::SourceId>>,
97    }
98
99    #[glib::object_subclass]
100    impl ObjectSubclass for Provider {
101        const NAME: &'static str = "Provider";
102        type Type = super::Provider;
103
104        fn new() -> Self {
105            let model = AccountsModel::default();
106            Self {
107                id: Cell::default(),
108                default_counter: Cell::new(OTP::DEFAULT_COUNTER),
109                algorithm: Cell::new(Algorithm::default()),
110                digits: Cell::new(OTP::DEFAULT_DIGITS),
111                name: RefCell::default(),
112                website: RefCell::default(),
113                help_url: RefCell::default(),
114                image_uri: RefCell::default(),
115                method: Cell::new(Method::default()),
116                period: Cell::new(OTP::DEFAULT_PERIOD),
117                filter_model: gtk::FilterListModel::new(Some(model.clone()), None::<gtk::Filter>),
118                accounts_model: model,
119                tick_callback: RefCell::default(),
120                remaining_time: Cell::default(),
121            }
122        }
123    }
124
125    #[glib::derived_properties]
126    impl ObjectImpl for Provider {
127        fn dispose(&self) {
128            // Stop ticking
129            if let Some(source_id) = self.tick_callback.borrow_mut().take() {
130                source_id.remove();
131            }
132        }
133    }
134
135    impl Provider {
136        fn set_image_uri_inner(&self, id: i32, uri: Option<&str>) -> anyhow::Result<()> {
137            let db = database::connection();
138            let mut conn = db.get()?;
139
140            let target = providers::table.filter(providers::columns::id.eq(id));
141            diesel::update(target)
142                .set(providers::columns::image_uri.eq(uri))
143                .execute(&mut conn)?;
144
145            Ok(())
146        }
147
148        fn set_image_uri(&self, uri: Option<&str>) {
149            let obj = self.obj();
150            if let Err(err) = self.set_image_uri_inner(obj.id() as i32, uri) {
151                tracing::warn!("Failed to update provider image {}", err);
152            }
153            self.image_uri.replace(uri.map(ToOwned::to_owned));
154            obj.notify_image_uri();
155        }
156    }
157}
158
159glib::wrapper! {
160    pub struct Provider(ObjectSubclass<imp::Provider>);
161}
162
163impl Provider {
164    #[allow(clippy::too_many_arguments)]
165    pub fn create(
166        name: &str,
167        period: u32,
168        algorithm: Algorithm,
169        website: Option<String>,
170        method: Method,
171        digits: u32,
172        default_counter: u32,
173        help_url: Option<String>,
174        image_uri: Option<String>,
175    ) -> Result<Self> {
176        let db = database::connection();
177        let mut conn = db.get()?;
178
179        diesel::insert_into(providers::table)
180            .values(NewProvider {
181                name: name.to_string(),
182                period: period as i32,
183                method: method.to_string(),
184                website,
185                algorithm: algorithm.to_string(),
186                digits: digits as i32,
187                default_counter: default_counter as i32,
188                help_url,
189                image_uri,
190            })
191            .execute(&mut conn)?;
192
193        providers::table
194            .order(providers::columns::id.desc())
195            .first::<DieselProvider>(&mut conn)
196            .map_err(From::from)
197            .map(From::from)
198    }
199
200    pub fn load() -> Result<impl Iterator<Item = Self>> {
201        use crate::schema::providers::dsl::*;
202        let db = database::connection();
203        let mut conn = db.get()?;
204
205        let results = providers
206            .load::<DieselProvider>(&mut conn)?
207            .into_iter()
208            .map(From::from)
209            .inspect(|p: &Provider| {
210                let accounts = Account::load(p).unwrap().collect::<Vec<_>>();
211                p.add_accounts(&accounts);
212            });
213        Ok(results)
214    }
215
216    #[allow(clippy::too_many_arguments)]
217    pub fn new(
218        id: u32,
219        name: &str,
220        period: u32,
221        method: Method,
222        algorithm: Algorithm,
223        digits: u32,
224        default_counter: u32,
225        website: Option<String>,
226        help_url: Option<String>,
227        image_uri: Option<String>,
228    ) -> Provider {
229        glib::Object::builder()
230            .property("id", id)
231            .property("name", name)
232            .property("website", website)
233            .property("help-url", help_url)
234            .property("image-uri", image_uri)
235            .property("period", period)
236            .property("method", method)
237            .property("algorithm", algorithm)
238            .property("digits", digits)
239            .property("default-counter", default_counter)
240            .build()
241    }
242
243    pub async fn favicon(
244        website: String,
245        name: String,
246        id: u32,
247    ) -> Result<String, Box<dyn std::error::Error>> {
248        let website_url = Url::parse(&website)?;
249        let favicon = favicon_scrapper::Scrapper::from_url(&website_url).await?;
250        tracing::debug!("Found the following icons {:#?} for {}", favicon, name);
251
252        let icon_name = format!("{id}_{}", name.replace(' ', "_"));
253        let icon_name = glib::base64_encode(icon_name.as_bytes());
254        let small_icon_name = format!("{icon_name}_32x32");
255        let large_icon_name = format!("{icon_name}_96x96");
256        // TODO: figure out why trying to grab icons at specific size causes stack size
257        // errors We need two sizes:
258        // - 32x32 for the accounts lists
259        // - 96x96 elsewhere
260        if let Some(best_favicon) = favicon.find_best().await {
261            tracing::debug!("Largest favicon found is {:#?}", best_favicon);
262            let cache_path = FAVICONS_PATH.join(&*icon_name);
263            best_favicon.save(cache_path.clone()).await?;
264            // Don't try to scale down svg variants
265            if !best_favicon.metadata().format().is_svg() {
266                tracing::debug!("Creating scaled down variants for {:#?}", cache_path);
267                {
268                    let pixbuf = gdk_pixbuf::Pixbuf::from_file(cache_path.clone())?;
269                    tracing::debug!("Creating a 32x32 variant of the favicon");
270                    let small_pixbuf = pixbuf
271                        .scale_simple(32, 32, gdk_pixbuf::InterpType::Bilinear)
272                        .unwrap();
273
274                    let mut small_cache = cache_path.clone();
275                    small_cache.set_file_name(small_icon_name);
276                    small_pixbuf.savev(small_cache.clone(), "png", &[])?;
277
278                    tracing::debug!("Creating a 96x96 variant of the favicon");
279                    let large_pixbuf = pixbuf
280                        .scale_simple(96, 96, gdk_pixbuf::InterpType::Bilinear)
281                        .unwrap();
282                    let mut large_cache = cache_path.clone();
283                    large_cache.set_file_name(large_icon_name);
284                    large_pixbuf.savev(large_cache.clone(), "png", &[])?;
285                };
286                tokio::fs::remove_file(cache_path).await?;
287            } else {
288                let mut small_cache = cache_path.clone();
289                small_cache.set_file_name(small_icon_name);
290                tokio::fs::symlink(&cache_path, small_cache).await?;
291
292                let mut large_cache = cache_path.clone();
293                large_cache.set_file_name(large_icon_name);
294                tokio::fs::symlink(&cache_path, large_cache).await?;
295            }
296            Ok(icon_name.to_string())
297        } else {
298            Err(Box::new(favicon_scrapper::Error::NoResults))
299        }
300    }
301
302    pub fn delete(&self) -> Result<()> {
303        let db = database::connection();
304        let mut conn = db.get()?;
305        diesel::delete(providers::table.filter(providers::columns::id.eq(self.id() as i32)))
306            .execute(&mut conn)?;
307        Ok(())
308    }
309
310    pub fn update(&self, patch: &ProviderPatch) -> Result<()> {
311        // Can't implement PartialEq because of how GObject works
312        if patch.name == self.name()
313            && patch.website == self.website()
314            && patch.help_url == self.help_url()
315            && patch.image_uri == self.image_uri()
316            && patch.period == self.period() as i32
317            && patch.digits == self.digits() as i32
318            && patch.default_counter == self.default_counter() as i32
319            && patch.algorithm == self.algorithm().to_string()
320            && patch.method == self.method().to_string()
321        {
322            return Ok(());
323        }
324
325        let db = database::connection();
326        let mut conn = db.get()?;
327
328        let target = providers::table.filter(providers::columns::id.eq(self.id() as i32));
329        diesel::update(target)
330            .set((
331                providers::columns::algorithm.eq(&patch.algorithm),
332                providers::columns::method.eq(&patch.method),
333                providers::columns::digits.eq(&patch.digits),
334                providers::columns::period.eq(&patch.period),
335                providers::columns::default_counter.eq(&patch.default_counter),
336                providers::columns::name.eq(&patch.name),
337            ))
338            .execute(&mut conn)?;
339        if !patch.is_backup_restore {
340            diesel::update(target)
341                .set((
342                    providers::columns::image_uri.eq(&patch.image_uri),
343                    providers::columns::website.eq(&patch.website),
344                    providers::columns::help_url.eq(&patch.help_url),
345                ))
346                .execute(&mut conn)?;
347        };
348
349        self.set_properties(&[
350            ("name", &patch.name),
351            ("period", &(patch.period as u32)),
352            ("method", &patch.method.parse::<Method>()?),
353            ("digits", &(patch.digits as u32)),
354            ("algorithm", &patch.algorithm.parse::<Algorithm>()?),
355            ("default-counter", &(patch.default_counter as u32)),
356        ]);
357
358        if !patch.is_backup_restore {
359            self.set_properties(&[
360                ("image-uri", &patch.image_uri),
361                ("website", &patch.website),
362                ("help-url", &patch.help_url),
363            ]);
364        }
365        Ok(())
366    }
367
368    pub fn open_help(&self) {
369        if let Some(ref url) = self.help_url() {
370            gio::AppInfo::launch_default_for_uri(url, None::<&gio::AppLaunchContext>).unwrap();
371        }
372    }
373
374    fn tick(&self) {
375        let period = self.period() as u64;
376        let remaining_time: u64 = period
377            - SystemTime::now()
378                .duration_since(UNIX_EPOCH)
379                .unwrap()
380                .as_secs()
381                % period;
382        if period == remaining_time {
383            self.regenerate_otp();
384        }
385        self.set_remaining_time(remaining_time);
386    }
387
388    fn setup_tick_callback(&self) {
389        if self.imp().tick_callback.borrow().is_some() || self.method().is_event_based() {
390            return;
391        }
392        self.set_remaining_time(self.period() as u64);
393
394        match self.method() {
395            Method::TOTP | Method::Steam => {
396                let source_id = glib::timeout_add_seconds_local(
397                    1,
398                    clone!(
399                        #[weak(rename_to = provider)]
400                        self,
401                        #[upgrade_or]
402                        glib::ControlFlow::Break,
403                        move || {
404                            provider.tick();
405                            glib::ControlFlow::Continue
406                        }
407                    ),
408                );
409                self.imp().tick_callback.replace(Some(source_id));
410            }
411            _ => (),
412        };
413    }
414
415    fn regenerate_otp(&self) {
416        let accounts = self.accounts();
417        for i in 0..accounts.n_items() {
418            let item = accounts.item(i).unwrap();
419            let account = item.downcast_ref::<Account>().unwrap();
420            account.generate_otp();
421        }
422    }
423
424    pub fn has_accounts(&self) -> bool {
425        self.accounts_model().n_items() != 0
426    }
427
428    fn add_accounts(&self, accounts: &[Account]) {
429        self.accounts_model().splice(accounts);
430        self.setup_tick_callback();
431    }
432
433    pub fn add_account(&self, account: &Account) {
434        self.accounts_model().append(account);
435        self.setup_tick_callback();
436    }
437
438    fn tokenize_search(account_name: &str, provider_name: &str, term: &str) -> bool {
439        let term = term.to_ascii_lowercase();
440        let provider_name = provider_name.to_ascii_lowercase();
441        let account_name = account_name.to_ascii_lowercase();
442
443        account_name.split_ascii_whitespace().any(|x| x == term)
444            || provider_name.split_ascii_whitespace().any(|x| x == term)
445            || account_name.contains(term.as_str())
446            || provider_name.contains(term.as_str())
447    }
448
449    pub fn find_accounts(&self, terms: &[String]) -> Vec<Account> {
450        let mut results = vec![];
451        let model = self.accounts_model();
452        let provider_name = self.name();
453        for pos in 0..model.n_items() {
454            let account = model.item(pos).and_downcast::<Account>().unwrap();
455            let account_name = account.name();
456
457            if terms
458                .iter()
459                .any(|term| Self::tokenize_search(&account_name, &provider_name, term))
460            {
461                results.push(account);
462            }
463        }
464        results
465    }
466
467    pub fn accounts(&self) -> &gtk::FilterListModel {
468        &self.imp().filter_model
469    }
470
471    pub fn filter(&self, text: String) {
472        let filter = gtk::CustomFilter::new(glib::clone!(
473            #[weak(rename_to = provider)]
474            self,
475            #[upgrade_or]
476            false,
477            move |obj| {
478                let account = obj.downcast_ref::<Account>().unwrap();
479                let account_name = account.name();
480                let provider_name = provider.name();
481
482                Self::tokenize_search(&account_name, &provider_name, &text)
483            }
484        ));
485        self.imp().filter_model.set_filter(Some(&filter));
486    }
487
488    pub fn remove_account(&self, account: &Account) {
489        let imp = self.imp();
490        let model = self.accounts_model();
491        if let Some(pos) = model.find_position_by_id(account.id()) {
492            model.remove(pos);
493            if !self.has_accounts() && self.method().is_time_based() {
494                // Stop ticking
495                if let Some(source_id) = imp.tick_callback.borrow_mut().take() {
496                    source_id.remove();
497                }
498            }
499        }
500    }
501}
502
503impl From<DieselProvider> for Provider {
504    fn from(p: DieselProvider) -> Self {
505        Self::new(
506            p.id as u32,
507            &p.name,
508            p.period as u32,
509            p.method.parse::<Method>().unwrap(),
510            p.algorithm.parse::<Algorithm>().unwrap(),
511            p.digits as u32,
512            p.default_counter as u32,
513            p.website,
514            p.help_url,
515            p.image_uri,
516        )
517    }
518}
519
520impl From<&Provider> for DieselProvider {
521    fn from(p: &Provider) -> Self {
522        Self {
523            id: p.id() as i32,
524            name: p.name(),
525            period: p.period() as i32,
526            method: p.method().to_string(),
527            algorithm: p.algorithm().to_string(),
528            digits: p.digits() as i32,
529            default_counter: p.default_counter() as i32,
530            website: p.website(),
531            help_url: p.help_url(),
532            image_uri: p.image_uri(),
533        }
534    }
535}