Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions openfe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
# silence pymbar logging warnings
import logging
# We need to do this first so that we can set up our
# log control since some modules have warnings on import
from openfe.utils.logging_control import LogControl

LogControl.silence_message(
msg=[
"****** PyMBAR will use 64-bit JAX! *******",
],
logger_names=[
"pymbar.mbar_solvers",
],
)

def _mute_timeseries(record):
return not "Warning on use of the timeseries module:" in record.msg

LogControl.silence_message(
msg=[
"Warning on use of the timeseries module:",
],
logger_names=[
"pymbar.timeseries",
],
)

def _mute_jax(record):
return not "****** PyMBAR will use 64-bit JAX! *******" in record.msg
LogControl.append_logger(
suffix="\n \n[OPENFE]: See this url for more information about the warning above\n",
logger_names="jax._src.xla_bridge",
)

# These two lines are just to test the append_logger and will be removed before
# the PR is merged
from jax._src.xla_bridge import backends

_mbar_log = logging.getLogger("pymbar.timeseries")
_mbar_log.addFilter(_mute_timeseries)
_mbar_log = logging.getLogger("pymbar.mbar_solvers")
_mbar_log.addFilter(_mute_jax)
backends()

from importlib.metadata import version

Expand Down
10 changes: 7 additions & 3 deletions openfe/protocols/openmm_utils/multistate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from openff.units import Quantity, unit
from openff.units.openmm import from_openmm
from openmmtools import multistate
from pymbar import MBAR
from pymbar.utils import ParameterError

