33import json
44import logging
55import os
6- import threading
76import time
87from datetime import datetime
98from functools import cmp_to_key
109from typing import List
1110from typing import Optional
1211
1312import requests
13+ from readerwriterlock import rwlock
1414
1515from cryptojwt .jwk .ec import NIST2SEC
1616from cryptojwt .jwk .hmac import new_sym_key
4747
4848MAP = {"dec" : "enc" , "enc" : "enc" , "ver" : "sig" , "sig" : "sig" }
4949
50- update_lock = threading .Lock ()
51-
5250
5351def harmonize_usage (use ):
5452 """
@@ -153,6 +151,14 @@ def ec_init(spec):
153151 return _kb
154152
155153
154+ def keys_writer (func ):
155+ def wrapper (self , * args , ** kwargs ):
156+ with self ._lock_writer :
157+ return func (self , * args , ** kwargs )
158+
159+ return wrapper
160+
161+
156162class KeyBundle :
157163 """The Key Bundle"""
158164
@@ -230,6 +236,10 @@ def __init__(
230236 self .source = None
231237 self .time_out = 0
232238
239+ self ._lock = rwlock .RWLockFairD ()
240+ self ._lock_reader = self ._lock .gen_rlock ()
241+ self ._lock_writer = self ._lock .gen_wlock ()
242+
233243 if httpc :
234244 self .httpc = httpc
235245 else :
@@ -500,6 +510,7 @@ def _uptodate(self):
500510 return self .update ()
501511 return False
502512
513+ @keys_writer
503514 def update (self ):
504515 """
505516 Reload the keys if necessary.
@@ -510,35 +521,34 @@ def update(self):
510521 :return: True if update was ok or False if we encountered an error during update.
511522 """
512523 if self .source :
513- with update_lock :
514- _old_keys = self ._keys # just in case
524+ _old_keys = self ._keys # just in case
515525
516- # reread everything
517- self ._keys = []
518- updated = None
526+ # reread everything
527+ self ._keys = []
528+ updated = None
519529
520- try :
521- if self .local :
522- if self .fileformat in ["jwks" , "jwk" ]:
523- updated = self .do_local_jwk (self .source )
524- elif self .fileformat == "der" :
525- updated = self .do_local_der (self .source , self .keytype , self .keyusage )
526- elif self .remote :
527- updated = self .do_remote ()
528- except Exception as err :
529- LOGGER .error ("Key bundle update failed: %s" , err )
530- self ._keys = _old_keys # restore
531- return False
532-
533- if updated :
534- now = time .time ()
535- for _key in _old_keys :
536- if _key not in self ._keys :
537- if not _key .inactive_since : # If already marked don't mess
538- _key .inactive_since = now
539- self ._keys .append (_key )
540- else :
541- self ._keys = _old_keys
530+ try :
531+ if self .local :
532+ if self .fileformat in ["jwks" , "jwk" ]:
533+ updated = self .do_local_jwk (self .source )
534+ elif self .fileformat == "der" :
535+ updated = self .do_local_der (self .source , self .keytype , self .keyusage )
536+ elif self .remote :
537+ updated = self .do_remote ()
538+ except Exception as err :
539+ LOGGER .error ("Key bundle update failed: %s" , err )
540+ self ._keys = _old_keys # restore
541+ return False
542+
543+ if updated :
544+ now = time .time ()
545+ for _key in _old_keys :
546+ if _key not in self ._keys :
547+ if not _key .inactive_since : # If already marked don't mess
548+ _key .inactive_since = now
549+ self ._keys .append (_key )
550+ else :
551+ self ._keys = _old_keys
542552
543553 return True
544554
@@ -551,32 +561,34 @@ def get(self, typ="", only_active=True):
551561 otherwise the appropriate keys in a list
552562 """
553563 self ._uptodate ()
554- _typs = [typ .lower (), typ .upper ()]
555564
556- if typ :
557- _keys = [k for k in self ._keys if k .kty in _typs ]
558- else :
559- _keys = self ._keys
565+ with self ._lock_reader :
566+ if typ :
567+ _typs = [typ .lower (), typ .upper ()]
568+ _keys = [k for k in self ._keys if k .kty in _typs ]
569+ else :
570+ _keys = self ._keys
560571
561572 if only_active :
562573 return [k for k in _keys if not k .inactive_since ]
563574
564575 return _keys
565576
566- def keys (self ):
577+ def keys (self , update : bool = True ):
567578 """
568579 Return all keys after having updated them
569580
570581 :return: List of all keys
571582 """
572- self ._uptodate ()
573-
574- return self ._keys
583+ if update :
584+ self ._uptodate ()
585+ with self ._lock_reader :
586+ return self ._keys
575587
576588 def active_keys (self ):
577589 """Return the set of active keys."""
578590 _res = []
579- for k in self ._keys :
591+ for k in self .keys () :
580592 try :
581593 ias = k .inactive_since
582594 except ValueError :
@@ -586,6 +598,7 @@ def active_keys(self):
586598 _res .append (k )
587599 return _res
588600
601+ @keys_writer
589602 def remove_keys_by_type (self , typ ):
590603 """
591604 Remove keys that are of a specific type.
@@ -605,9 +618,8 @@ def jwks(self, private=False):
605618 :param private: Whether private key information should be included.
606619 :return: A JWKS JSON representation of the keys in this bundle
607620 """
608- self ._uptodate ()
609621 keys = list ()
610- for k in self ._keys :
622+ for k in self .keys () :
611623 if private :
612624 key = k .serialize (private )
613625 else :
@@ -617,6 +629,7 @@ def jwks(self, private=False):
617629 keys .append (key )
618630 return json .dumps ({"keys" : keys })
619631
632+ @keys_writer
620633 def append (self , key ):
621634 """
622635 Add a key to list of keys in this bundle
@@ -625,10 +638,12 @@ def append(self, key):
625638 """
626639 self ._keys .append (key )
627640
641+ @keys_writer
628642 def extend (self , keys ):
629643 """Add a key to the list of keys."""
630644 self ._keys .extend (keys )
631645
646+ @keys_writer
632647 def remove (self , key ):
633648 """
634649 Remove a specific key from this bundle
@@ -648,6 +663,7 @@ def __len__(self):
648663 """
649664 return len (self ._keys )
650665
666+ @keys_writer
651667 def set (self , keys ):
652668 """Set the keys to the set provided."""
653669 self ._keys = keys
@@ -659,13 +675,15 @@ def get_key_with_kid(self, kid):
659675 :param kid: The Key ID
660676 :return: The key or None
661677 """
678+ self ._uptodate ()
679+ with self ._lock_reader :
680+ return self ._get_key_with_kid (kid )
681+
682+ def _get_key_with_kid (self , kid ):
662683 for key in self ._keys :
663684 if key .kid == kid :
664685 return key
665686
666- # Try updating since there might have been an update to the key file
667- self .update ()
668-
669687 for key in self ._keys :
670688 if key .kid == kid :
671689 return key
@@ -680,16 +698,16 @@ def kids(self):
680698 The reason might be that there are some keys with no key ID.
681699 :return: A list of all the key IDs that exists in this bundle
682700 """
683- self ._uptodate ()
684- return [key .kid for key in self ._keys if key .kid != "" ]
701+ return [key .kid for key in self .keys () if key .kid != "" ]
685702
703+ @keys_writer
686704 def mark_as_inactive (self , kid ):
687705 """
688706 Mark a specific key as inactive based on the keys KeyID.
689707
690708 :param kid: The Key Identifier
691709 """
692- k = self .get_key_with_kid (kid )
710+ k = self ._get_key_with_kid (kid )
693711 if k :
694712 self ._keys .remove (k )
695713 k .inactive_since = time .time ()
@@ -698,17 +716,19 @@ def mark_as_inactive(self, kid):
698716 else :
699717 return False
700718
719+ @keys_writer
701720 def mark_all_as_inactive (self ):
702721 """
703722 Mark a specific key as inactive based on the keys KeyID.
704723 """
705- _keys = self .keys ()
724+ _keys = self ._keys
706725 _updated = []
707726 for k in _keys :
708727 k .inactive_since = time .time ()
709728 _updated .append (k )
710729 self ._keys = _updated
711730
731+ @keys_writer
712732 def remove_outdated (self , after , when = 0 ):
713733 """
714734 Remove keys that should not be available any more.
@@ -775,7 +795,7 @@ def difference(self, bundle):
775795 if not isinstance (bundle , KeyBundle ):
776796 return ValueError ("Not a KeyBundle instance" )
777797
778- return [k for k in self ._keys if k not in bundle ]
798+ return [k for k in self .keys () if k not in bundle ]
779799
780800 def dump (self , exclude_attributes : Optional [List [str ]] = None ):
781801 if exclude_attributes is None :
@@ -785,7 +805,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None):
785805
786806 if "keys" not in exclude_attributes :
787807 _keys = []
788- for _k in self ._keys :
808+ for _k in self .keys ( update = False ) :
789809 _ser = _k .to_dict ()
790810 if _k .inactive_since :
791811 _ser ["inactive_since" ] = _k .inactive_since
@@ -819,6 +839,7 @@ def load(self, spec):
819839
820840 return self
821841
842+ @keys_writer
822843 def flush (self ):
823844 self ._keys = []
824845 self .cache_time = (300 ,)
0 commit comments