@@ -318,7 +318,11 @@ def do_local_jwk(self, filename):
318318 Load a JWKS from a local file
319319
320320 :param filename: Name of the file from which the JWKS should be loaded
321+ :return: True if load was successful or False if file hasn't been modified
321322 """
323+ if not self ._local_update_required ():
324+ return False
325+
322326 LOGGER .info ("Reading local JWKS from %s" , filename )
323327 with open (filename ) as input_file :
324328 _info = json .load (input_file )
@@ -328,6 +332,7 @@ def do_local_jwk(self, filename):
328332 self .do_keys ([_info ])
329333 self .last_local = time .time ()
330334 self .time_out = self .last_local + self .cache_time
335+ return True
331336
332337 def do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
333338 """
@@ -336,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
336341 :param filename: Name of the file
337342 :param keytype: Presently 'rsa' and 'ec' supported
338343 :param keyusage: encryption ('enc') or signing ('sig') or both
344+ :return: True if load was successful or False if file hasn't been modified
339345 """
346+ if not self ._local_update_required ():
347+ return False
348+
340349 LOGGER .info ("Reading local DER from %s" , filename )
341350 key_args = {}
342351 _kty = keytype .lower ()
@@ -359,12 +368,13 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
359368 self .do_keys ([key_args ])
360369 self .last_local = time .time ()
361370 self .time_out = self .last_local + self .cache_time
371+ return True
362372
363373 def do_remote (self ):
364374 """
365375 Load a JWKS from a webpage.
366376
367- :return: True or False if load was successful
377+ :return: True if load was successful or False if remote hasn't been modified
368378 """
369379 # if self.verify_ssl is not None:
370380 # self.httpc_params["verify"] = self.verify_ssl
@@ -390,7 +400,10 @@ def do_remote(self):
390400 LOGGER .error (err )
391401 raise UpdateFailed (REMOTE_FAILED .format (self .source , str (err )))
392402
393- if _http_resp .status_code == 200 : # New content
403+ load_successful = _http_resp .status_code == 200
404+ not_modified = _http_resp .status_code == 304
405+
406+ if load_successful :
394407 self .time_out = time .time () + self .cache_time
395408
396409 self .imp_jwks = self ._parse_remote_response (_http_resp )
@@ -408,11 +421,9 @@ def do_remote(self):
408421 if hasattr (_http_resp , "headers" ):
409422 headers = getattr (_http_resp , "headers" )
410423 self .last_remote = headers .get ("last-modified" ) or headers .get ("date" )
411-
412- elif _http_resp .status_code == 304 : # Not modified
424+ elif not_modified :
413425 LOGGER .debug ("%s not modified since %s" , self .source , self .last_remote )
414426 self .time_out = time .time () + self .cache_time
415-
416427 else :
417428 LOGGER .warning (
418429 "HTTP status %d reading remote JWKS from %s" ,
@@ -424,7 +435,7 @@ def do_remote(self):
424435
425436 self .last_updated = time .time ()
426437 self .ignore_errors_until = None
427- return True
438+ return load_successful
428439
429440 def _parse_remote_response (self , response ):
430441 """
@@ -449,23 +460,20 @@ def _parse_remote_response(self, response):
449460 return None
450461
451462 def _uptodate (self ):
452- res = False
453463 if self .remote or self .local :
454464 if time .time () > self .time_out :
455- if self .local and not self ._local_update_required ():
456- res = True
457- elif self .update ():
458- res = True
459- return res
465+ return self .update ()
466+ return False
460467
461468 def update (self ):
462469 """
463470 Reload the keys if necessary.
464471
465472 This is a forced update, will happen even if cache time has not elapsed.
466473 Replaced keys will be marked as inactive and not removed.
474+
475+ :return: True if update was ok or False if we encountered an error during update.
467476 """
468- res = True # An update was successful
469477 if self .source :
470478 _old_keys = self ._keys # just in case
471479
@@ -475,24 +483,27 @@ def update(self):
475483 try :
476484 if self .local :
477485 if self .fileformat in ["jwks" , "jwk" ]:
478- self .do_local_jwk (self .source )
486+ updated = self .do_local_jwk (self .source )
479487 elif self .fileformat == "der" :
480- self .do_local_der (self .source , self .keytype , self .keyusage )
488+ updated = self .do_local_der (self .source , self .keytype , self .keyusage )
481489 elif self .remote :
482- res = self .do_remote ()
490+ updated = self .do_remote ()
483491 except Exception as err :
484492 LOGGER .error ("Key bundle update failed: %s" , err )
485493 self ._keys = _old_keys # restore
486494 return False
487495
488- now = time .time ()
489- for _key in _old_keys :
490- if _key not in self ._keys :
491- if not _key .inactive_since : # If already marked don't mess
492- _key .inactive_since = now
493- self ._keys .append (_key )
496+ if updated :
497+ now = time .time ()
498+ for _key in _old_keys :
499+ if _key not in self ._keys :
500+ if not _key .inactive_since : # If already marked don't mess
501+ _key .inactive_since = now
502+ self ._keys .append (_key )
503+ else :
504+ self ._keys = _old_keys
494505
495- return res
506+ return True
496507
497508 def get (self , typ = "" , only_active = True ):
498509 """
0 commit comments