from openfe.analysis import plotting
from openfe.due import Doi, due
Expand Down Expand Up @@ -236,8 +234,10 @@ def _get_free_energy(
* Allow folks to pass in extra options for bootstrapping etc..
* Add standard test against analyzer.get_free_energy()
"""
# pymbar has some side effects when imported so we only import it right when we
# need it
from pymbar import MBAR

# pymbar 4
mbar = MBAR(
u_ln,
N_l,
Expand Down Expand Up @@ -312,6 +312,10 @@ def get_forward_and_reverse_analysis(
issues with the solver when using low amounts of data points. All
uncertainties are MBAR analytical errors.
"""
# pymbar has some side effects from being imported, so we only want to import
# it right when we need it
from pymbar.utils import ParameterError

try:
u_ln = self.analyzer._unbiased_decorrelated_u_ln
N_l = self.analyzer._unbiased_decorrelated_N_l
Expand Down
1 change: 1 addition & 0 deletions openfe/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/openfe

from . import custom_typing
from .logging_control import LogControl
from .optional_imports import requires_package
from .remove_oechem import without_oechem_backend
from .system_probe import log_system_probe
175 changes: 175 additions & 0 deletions openfe/utils/logging_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import logging
from abc import ABC, abstractmethod


class BaseLogFilter(ABC):
"""Base class for log filters that handle string or list of strings.

Parameters
----------
strings : str or list of str
String(s) to use in the filter logic
"""

def __init__(self, strings: str | list[str]) -> None:
if isinstance(strings, str):
strings = [strings]
self.strings: list[str] = strings

@abstractmethod
def filter(self, record: logging.LogRecord) -> bool:
"""Filter method to be implemented by subclasses.

Parameters
----------
record : logging.LogRecord
Log record to filter/modify

Returns
-------
bool
True to allow the record, False to block it
"""
...


class MsgIncludesStringFilter(BaseLogFilter):
"""Logging filter to silence specific log messages.

See https://docs.python.org/3/library/logging.html#filter-objects

Parameters
----------
strings : str or list of str
If string(s) match in log messages (substring match) then the log record
is suppressed
"""

def filter(self, record: logging.LogRecord) -> bool:
"""Filter log records that contain any of the specified strings.

Parameters
----------
record : logging.LogRecord
Log record to filter

Returns
-------
bool
False if the record should be blocked, True if it should be logged
"""
for string in self.strings:
if string in record.msg:
return False
return True


class AppendMsgFilter(BaseLogFilter):
"""Logging filter to append a message to a specific log message.

See https://docs.python.org/3/library/logging.html#filter-objects

Parameters
----------
strings : str or list of str
Suffix text(s) to append to log messages
"""

def __init__(self, strings: str | list[str]) -> None:
super().__init__(strings)
# Rename for clarity in this context
self.suffixes = self.strings

def filter(self, record: logging.LogRecord) -> bool:
"""Append suffix to log record message.

Parameters
----------
record : logging.LogRecord
Log record to modify

Returns
-------
bool
Always True to allow the record to be logged
"""
for suffix in self.suffixes:
# Only modify if not already appended (idempotent)
if not record.msg.endswith(suffix):
record.msg = f"{record.msg}{suffix}"
return True


class LogControl:
"""Easy-to-use logging control for third-party packages."""

@staticmethod
def silence_message(msg: str | list[str], logger_names: str | list[str]) -> None:
"""Silence specific log messages from one or more loggers.

Parameters
----------
msg : str or list of str
String(s) to match in log messages (substring match)
logger_names : str or list of str
Logger name(s) to apply the filter to

Examples
--------
>>> LogControl.silence_message(
... msg="****** PyMBAR will use 64-bit JAX! *******",
... logger_names=["pymbar.timeseries", "pymbar.mbar_solvers"]
... )
"""
if isinstance(logger_names, str):
logger_names = [logger_names]

filter_obj = MsgIncludesStringFilter(msg)
for name in logger_names:
logging.getLogger(name).addFilter(filter_obj)

@staticmethod
def silence_logger(logger_names: str | list[str], level: int = logging.CRITICAL) -> None:
"""Completely silence one or more loggers.

Parameters
----------
logger_names : str or list of str
Logger name(s) to silence
level : int
Set logger level (default: CRITICAL to silence everything)

Examples
--------
>>> LogControl.silence_logger(logger_names=["urllib3", "requests"])
"""
if isinstance(logger_names, str):
logger_names = [logger_names]

for name in logger_names:
logging.getLogger(name).setLevel(level)

@staticmethod
def append_logger(suffix: str | list[str], logger_names: str | list[str]) -> None:
"""Append text to logger messages.

Parameters
----------
suffix : str or list of str
Suffix text to append to log messages
logger_names : str or list of str
Logger name(s) to modify

Examples
--------
>>> LogControl.append_logger(
... suffix=" [DEPRECATED]",
... logger_names="myapp"
... )
"""
if isinstance(logger_names, str):
logger_names = [logger_names]

filter_obj = AppendMsgFilter(suffix)
for name in logger_names:
logging.getLogger(name).addFilter(filter_obj)
20 changes: 0 additions & 20 deletions openfe/utils/logging_filter.py

This file was deleted.

19 changes: 10 additions & 9 deletions openfecli/commands/quickrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def quickrun(transformation, work_dir, output):
from gufe.tokenization import JSON_HANDLER
from gufe.transformations.transformation import Transformation

from openfe.utils.logging_filter import MsgIncludesStringFilter
from openfe.utils import LogControl

# avoid problems with output not showing if queueing system kills a job
sys.stdout.reconfigure(line_buffering=True)
Expand All @@ -68,15 +68,16 @@ def quickrun(transformation, work_dir, output):
configure_logger("openfe", handler=stdout_handler)

# silence the openmmtools.multistate API warning
stfu = MsgIncludesStringFilter(
"The openmmtools.multistate API is experimental and may change in future releases"
LogControl.silence_message(
msg=[
"The openmmtools.multistate API is experimental and may change in future releases",
],
logger_names=[
"openmmtools.multistate.multistatereporter",
"openmmtools.multistate.multistateanalyzer",
"openmmtools.multistate.multistatesampler",
],
)
omm_multistate = "openmmtools.multistate"
modules = ["multistatereporter", "multistateanalyzer", "multistatesampler"]
for module in modules:
ms_log = logging.getLogger(omm_multistate + "." + module)
ms_log.addFilter(stfu)

# turn warnings into log message (don't show stack trace)
logging.captureWarnings(True)

Expand Down
Loading