diff --git a/docs/source/concurrency.rst b/docs/source/concurrency.rst index 3e04e246..6a5becf1 100644 --- a/docs/source/concurrency.rst +++ b/docs/source/concurrency.rst @@ -11,7 +11,7 @@ In the case of properties, the HTTP response is only returned once the `.Thing` Many of the functions that handle HTTP requests are asynchronous, running in an :mod:`anyio` event loop. This enables many HTTP connections to be handled at once with good efficiency. The `anyio documentation`_ describes the functions that link between async and threaded code. When the LabThings server is started, we create an :class:`anyio.from_thread.BlockingPortal`, which allows threaded code to run code asynchronously in the event loop. -An action can obtain the blocking portal using the `~labthings_fastapi.dependencies.blocking_portal.BlockingPortal` dependency, i.e. by declaring an argument of that type. This avoids referring to the blocking portal through a global variable, which could lead to confusion if there are multiple event loops, e.g. during testing. +An action can run async code using its server interface. See `.ThingServerInterface.start_async_task_soon` for details. There are relatively few occasions when `.Thing` code will need to consider this explicitly: more usually the blocking portal will be obtained by a LabThings function, for example the `.MJPEGStream` class. diff --git a/docs/source/dependencies/example.py b/docs/source/dependencies/example.py index d170178b..315611e2 100644 --- a/docs/source/dependencies/example.py +++ b/docs/source/dependencies/example.py @@ -5,7 +5,7 @@ import labthings_fastapi as lt from labthings_fastapi.example_things import MyThing -MyThingClient = lt.deps.direct_thing_client_class(MyThing, "/mything/") +MyThingClient = lt.deps.direct_thing_client_class(MyThing, "mything") MyThingDep = Annotated[MyThingClient, Depends()] @@ -19,8 +19,8 @@ def increment_counter(self, my_thing: MyThingDep) -> None: server = lt.ThingServer() -server.add_thing(MyThing(), "/mything/") -server.add_thing(TestThing(), "/testthing/") +server.add_thing("mything", MyThing) +server.add_thing("testthing", TestThing) if __name__ == "__main__": import uvicorn diff --git a/docs/source/quickstart/counter.py b/docs/source/quickstart/counter.py index 8a2c84f1..8e4b566d 100644 --- a/docs/source/quickstart/counter.py +++ b/docs/source/quickstart/counter.py @@ -34,7 +34,7 @@ def slowly_increase_counter(self) -> None: server = lt.ThingServer() # The line below creates a TestThing instance and adds it to the server - server.add_thing(TestThing(), "/counter/") + server.add_thing("counter", TestThing) # We run the server using `uvicorn`: uvicorn.run(server.app, port=5000) diff --git a/docs/source/tutorial/writing_a_thing.rst b/docs/source/tutorial/writing_a_thing.rst index 9bec6327..defe6516 100644 --- a/docs/source/tutorial/writing_a_thing.rst +++ b/docs/source/tutorial/writing_a_thing.rst @@ -30,9 +30,8 @@ Our first Thing will pretend to be a light: we can set its brightness and turn i self.is_on = not self.is_on - light = Light() server = lt.ThingServer() - server.add_thing("/light", light) + server.add_thing("light", Light) if __name__ == "__main__": import uvicorn diff --git a/pyproject.toml b/pyproject.toml index 918152c2..47cbd943 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,9 @@ addopts = [ "--cov-report=html:htmlcov", "--cov-report=lcov", ] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", +] [tool.ruff] target-version = "py310" diff --git a/src/labthings_fastapi/__init__.py b/src/labthings_fastapi/__init__.py index a77559ef..a7bbabe4 100644 --- a/src/labthings_fastapi/__init__.py +++ b/src/labthings_fastapi/__init__.py @@ -20,6 +20,7 @@ """ from .thing import Thing +from .thing_server_interface import ThingServerInterface from .properties import property, setting, DataProperty, DataSetting from .decorators import ( thing_action, @@ -30,7 +31,6 @@ from .outputs import blob from .server import ThingServer, cli from .client import ThingClient -from .utilities import get_blocking_portal # The symbols in __all__ are part of our public API. # They are imported when using `import labthings_fastapi as lt`. @@ -40,6 +40,7 @@ # re-export style, we may switch in the future. __all__ = [ "Thing", + "ThingServerInterface", "property", "setting", "DataProperty", @@ -52,5 +53,4 @@ "ThingServer", "cli", "ThingClient", - "get_blocking_portal", ] diff --git a/src/labthings_fastapi/actions/__init__.py b/src/labthings_fastapi/actions/__init__.py index d1ca3844..098a92fc 100644 --- a/src/labthings_fastapi/actions/__init__.py +++ b/src/labthings_fastapi/actions/__init__.py @@ -270,14 +270,14 @@ def run(self) -> None: # self.action evaluates to an ActionDescriptor. This confuses mypy, # which thinks we are calling ActionDescriptor.__get__. action: ActionDescriptor = self.action # type: ignore[call-overload] + # Create a logger just for this invocation, keyed to the invocation id + # Logs that go to this logger will be copied into `self._log` + handler = DequeLogHandler(dest=self._log) + logger = invocation_logger(self.id) + logger.addHandler(handler) try: action.emit_changed_event(self.thing, self._status.value) - # Capture just this thread's log messages - handler = DequeLogHandler(dest=self._log) - logger = invocation_logger(self.id) - logger.addHandler(handler) - thing = self.thing kwargs = model_to_dict(self.input) if thing is None: # pragma: no cover diff --git a/src/labthings_fastapi/client/in_server.py b/src/labthings_fastapi/client/in_server.py index d81de108..bfccb4c5 100644 --- a/src/labthings_fastapi/client/in_server.py +++ b/src/labthings_fastapi/client/in_server.py @@ -51,8 +51,8 @@ class DirectThingClient: __globals__ = globals() # "bake in" globals so dependency injection works thing_class: type[Thing] """The class of the underlying `.Thing` we are wrapping.""" - thing_path: str - """The path to the Thing on the server. Relative to the server's base URL.""" + thing_name: str + """The name of the Thing on the server.""" def __init__(self, request: Request, **dependencies: Mapping[str, Any]) -> None: r"""Wrap a `.Thing` so it works like a `.ThingClient`. @@ -70,7 +70,7 @@ def __init__(self, request: Request, **dependencies: Mapping[str, Any]) -> None: such as access to other `.Things`. """ server = find_thing_server(request.app) - self._wrapped_thing = server.things[self.thing_path] + self._wrapped_thing = server.things[self.thing_name] self._request = request self._dependencies = dependencies @@ -254,7 +254,7 @@ def add_property( def direct_thing_client_class( thing_class: type[Thing], - thing_path: str, + thing_name: str, actions: Optional[list[str]] = None, ) -> type[DirectThingClient]: r"""Create a DirectThingClient from a Thing class and a path. @@ -262,7 +262,7 @@ def direct_thing_client_class( This is a class, not an instance: it's designed to be a FastAPI dependency. :param thing_class: The `.Thing` subclass that will be wrapped. - :param thing_path: The path where the `.Thing` is found on the server. + :param thing_name: The name of the `.Thing` on the server. :param actions: An optional list giving a subset of actions that will be accessed. If this is specified, it may reduce the number of FastAPI dependencies we need. @@ -291,15 +291,15 @@ def init_proxy( # of `DirectThingClient` with bad results. DirectThingClient.__init__(self, request, **dependencies) - init_proxy.__doc__ = f"""Initialise a client for {thing_class} at {thing_path}""" + init_proxy.__doc__ = f"""Initialise a client for {thing_class}""" # Using a class definition gets confused by the scope of the function # arguments - this is equivalent to a class definition but all the # arguments are evaluated in the right scope. client_attrs = { "thing_class": thing_class, - "thing_path": thing_path, - "__doc__": f"A client for {thing_class} at {thing_path}", + "thing_name": thing_name, + "__doc__": f"A client for {thing_class} named {thing_name}", "__init__": init_proxy, } dependencies: list[inspect.Parameter] = [] diff --git a/src/labthings_fastapi/dependencies/metadata.py b/src/labthings_fastapi/dependencies/metadata.py index 5f87d732..aafb8118 100644 --- a/src/labthings_fastapi/dependencies/metadata.py +++ b/src/labthings_fastapi/dependencies/metadata.py @@ -17,6 +17,12 @@ def thing_states_getter(request: Request) -> Callable[[], Mapping[str, Any]]: """Generate a function to retrieve metadata from all Things in this server. + .. warning:: + + This function is deprecated in favour of the `.ThingServerInterface`, which + is available as a property of every Thing. + See `.ThingServerInterface.get_thing_states` for more information. + This is intended to make it easy for a `.Thing` to summarise the other `.Things` in the same server, as is often appropriate when embedding metadata in data files. For example, it's used to populate the ``UserComment`` @@ -64,7 +70,11 @@ def get_metadata() -> dict[str, Any]: GetThingStates = Annotated[ Callable[[], Mapping[str, Any]], Depends(thing_states_getter) ] -"""A ready-made FastAPI dependency, returning a function to collect metadata. +r"""A ready-made FastAPI dependency, returning a function to collect metadata. + +.. warning:: + + This dependency is deprecated in favour of the `.ThingServerInterface`\ . This calls `.thing_states_getter` to provide a function that supplies a dictionary of metadata. It describes the state of all `.Thing` instances on diff --git a/src/labthings_fastapi/descriptors/action.py b/src/labthings_fastapi/descriptors/action.py index 79e61904..1c3fa332 100644 --- a/src/labthings_fastapi/descriptors/action.py +++ b/src/labthings_fastapi/descriptors/action.py @@ -33,7 +33,7 @@ from ..outputs.blob import BlobIOContextDep from ..thing_description import type_to_dataschema from ..thing_description._model import ActionAffordance, ActionOp, Form -from ..utilities import labthings_data, get_blocking_portal +from ..utilities import labthings_data from ..exceptions import NotConnectedToServerError if TYPE_CHECKING: @@ -200,19 +200,8 @@ def emit_changed_event(self, obj: Thing, status: str) -> None: :param obj: The `.Thing` on which the action is being observed. :param status: The status of the action, to be sent to observers. - - :raise NotConnectedToServerError: if the Thing calling the action is not - connected to a server with a running event loop. """ - runner = get_blocking_portal(obj) - if not runner: - thing_name = obj.__class__.__name__ - msg = ( - f"Cannot emit action changed event. Is {thing_name} connected to " - "a running server?" - ) - raise NotConnectedToServerError(msg) - runner.start_task_soon( + obj._thing_server_interface.start_async_task_soon( self.emit_changed_event_async, obj, status, diff --git a/src/labthings_fastapi/example_things/__init__.py b/src/labthings_fastapi/example_things/__init__.py index 9d6daba3..7a52f241 100644 --- a/src/labthings_fastapi/example_things/__init__.py +++ b/src/labthings_fastapi/example_things/__init__.py @@ -137,11 +137,14 @@ def broken_property(self) -> None: class ThingThatCantInstantiate(Thing): """A Thing that raises an exception in __init__.""" - def __init__(self) -> None: - """Fail to initialise. + def __init__(self, **kwargs: Any) -> None: + r"""Fail to initialise. + + :param \**kwargs: keyword arguments passed to Thing.__init__ :raise RuntimeError: every time. """ + super().__init__(**kwargs) raise RuntimeError("This thing can't be instantiated") diff --git a/src/labthings_fastapi/outputs/mjpeg_stream.py b/src/labthings_fastapi/outputs/mjpeg_stream.py index 740c037e..2aaf05bd 100644 --- a/src/labthings_fastapi/outputs/mjpeg_stream.py +++ b/src/labthings_fastapi/outputs/mjpeg_stream.py @@ -24,11 +24,11 @@ from contextlib import asynccontextmanager import threading import anyio -from anyio.from_thread import BlockingPortal import logging if TYPE_CHECKING: from ..thing import Thing + from ..thing_server_interface import ThingServerInterface @dataclass @@ -126,12 +126,17 @@ class MJPEGStream: of new frames, and then retrieving the frame (shortly) afterwards. """ - def __init__(self, ringbuffer_size: int = 10) -> None: + def __init__( + self, thing_server_interface: ThingServerInterface, ringbuffer_size: int = 10 + ) -> None: """Initialise an MJPEG stream. See the class docstring for `.MJPEGStream`. Note that it will often be initialised by `.MJPEGStreamDescriptor`. + :param thing_server_interface: the `.ThingServerInterface` of the + `.Thing` associated with this stream. It's used to run the async + code that relays frames to open connections. :param ringbuffer_size: The number of frames to retain in memory, to allow retrieval after the frame has been sent. """ @@ -139,6 +144,7 @@ def __init__(self, ringbuffer_size: int = 10) -> None: self.condition = anyio.Condition() self._streaming = False self._ringbuffer: list[RingbufferEntry] = [] + self._thing_server_interface = thing_server_interface self.reset(ringbuffer_size=ringbuffer_size) def reset(self, ringbuffer_size: Optional[int] = None) -> None: @@ -161,18 +167,16 @@ def reset(self, ringbuffer_size: Optional[int] = None) -> None: ] self.last_frame_i = -1 - def stop(self, portal: BlockingPortal) -> None: + def stop(self) -> None: """Stop the stream. Stop the stream and cause all clients to disconnect. - - :param portal: an `anyio.from_thread.BlockingPortal` that allows - this function to use the event loop to notify that the stream - should stop. """ with self._lock: self._streaming = False - portal.start_task_soon(self.notify_stream_stopped) + self._thing_server_interface.start_async_task_soon( + self.notify_stream_stopped + ) async def ringbuffer_entry(self, i: int) -> RingbufferEntry: """Return the ith frame acquired by the camera. @@ -308,7 +312,7 @@ async def mjpeg_stream_response(self) -> MJPEGStreamResponse: """ return MJPEGStreamResponse(self.frame_async_generator()) - def add_frame(self, frame: bytes, portal: BlockingPortal) -> None: + def add_frame(self, frame: bytes) -> None: """Add a JPEG to the MJPEG stream. This function adds a frame to the stream. It may be called from @@ -317,10 +321,6 @@ def add_frame(self, frame: bytes, portal: BlockingPortal) -> None: are handled. :param frame: The frame to add - :param portal: The blocking portal to use for scheduling tasks. - This is necessary because tasks are handled asynchronously. - The blocking portal may be obtained with a dependency, in - `labthings_fastapi.dependencies.blocking_portal.BlockingPortal`. :raise ValueError: if the supplied frame does not start with the JPEG start bytes and end with the end bytes. @@ -337,7 +337,9 @@ def add_frame(self, frame: bytes, portal: BlockingPortal) -> None: entry.timestamp = datetime.now() entry.frame = frame entry.index = self.last_frame_i + 1 - portal.start_task_soon(self.notify_new_frame, entry.index) + self._thing_server_interface.start_async_task_soon( + self.notify_new_frame, entry.index + ) async def notify_new_frame(self, i: int) -> None: """Notify any waiting tasks that a new frame is available. @@ -420,7 +422,10 @@ def __get__( try: return obj.__dict__[self.name] except KeyError: - obj.__dict__[self.name] = MJPEGStream(**self._kwargs) + obj.__dict__[self.name] = MJPEGStream( + **self._kwargs, + thing_server_interface=obj._thing_server_interface, + ) return obj.__dict__[self.name] async def viewer_page(self, url: str) -> HTMLResponse: @@ -452,7 +457,7 @@ class Camera(lt.Thing): server = lt.ThingServer() - server.add_thing(Camera(), "/camera") + server.add_thing("camera", Camera) :param app: the `fastapi.FastAPI` application to which we are being added. :param thing: the host `.Thing` instance. diff --git a/src/labthings_fastapi/properties.py b/src/labthings_fastapi/properties.py index 8601a21e..c24d3699 100644 --- a/src/labthings_fastapi/properties.py +++ b/src/labthings_fastapi/properties.py @@ -85,14 +85,15 @@ class attribute. Documentation is in strings immediately following the # Note on ignored linter codes: # -# D103 refers to missing docstrings. I have ignored this on @overload definitions -# because they shouldn't have docstrings - the docstring belongs only on the -# function they overload. -# D105 is the same as D103, but for __init__ (i.e. magic methods). -# DOC101 and DOC103 are also a result of overloads not having docstrings -# DOC201 is ignored on properties. Because we are overriding the -# builtin `property`, we are using `@builtins.property` which is not recognised -# by pydoclint as a property. I've therefore ignored those codes manually. +# DOC101 and DOC103 are a result of overloads not having docstrings. While +# the related D codes (checked by Ruff) don't flag overloads, pydoclint +# doesn't ignore overloads. This is most likely a pydoclint bug that +# we are working around. +# DOC201 is ignored on properties. +# Because we are overriding the +# builtin `property`, we are using `@builtins.property` which is not +# recognised by pydoclint as a property. I've therefore ignored those +# codes manually. # pydocstyle ("D" codes) is run in Ruff and correctly recognises # builtins.property as a property decorator. @@ -207,19 +208,17 @@ def default_factory() -> Value: # See comment at the top of the file regarding ignored linter rules. @overload # use as a decorator @property -def property( # noqa: D103 +def property( getter: Callable[[Any], Value], ) -> FunctionalProperty[Value]: ... @overload # use as `field: int = property(default=0)` -def property( # noqa: D103 - *, default: Value, readonly: bool = False -) -> Value: ... +def property(*, default: Value, readonly: bool = False) -> Value: ... @overload # use as `field: int = property(default_factory=lambda: 0)` -def property( # noqa: D103 +def property( *, default_factory: Callable[[], Value], readonly: bool = False ) -> Value: ... @@ -480,12 +479,12 @@ class DataProperty(BaseProperty[Value], Generic[Value]): """ @overload - def __init__( # noqa: D105,D107,DOC101,DOC103 + def __init__( # noqa: DOC101,DOC103 self, default: Value, *, readonly: bool = False ) -> None: ... @overload - def __init__( # noqa: D105,D107,DOC101,DOC103 + def __init__( # noqa: DOC101,DOC103 self, *, default_factory: ValueFactory, readonly: bool = False ) -> None: ... @@ -654,19 +653,8 @@ def emit_changed_event(self, obj: Thing, value: Value) -> None: :param obj: the `.Thing` to which we are attached. :param value: the new property value, to be sent to observers. - - :raise NotConnectedToServerError: if the Thing that is calling the property - update is not connected to a server with a running event loop. """ - runner = obj._labthings_blocking_portal - if not runner: - thing_name = obj.__class__.__name__ - msg = ( - f"Cannot emit property updated changed event. Is {thing_name} " - "connected to a running server?" - ) - raise NotConnectedToServerError(msg) - runner.start_task_soon( + obj._thing_server_interface.start_async_task_soon( self.emit_changed_event_async, obj, value, @@ -752,7 +740,8 @@ def setter(self, fset: ValueSetter) -> Self: .. code-block:: python class MyThing(lt.Thing): - def __init__(self): + def __init__(self, thing_server_interface): + super().__init__(thing_server_interface=thing_server_interface) self._myprop: int = 0 @lt.property @@ -833,19 +822,17 @@ def __set__(self, obj: Thing, value: Value) -> None: @overload # use as a decorator @setting -def setting( # noqa: D103 +def setting( getter: Callable[[Any], Value], ) -> FunctionalSetting[Value]: ... @overload # use as `field: int = setting(default=0)`` -def setting( # noqa: D103 - *, default: Value, readonly: bool = False -) -> Value: ... +def setting(*, default: Value, readonly: bool = False) -> Value: ... @overload # use as `field: int = setting(default_factory=lambda: 0)` -def setting( # noqa: D103 +def setting( *, default_factory: Callable[[], Value], readonly: bool = False ) -> Value: ... diff --git a/src/labthings_fastapi/server/__init__.py b/src/labthings_fastapi/server/__init__.py index 34f8add7..8efc3967 100644 --- a/src/labthings_fastapi/server/__init__.py +++ b/src/labthings_fastapi/server/__init__.py @@ -7,7 +7,7 @@ """ from __future__ import annotations -from typing import AsyncGenerator, Optional, Sequence, TypeVar +from typing import Any, AsyncGenerator, Optional, Sequence, TypeVar import os.path import re @@ -23,6 +23,7 @@ ) from ..actions import ActionManager from ..thing import Thing +from ..thing_server_interface import ThingServerInterface from ..thing_description._model import ThingDescription from ..dependencies.thing_server import _thing_servers # noqa: F401 @@ -32,7 +33,10 @@ # A path should be made up of names separated by / as a path separator. # Each name should be made of alphanumeric characters, hyphen, or underscore. # This regex enforces a trailing / -PATH_REGEX = re.compile(r"^/([a-zA-Z0-9\-_]+\/)+$") +PATH_REGEX = re.compile(r"^([a-zA-Z0-9\-_]+)$") + + +ThingSubclass = TypeVar("ThingSubclass", bound=Thing) class ThingServer: @@ -143,35 +147,83 @@ def thing_by_class(self, cls: type[ThingInstance]) -> ThingInstance: f"There are {len(instances)} Things of class {cls}, expected 1." ) - def add_thing(self, thing: Thing, path: str) -> None: - """Add a thing to the server. - - :param thing: The `.Thing` instance to add to the server. - :param path: the relative path to access the thing on the server. Must only - contain alphanumeric characters, hyphens, or underscores. + def add_thing( + self, + name: str, + thing_subclass: type[ThingSubclass], + args: Sequence[Any] | None = None, + kwargs: Mapping[str, Any] | None = None, + ) -> ThingSubclass: + r"""Add a thing to the server. + + This function will create an instance of ``thing_subclass`` and supply + the ``args`` and ``kwargs`` arguments to its ``__init__`` method. That + instance will then be added to the server with the given name. + + :param name: The name to use for the thing. This will be part of the URL + used to access the thing, and must only contain alphanumeric characters, + hyphens and underscores. + :param thing_subclass: The `.Thing` subclass to add to the server. + :param args: positional arguments to pass to the constructor of + ``thing_subclass``\ . + :param kwargs: keyword arguments to pass to the constructor of + ``thing_subclass``\ . + + :returns: the instance of ``thing_subclass`` that was created and added + to the server. There is no need to retain a reference to this, as it + is stored in the server's dictionary of `.Thing` instances. :raise ValueError: if ``path`` contains invalid characters. - :raise KeyError: if a `.Thing` has already been added at ``path``. + :raise KeyError: if a `.Thing` has already been added at ``path``\ . + :raise TypeError: if ``thing_subclass`` is not a subclass of `.Thing` + or if ``name`` is not string-like. This usually means arguments + are being passed the wrong way round. """ - # Ensure leading and trailing / - if not path.endswith("/"): - path += "/" - if not path.startswith("/"): - path = "/" + path - if PATH_REGEX.match(path) is None: + if not isinstance(name, str): + raise TypeError("Thing names must be strings.") + if PATH_REGEX.match(name) is None: msg = ( - f"{path} contains unsafe characters. Use only alphanumeric " + f"'{name}' contains unsafe characters. Use only alphanumeric " "characters, hyphens and underscores" ) raise ValueError(msg) - if path in self._things: - raise KeyError(f"{path} has already been added to this thing server.") - self._things[path] = thing - settings_folder = os.path.join(self.settings_folder, path.lstrip("/")) - os.makedirs(settings_folder, exist_ok=True) + if name in self._things: + raise KeyError(f"{name} has already been added to this thing server.") + if not issubclass(thing_subclass, Thing): + raise TypeError(f"{thing_subclass} is not a Thing subclass.") + if args is None: + args = [] + if kwargs is None: + kwargs = {} + interface = ThingServerInterface(name=name, server=self) + os.makedirs(interface.settings_folder, exist_ok=True) + # This is where we instantiate the Thing + # I've had to ignore this line because the *args causes an error. + # Given that *args and **kwargs are very loosely typed anyway, this + # doesn't lose us much. + thing = thing_subclass( + *args, + **kwargs, + thing_server_interface=interface, + ) # type: ignore[misc] + self._things[name] = thing thing.attach_to_server( - self, path, os.path.join(settings_folder, "settings.json") + server=self, ) + return thing + + def path_for_thing(self, name: str) -> str: + """Return the path for a thing with the given name. + + :param name: The name of the thing, as passed to `.add_thing`. + + :return: The path at which the thing is served. + + :raise KeyError: if no thing with the given name has been added. + """ + if name not in self._things: + raise KeyError(f"No thing named {name} has been added to this server.") + return f"/{name}/" @asynccontextmanager async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]: @@ -192,19 +244,11 @@ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]: :param app: The FastAPI application wrapped by the server. :yield: no value. The FastAPI application will serve requests while this function yields. - - :raises RuntimeError: if a `.Thing` already has a blocking portal attached. - This should never happen, and suggests the server is being used to - serve a `.Thing` that is already being served elsewhere. """ async with BlockingPortal() as portal: + # We create a blocking portal to allow threaded code to call async code + # in the event loop. self.blocking_portal = portal - # We attach a blocking portal to each thing, so that threaded code can - # make callbacks to async code (needed for events etc.) - for thing in self.things.values(): - if thing._labthings_blocking_portal is not None: - raise RuntimeError("Things may only ever have one blocking portal") - thing._labthings_blocking_portal = portal # we __aenter__ and __aexit__ each Thing, which will in turn call the # synchronous __enter__ and __exit__ methods if they exist, to initialise # and shut down the hardware. NB we must make sure the blocking portal @@ -213,9 +257,6 @@ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None]: for thing in self.things.values(): await stack.enter_async_context(thing) yield - for _name, thing in self.things.items(): - # Remove the blocking portal - the event loop is about to stop. - thing._labthings_blocking_portal = None self.blocking_portal = None @@ -277,10 +318,9 @@ def server_from_config(config: dict) -> ThingServer: :raise ImportError: if a Thing could not be loaded from the specified object reference. - :raise TypeError: if a class is specified that does not subclass `.Thing`\ . """ server = ThingServer(config.get("settings_folder", None)) - for path, thing in config.get("things", {}).items(): + for name, thing in config.get("things", {}).items(): if isinstance(thing, str): thing = {"class": thing} try: @@ -288,10 +328,12 @@ def server_from_config(config: dict) -> ThingServer: except ImportError as e: raise ImportError( f"Could not import {thing['class']}, which was " - f"specified as the class for {path}." + f"specified as the class for {name}." ) from e - instance = cls(*thing.get("args", {}), **thing.get("kwargs", {})) - if not isinstance(instance, Thing): - raise TypeError(f"{thing['class']} is not a Thing") - server.add_thing(instance, path) + server.add_thing( + name=name, + thing_subclass=cls, + args=thing.get("args", ()), + kwargs=thing.get("kwargs", {}), + ) return server diff --git a/src/labthings_fastapi/thing.py b/src/labthings_fastapi/thing.py index c2383179..fd49fcdc 100644 --- a/src/labthings_fastapi/thing.py +++ b/src/labthings_fastapi/thing.py @@ -16,12 +16,10 @@ from fastapi.encoders import jsonable_encoder from fastapi import Request, WebSocket from anyio.abc import ObjectSendStream -from anyio.from_thread import BlockingPortal from anyio.to_thread import run_sync from pydantic import BaseModel -from .exceptions import NotConnectedToServerError from .properties import BaseProperty, DataProperty, BaseSetting from .descriptors import ActionDescriptor from .thing_description._model import ThingDescription, NoSecurityScheme @@ -35,6 +33,7 @@ if TYPE_CHECKING: from .server import ThingServer from .actions import ActionManager + from .thing_server_interface import ThingServerInterface _LOGGER = logging.getLogger(__name__) @@ -51,7 +50,8 @@ class Thing: * ``__init__``: You should accept any arguments you need to configure the Thing in ``__init__``. Don't initialise any hardware at this time, as your Thing may - be instantiated quite early, or even at import time. + be instantiated quite early, or even at import time. You must make sure to + call ``super().__init__(thing_server_interface)``\ . * ``__enter__(self)`` and ``__exit__(self, exc_t, exc_v, exc_tb)`` are where you should start and stop communications with the hardware. This is Python's "context manager" protocol. The arguments of ``__exit__`` will be ``None`` @@ -70,17 +70,36 @@ class Thing: so it makes sense to set this in a subclass. There are various LabThings methods that you should avoid overriding unless you - know what you are doing: anything not mentioned above that's defined in `Thing` is + know what you are doing: anything not mentioned above that's defined in `.Thing` is probably best left alone. They may in time be collected together into a single object to avoid namespace clashes. """ title: str """A human-readable description of the Thing""" - _labthings_blocking_portal: Optional[BlockingPortal] = None - """See :ref:`concurrency` for why blocking portal is needed.""" - path: Optional[str] = None - """The path at which the `.Thing` is exposed over HTTP.""" + + def __init__(self, thing_server_interface: ThingServerInterface) -> None: + """Initialise a Thing. + + The most important function of ``__init__`` is attaching the + thing_server_interface, and setting the path. Note that `.Thing` + instances are usually created by a `.ThingServer` and not instantiated + directly: if you do make a `.Thing` directly, you will need to supply + a `.ThingServerInterface` that is connected to a `.ThingServer` or a + suitable mock object. + + :param thing_server_interface: The interface to the server that + is hosting this Thing. It will be supplied when the `.Thing` is + instantiated by the `.ThingServer` or by + `.create_thing_without_server` which generates a mock interface. + """ + self._thing_server_interface = thing_server_interface + self._disable_saving_settings: bool = False + + @property + def path(self) -> str: + """The path at which the `.Thing` is exposed over HTTP.""" + return self._thing_server_interface.path async def __aenter__(self) -> Self: """Context management is used to set up/close the thing. @@ -110,18 +129,12 @@ async def __aexit__( if hasattr(self, "__exit__"): await run_sync(self.__exit__, exc_t, exc_v, exc_tb) - def attach_to_server( - self, server: ThingServer, path: str, setting_storage_path: str - ) -> None: + def attach_to_server(self, server: ThingServer) -> None: """Attach this thing to the server. Things need to be attached to a server before use to function correctly. :param server: The server to attach this Thing to. - :param path: The root URL for the Thing. - :param setting_storage_path: The path on disk to save the any Thing Settings - to. This should be the path to a json file. If it does not exist it will be - created. Attaching the `.Thing` to a `.ThingServer` allows the `.Thing` to start actions, load its settings from the correct place, and create HTTP endpoints @@ -130,9 +143,8 @@ def attach_to_server( We create HTTP endpoints for all :ref:`wot_affordances` on the `.Thing`, as well as any `.EndpointDescriptor` descriptors. """ - self.path = path self.action_manager: ActionManager = server.action_manager - self.load_settings(setting_storage_path) + self.load_settings() for _name, item in class_attributes(self): try: @@ -179,21 +191,7 @@ def _settings(self) -> dict[str, BaseSetting]: self._settings_store[name] = attr return self._settings_store - _setting_storage_path: Optional[str] = None - - @property - def setting_storage_path(self) -> Optional[str]: - """The storage path for settings. - - .. note:: - - This is set in `.Thing.attach_to_server`. It is ``None`` during the - ``__init__`` method, so it is best to avoid using settings until the - `.Thing` is set up in ``__enter__``. - """ - return self._setting_storage_path - - def load_settings(self, setting_storage_path: str) -> None: + def load_settings(self) -> None: """Load settings from json. Read the JSON file and use it to populate settings. @@ -205,14 +203,11 @@ def load_settings(self, setting_storage_path: str) -> None: Note that no notifications will be triggered when the settings are set, so if action is needed (e.g. updating hardware with the loaded settings) it should be taken in ``__enter__``. - - :param setting_storage_path: The path where the settings should be stored. """ - # Ensure that the settings path isn't set during loading or saving will be - # triggered - self._setting_storage_path = None + setting_storage_path = self._thing_server_interface.settings_file_path thing_name = type(self).__name__ if os.path.exists(setting_storage_path): + self._disable_saving_settings = True try: with open(setting_storage_path, "r", encoding="utf-8") as file_obj: setting_dict = json.load(file_obj) @@ -230,24 +225,18 @@ def load_settings(self, setting_storage_path: str) -> None: ) except (FileNotFoundError, JSONDecodeError, PermissionError): _LOGGER.warning("Error loading settings for %s", thing_name) - self._setting_storage_path = setting_storage_path + finally: + self._disable_saving_settings = False def save_settings(self) -> None: """Save settings to JSON. This is called whenever a setting is updated. All settings are written to the settings file every time. - - :raises NotConnectedToServerError: if there is no settings file path set. - This is set when the `.Thing` is connected to a `.ThingServer` so - most likely we are trying to save settings before we are attached - to a server. """ + if self._disable_saving_settings: + return if self._settings is not None: - if self._setting_storage_path is None: - raise NotConnectedToServerError( - "The path to the settings file is not defined yet." - ) setting_dict = {} for name in self._settings.keys(): value = getattr(self, name) @@ -256,7 +245,8 @@ def save_settings(self) -> None: setting_dict[name] = value # Dumpy to string before writing so if this fails the file isn't overwritten setting_json = json.dumps(setting_dict, indent=4) - with open(self._setting_storage_path, "w", encoding="utf-8") as file_obj: + path = self._thing_server_interface.settings_file_path + with open(path, "w", encoding="utf-8") as file_obj: file_obj.write(setting_json) _labthings_thing_state: Optional[dict] = None diff --git a/src/labthings_fastapi/thing_server_interface.py b/src/labthings_fastapi/thing_server_interface.py new file mode 100644 index 00000000..c11116f6 --- /dev/null +++ b/src/labthings_fastapi/thing_server_interface.py @@ -0,0 +1,243 @@ +r"""Interface between `.Thing` subclasses and the `.ThingServer`\ .""" + +from __future__ import annotations +from concurrent.futures import Future +import os +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Mapping, ParamSpec, TypeVar +from weakref import ref, ReferenceType + +from .exceptions import ServerNotRunningError + +if TYPE_CHECKING: + from .server import ThingServer + from .thing import Thing + + +Params = ParamSpec("Params") +ReturnType = TypeVar("ReturnType") + + +class ThingServerMissingError(RuntimeError): + """The error raised when a ThingServer is no longer available. + + This error indicates that a ThingServerInterface is still in use + even though its underlying ThingServer has been deleted. This is + unlikely to happen and usually indicates that the server has + been created in an odd way. + """ + + +class ThingServerInterface: + r"""An interface for Things to interact with their server. + + This is added to every `.Thing` during ``__init__`` and is available + as ``self._thing_server_interface``\ . + """ + + def __init__(self, server: ThingServer, name: str) -> None: + """Initialise a ThingServerInterface. + + The ThingServerInterface sits between a Thing and its ThingServer, + with the intention of providing a useful set of functions, without + exposing too much of the server to the Thing. + + One reason for using this intermediary class is to make it easier + to mock the server during testing: only functions provided here + need be mocked, not the whole functionality of the server. + + :param server: the `.ThingServer` instance we're connected to. + This will be retained as a weak reference. + :param name: the name of the `.Thing` instance this interface + is provided for. + """ + self._name: str = name + self._server: ReferenceType[ThingServer] = ref(server) + + def _get_server(self) -> ThingServer: + """Return a live reference to the ThingServer. + + This will evaluate the weak reference to the ThingServer, and will + raise an exception if the server has been garbage collected. + + The server is, in practice, not going to be finalized before the + Things, so this should not be a problem. + + :returns: the ThingServer. + + :raises ThingServerMissingError: if the `ThingServer` is no longer + available. + """ + server = self._server() + if server is None: + raise ThingServerMissingError() + return server + + def start_async_task_soon( + self, async_function: Callable[Params, Awaitable[ReturnType]], *args: Any + ) -> Future[ReturnType]: + r"""Run an asynchronous task in the server's event loop. + + This function wraps `anyio.from_thread.BlockingPortal.start_task_soon` to + provide a way of calling asynchronous code from threaded code. It will + call the provided async function in the server's event loop, without any + guarantee of exactly when it will happen. This means we will return + immediately, and the return value of this function will be a + `concurrent.futures.Future` object that may resolve to the async function's + return value. + + :param async_function: the asynchronous function to call. + :param \*args: positional arguments to be provided to the function. + + :returns: an `asyncio.Future` object wrapping the return value. + + :raises ServerNotRunningError: if the server is not running + (i.e. there is no event loop). + """ + portal = self._get_server().blocking_portal + if portal is None: + raise ServerNotRunningError("Can't run async code without an event loop.") + return portal.start_task_soon(async_function, *args) + + @property + def settings_folder(self) -> str: + """The path to a folder where persistent files may be saved.""" + server = self._get_server() + return os.path.join(server.settings_folder, self.name) + + @property + def settings_file_path(self) -> str: + """The path where settings should be loaded and saved as JSON.""" + return os.path.join(self.settings_folder, "settings.json") + + @property + def name(self) -> str: + """The name of the Thing attached to this interface.""" + return self._name + + @property + def path(self) -> str: + """The path, relative to the server's base URL, of the Thing. + + A ThingServerInterface is specific to one Thing, so this path points + to the base URL of the Thing, i.e. the Thing Description's endpoint. + """ + return self._get_server().path_for_thing(self.name) + + def get_thing_states(self) -> Mapping[str, Any]: + """Retrieve metadata from all Things on the server. + + This function will retrieve the `.Thing.thing_state` property from + each `.Thing` on the server, and return it as a dictionary. + It is intended to make it easy to add metadata to the results + of actions, for example to embed in an image. + + :return: a dictionary of metadata, with the `.Thing` names as keys. + """ + return {k: v.thing_state for k, v in self._get_server().things.items()} + + +class MockThingServerInterface(ThingServerInterface): + """A mock class that simulates a ThingServerInterface without the server.""" + + def __init__(self, name: str) -> None: + """Initialise a ThingServerInterface. + + :param name: The name of the Thing we're providing an interface to. + """ + # We deliberately don't call super().__init__(), as it won't work without + # a server. + self._name: str = name + self._settings_tempdir: TemporaryDirectory | None = None + + def start_async_task_soon( + self, async_function: Callable[Params, Awaitable[ReturnType]], *args: Any + ) -> Future[ReturnType]: + r"""Do nothing, as there's no event loop to use. + + This returns a `concurrent.futures.Future` object that is already cancelled, + in order to avoid accidental hangs in test code that attempts to wait for + the future object to resolve. Cancelling it may cause errors if you need + the return value. + + If you need the async code to run, it's best to add the `.Thing` to a + `lt.ThingServer` instead. Using a test client will start an event loop + in a background thread, and allow you to use a real `.ThingServerInterface` + without the overhead of actually starting an HTTP server. + + :param async_function: the asynchronous function to call. + :param \*args: positional arguments to be provided to the function. + + :returns: a `concurrent.futures.Future` object that has been cancelled. + """ + f: Future[ReturnType] = Future() + f.cancel() + return f + + @property + def settings_folder(self) -> str: + """The path to a folder where persistent files may be saved. + + This will create a temporary folder the first time it is called, + and return the same folder on subsequent calls. + + :returns: the path to a temporary folder. + """ + if not self._settings_tempdir: + self._settings_tempdir = TemporaryDirectory() + return self._settings_tempdir.name + + @property + def path(self) -> str: + """The path, relative to the server's base URL, of the Thing. + + A ThingServerInterface is specific to one Thing, so this path points + to the base URL of the Thing, i.e. the Thing Description's endpoint. + """ + return f"/{self.name}/" + + def get_thing_states(self) -> Mapping[str, Any]: + """Return an empty dictionary to mock the metadata dictionary. + + :returns: an empty dictionary. + """ + return {} + + +ThingSubclass = TypeVar("ThingSubclass", bound="Thing") + + +def create_thing_without_server( + cls: type[ThingSubclass], *args: Any, **kwargs: Any +) -> ThingSubclass: + r"""Create a `.Thing` and supply a mock ThingServerInterface. + + This function is intended for use in testing, where it will enable a `.Thing` + to be created without a server, by supplying a `.MockThingServerInterface` + instead of a real `.ThingServerInterface`\ . + + The name of the Thing will be taken from the class name, lowercased. + + :param cls: The `.Thing` subclass to instantiate. + :param \*args: positional arguments to ``__init__``. + :param \**kwargs: keyword arguments to ``__init__``. + + :returns: an instance of ``cls`` with a `.MockThingServerInterface` + so that it will function without a server. + + :raises ValueError: if a keyword argument called 'thing_server_interface' + is supplied, as this would conflict with the mock interface. + """ + name = cls.__name__.lower() + if "thing_server_interface" in kwargs: + msg = "You may not supply a keyword argument called 'thing_server_interface'." + raise ValueError(msg) + return cls( + *args, **kwargs, thing_server_interface=MockThingServerInterface(name=name) + ) # type: ignore[misc] + # Note: we must ignore misc typing errors above because mypy flags an error + # that `thing_server_interface` is multiply specified. + # This is a conflict with *args, if we had only **kwargs it would not flag + # any error. + # Given that args and kwargs are dynamically typed anyway, this does not + # lose us much. diff --git a/src/labthings_fastapi/utilities/__init__.py b/src/labthings_fastapi/utilities/__init__.py index fb94d786..f1964e94 100644 --- a/src/labthings_fastapi/utilities/__init__.py +++ b/src/labthings_fastapi/utilities/__init__.py @@ -5,7 +5,6 @@ from weakref import WeakSet from pydantic import BaseModel, ConfigDict, Field, RootModel, create_model from pydantic.dataclasses import dataclass -from anyio.from_thread import BlockingPortal from .introspection import EmptyObject if TYPE_CHECKING: @@ -82,25 +81,6 @@ def labthings_data(obj: Thing) -> LabThingsObjectData: return obj.__dict__[LABTHINGS_DICT_KEY] -def get_blocking_portal(obj: Thing) -> Optional[BlockingPortal]: - """Retrieve a blocking portal from a Thing. - - See :ref:`concurrency` for more details. - - When a `.Thing` is attached to a `.ThingServer` and the `.ThingServer` - is started, it sets an attribute on each `.Thing` to allow it to - access an `anyio.from_thread.BlockingPortal`. This allows threaded - code to call async code. - - This function retrieves the blocking portal from a `.Thing`. - - :param obj: the `.Thing` on which we are looking for the portal. - - :return: the blocking portal. - """ - return obj._labthings_blocking_portal - - def wrap_plain_types_in_rootmodel(model: type) -> type[BaseModel]: """Ensure a type is a subclass of BaseModel. diff --git a/tests/test_action_cancel.py b/tests/test_action_cancel.py index 5e7c5742..881e6bf1 100644 --- a/tests/test_action_cancel.py +++ b/tests/test_action_cancel.py @@ -3,6 +3,7 @@ """ import uuid +import pytest from fastapi.testclient import TestClient from .temp_client import poll_task, task_href import labthings_fastapi as lt @@ -68,136 +69,136 @@ def count_and_only_cancel_if_asked_twice( self.counter += counting_increment -def test_invocation_cancel(): +@pytest.fixture +def server(): + """Create a server with a CancellableCountingThing added.""" + server = lt.ThingServer() + server.add_thing("counting_thing", CancellableCountingThing) + return server + + +@pytest.fixture +def counting_thing(server): + """Retrieve the CancellableCountingThing from the server.""" + return server.things["counting_thing"] + + +@pytest.fixture +def client(server): + with TestClient(server.app) as client: + yield client + + +def test_invocation_cancel(counting_thing, client): """ Test that an invocation can be cancelled and the associated exception handled correctly. """ - server = lt.ThingServer() - counting_thing = CancellableCountingThing() - server.add_thing(counting_thing, "/counting_thing") - with TestClient(server.app) as client: - assert counting_thing.counter == 0 - assert not counting_thing.check - response = client.post("/counting_thing/count_slowly", json={}) - response.raise_for_status() - # Use `client.delete` to cancel the task! - cancel_response = client.delete(task_href(response.json())) - # Raise an exception is this isn't a 2xx response - cancel_response.raise_for_status() - invocation = poll_task(client, response.json()) - assert invocation["status"] == "cancelled" - assert counting_thing.counter < 9 - # Check that error handling worked - assert counting_thing.check - - -def test_invocation_that_refuses_to_cancel(): + assert counting_thing.counter == 0 + assert not counting_thing.check + response = client.post("/counting_thing/count_slowly", json={}) + response.raise_for_status() + # Use `client.delete` to cancel the task! + cancel_response = client.delete(task_href(response.json())) + # Raise an exception is this isn't a 2xx response + cancel_response.raise_for_status() + invocation = poll_task(client, response.json()) + assert invocation["status"] == "cancelled" + assert counting_thing.counter < 9 + # Check that error handling worked + assert counting_thing.check + + +def test_invocation_that_refuses_to_cancel(counting_thing, client): """ Test that an invocation can detect a cancel request but choose to modify behaviour. """ - server = lt.ThingServer() - counting_thing = CancellableCountingThing() - server.add_thing(counting_thing, "/counting_thing") - with TestClient(server.app) as client: - assert counting_thing.counter == 0 - response = client.post( - "/counting_thing/count_slowly_but_ignore_cancel", json={"n": 5} - ) - response.raise_for_status() - # Use `client.delete` to try to cancel the task! - cancel_response = client.delete(task_href(response.json())) - # Raise an exception is this isn't a 2xx response - cancel_response.raise_for_status() - invocation = poll_task(client, response.json()) - # As the task ignored the cancel. It should return completed - assert invocation["status"] == "completed" - # Counter should be greater than 5 as it counts faster if cancelled! - assert counting_thing.counter > 5 - - -def test_invocation_that_needs_cancel_twice(): + assert counting_thing.counter == 0 + response = client.post( + "/counting_thing/count_slowly_but_ignore_cancel", json={"n": 5} + ) + response.raise_for_status() + # Use `client.delete` to try to cancel the task! + cancel_response = client.delete(task_href(response.json())) + # Raise an exception is this isn't a 2xx response + cancel_response.raise_for_status() + invocation = poll_task(client, response.json()) + # As the task ignored the cancel. It should return completed + assert invocation["status"] == "completed" + # Counter should be greater than 5 as it counts faster if cancelled! + assert counting_thing.counter > 5 + + +def test_invocation_that_needs_cancel_twice(counting_thing, client): """ Test that an invocation can interpret cancel to change behaviour, but can really cancel if requested a second time """ - server = lt.ThingServer() - counting_thing = CancellableCountingThing() - server.add_thing(counting_thing, "/counting_thing") - with TestClient(server.app) as client: - # First cancel only once: - assert counting_thing.counter == 0 - response = client.post( - "/counting_thing/count_and_only_cancel_if_asked_twice", json={"n": 5} - ) - response.raise_for_status() - # Use `client.delete` to try to cancel the task! - cancel_response = client.delete(task_href(response.json())) - # Raise an exception is this isn't a 2xx response - cancel_response.raise_for_status() - invocation = poll_task(client, response.json()) - # As the task ignored the cancel. It should return completed - assert invocation["status"] == "completed" - # Counter should be less than 0 as it should started counting backwards - # almost immediately. - assert counting_thing.counter < 0 - - # Next cancel twice. - counting_thing.counter = 0 - assert counting_thing.counter == 0 - response = client.post( - "/counting_thing/count_and_only_cancel_if_asked_twice", json={"n": 5} - ) - response.raise_for_status() - # Use `client.delete` to try to cancel the task! - cancel_response = client.delete(task_href(response.json())) - # Raise an exception is this isn't a 2xx response - cancel_response.raise_for_status() - # Cancel again - cancel_response2 = client.delete(task_href(response.json())) - # Raise an exception is this isn't a 2xx response - cancel_response2.raise_for_status() - invocation = poll_task(client, response.json()) - # As the task ignored the cancel. It should return completed - assert invocation["status"] == "cancelled" - # Counter should be less than 0 as it should started counting backwards - # almost immediately. - assert counting_thing.counter < 0 - - -def test_late_invocation_cancel_responds_503(): + # First cancel only once: + assert counting_thing.counter == 0 + response = client.post( + "/counting_thing/count_and_only_cancel_if_asked_twice", json={"n": 5} + ) + response.raise_for_status() + # Use `client.delete` to try to cancel the task! + cancel_response = client.delete(task_href(response.json())) + # Raise an exception is this isn't a 2xx response + cancel_response.raise_for_status() + invocation = poll_task(client, response.json()) + # As the task ignored the cancel. It should return completed + assert invocation["status"] == "completed" + # Counter should be less than 0 as it should started counting backwards + # almost immediately. + assert counting_thing.counter < 0 + + # Next cancel twice. + counting_thing.counter = 0 + assert counting_thing.counter == 0 + response = client.post( + "/counting_thing/count_and_only_cancel_if_asked_twice", json={"n": 5} + ) + response.raise_for_status() + # Use `client.delete` to try to cancel the task! + cancel_response = client.delete(task_href(response.json())) + # Raise an exception is this isn't a 2xx response + cancel_response.raise_for_status() + # Cancel again + cancel_response2 = client.delete(task_href(response.json())) + # Raise an exception is this isn't a 2xx response + cancel_response2.raise_for_status() + invocation = poll_task(client, response.json()) + # As the task ignored the cancel. It should return completed + assert invocation["status"] == "cancelled" + # Counter should be less than 0 as it should started counting backwards + # almost immediately. + assert counting_thing.counter < 0 + + +def test_late_invocation_cancel_responds_503(counting_thing, client): """ Test that cancelling an invocation after it completes returns a 503 response. """ - server = lt.ThingServer() - counting_thing = CancellableCountingThing() - server.add_thing(counting_thing, "/counting_thing") - with TestClient(server.app) as client: - assert counting_thing.counter == 0 - assert not counting_thing.check - response = client.post("/counting_thing/count_slowly", json={"n": 1}) - response.raise_for_status() - # Sleep long enough that task completes. - time.sleep(0.3) - poll_task(client, response.json()) - # Use `client.delete` to cancel the task! - cancel_response = client.delete(task_href(response.json())) - # Check a 503 code is returned - assert cancel_response.status_code == 503 - # Check counter reached it's target - assert counting_thing.counter == 1 - # Check that error handling wasn't called - assert not counting_thing.check - - -def test_cancel_unknown_task(): + assert counting_thing.counter == 0 + assert not counting_thing.check + response = client.post("/counting_thing/count_slowly", json={"n": 1}) + response.raise_for_status() + # Sleep long enough that task completes. + time.sleep(0.3) + poll_task(client, response.json()) + # Use `client.delete` to cancel the task! + cancel_response = client.delete(task_href(response.json())) + # Check a 503 code is returned + assert cancel_response.status_code == 503 + # Check counter reached it's target + assert counting_thing.counter == 1 + # Check that error handling wasn't called + assert not counting_thing.check + + +def test_cancel_unknown_task(counting_thing, client): """ Test that cancelling an unknown invocation returns a 404 response """ - server = lt.ThingServer() - counting_thing = CancellableCountingThing() - server.add_thing(counting_thing, "/counting_thing") - with TestClient(server.app) as client: - cancel_response = client.delete(f"/invocations/{uuid.uuid4()}") - assert cancel_response.status_code == 404 + cancel_response = client.delete(f"/invocations/{uuid.uuid4()}") + assert cancel_response.status_code == 404 diff --git a/tests/test_action_logging.py b/tests/test_action_logging.py index 3231dd4e..03b9213a 100644 --- a/tests/test_action_logging.py +++ b/tests/test_action_logging.py @@ -4,6 +4,7 @@ import logging from fastapi.testclient import TestClient +import pytest from .temp_client import poll_task import labthings_fastapi as lt from labthings_fastapi.actions.invocation_model import LogRecordModel @@ -29,53 +30,54 @@ def action_with_invocation_error(self, logger: lt.deps.InvocationLogger): raise lt.exceptions.InvocationError("This is an error, but I handled it!") -def test_invocation_logging(caplog): +@pytest.fixture +def client(): + """Set up a Thing Server and yield a client to it.""" + server = lt.ThingServer() + server.add_thing("log_and_error_thing", ThingThatLogsAndErrors) + with TestClient(server.app) as client: + yield client + + +def test_invocation_logging(caplog, client): + """Check the expected items appear in the log when an action is invoked.""" with caplog.at_level(logging.INFO, logger="labthings.action"): - server = lt.ThingServer() - server.add_thing(ThingThatLogsAndErrors(), "/log_and_error_thing") - with TestClient(server.app) as client: - r = client.post("/log_and_error_thing/action_that_logs") - r.raise_for_status() - invocation = poll_task(client, r.json()) - assert invocation["status"] == "completed" - assert len(invocation["log"]) == len(ThingThatLogsAndErrors.LOG_MESSAGES) - assert len(invocation["log"]) == len(caplog.records) - for expected, entry in zip( - ThingThatLogsAndErrors.LOG_MESSAGES, invocation["log"], strict=True - ): - assert entry["message"] == expected - - -def test_unhandled_error_logs(caplog): + r = client.post("/log_and_error_thing/action_that_logs") + r.raise_for_status() + invocation = poll_task(client, r.json()) + assert invocation["status"] == "completed" + assert len(invocation["log"]) == len(ThingThatLogsAndErrors.LOG_MESSAGES) + assert len(invocation["log"]) == len(caplog.records) + for expected, entry in zip( + ThingThatLogsAndErrors.LOG_MESSAGES, invocation["log"], strict=True + ): + assert entry["message"] == expected + + +def test_unhandled_error_logs(caplog, client): """Check that a log with a traceback is raised if there is an unhandled error.""" with caplog.at_level(logging.INFO, logger="labthings.action"): - server = lt.ThingServer() - server.add_thing(ThingThatLogsAndErrors(), "/log_and_error_thing") - with TestClient(server.app) as client: - r = client.post("/log_and_error_thing/action_with_unhandled_error") - r.raise_for_status() - invocation = poll_task(client, r.json()) - assert invocation["status"] == "error" - assert len(invocation["log"]) == len(caplog.records) == 1 - assert caplog.records[0].levelname == "ERROR" - # There is a traceback - assert caplog.records[0].exc_info is not None - - -def test_invocation_error_logs(caplog): + r = client.post("/log_and_error_thing/action_with_unhandled_error") + r.raise_for_status() + invocation = poll_task(client, r.json()) + assert invocation["status"] == "error" + assert len(invocation["log"]) == len(caplog.records) == 1 + assert caplog.records[0].levelname == "ERROR" + # There is a traceback + assert caplog.records[0].exc_info is not None + + +def test_invocation_error_logs(caplog, client): """Check that a log with a traceback is raised if there is an unhandled error.""" with caplog.at_level(logging.INFO, logger="labthings.action"): - server = lt.ThingServer() - server.add_thing(ThingThatLogsAndErrors(), "/log_and_error_thing") - with TestClient(server.app) as client: - r = client.post("/log_and_error_thing/action_with_invocation_error") - r.raise_for_status() - invocation = poll_task(client, r.json()) - assert invocation["status"] == "error" - assert len(invocation["log"]) == len(caplog.records) == 1 - assert caplog.records[0].levelname == "ERROR" - # There is not a traceback - assert caplog.records[0].exc_info is None + r = client.post("/log_and_error_thing/action_with_invocation_error") + r.raise_for_status() + invocation = poll_task(client, r.json()) + assert invocation["status"] == "error" + assert len(invocation["log"]) == len(caplog.records) == 1 + assert caplog.records[0].levelname == "ERROR" + # There is not a traceback + assert caplog.records[0].exc_info is None def test_logrecordmodel(): diff --git a/tests/test_action_manager.py b/tests/test_action_manager.py index 2ad7385f..5da65ba5 100644 --- a/tests/test_action_manager.py +++ b/tests/test_action_manager.py @@ -25,12 +25,16 @@ def increment_counter_longlife(self): "A pointless counter" -thing = CounterThing() -server = lt.ThingServer() -server.add_thing(thing, "/thing") +@pytest.fixture +def client(): + """Yield a TestClient connected to a ThingServer.""" + server = lt.ThingServer() + server.add_thing("thing", CounterThing) + with TestClient(server.app) as client: + yield client -def test_action_expires(): +def test_action_expires(client): """Check the action is removed from the server We've set the retention period to be very short, so the action @@ -44,23 +48,22 @@ def test_action_expires(): This behaviour might change in the future, making the second run unnecessary. """ - with TestClient(server.app) as client: - before_value = client.get("/thing/counter").json() - r = client.post("/thing/increment_counter") - invocation = poll_task(client, r.json()) - time.sleep(0.02) - r2 = client.post("/thing/increment_counter") - poll_task(client, r2.json()) - after_value = client.get("/thing/counter").json() - assert after_value == before_value + 2 - invocation["status"] = "running" # Force an extra poll - # When the second action runs, the first one should expire - # so polling it again should give a 404. - with pytest.raises(httpx.HTTPStatusError): - poll_task(client, invocation) - - -def test_actions_list(): + before_value = client.get("/thing/counter").json() + r = client.post("/thing/increment_counter") + invocation = poll_task(client, r.json()) + time.sleep(0.02) + r2 = client.post("/thing/increment_counter") + poll_task(client, r2.json()) + after_value = client.get("/thing/counter").json() + assert after_value == before_value + 2 + invocation["status"] = "running" # Force an extra poll + # When the second action runs, the first one should expire + # so polling it again should give a 404. + with pytest.raises(httpx.HTTPStatusError): + poll_task(client, invocation) + + +def test_actions_list(client): """Check that the /action_invocations/ path works. The /action_invocations/ path should return a list of invocation @@ -68,10 +71,9 @@ def test_actions_list(): It's implemented in `ActionManager.list_all_invocations`. """ - with TestClient(server.app) as client: - r = client.post("/thing/increment_counter_longlife") - invocation = poll_task(client, r.json()) - r2 = client.get(ACTION_INVOCATIONS_PATH) - r2.raise_for_status() - invocations = r2.json() - assert invocations == [invocation] + r = client.post("/thing/increment_counter_longlife") + invocation = poll_task(client, r.json()) + r2 = client.get(ACTION_INVOCATIONS_PATH) + r2.raise_for_status() + invocations = r2.json() + assert invocations == [invocation] diff --git a/tests/test_actions.py b/tests/test_actions.py index 64533ddd..3227087d 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -3,14 +3,19 @@ import pytest import functools -from labthings_fastapi.exceptions import NotConnectedToServerError +from labthings_fastapi.thing_server_interface import create_thing_without_server from .temp_client import poll_task, get_link from labthings_fastapi.example_things import MyThing import labthings_fastapi as lt -thing = MyThing() -server = lt.ThingServer() -server.add_thing(thing, "/thing") + +@pytest.fixture +def client(): + """Yield a client connected to a ThingServer""" + server = lt.ThingServer() + server.add_thing("thing", MyThing) + with TestClient(server.app) as client: + yield client def action_partial(client: TestClient, url: str): @@ -22,54 +27,67 @@ def run(payload=None): return run -def test_get_action_invocations(): +def test_get_action_invocations(client): """Test that running "get" on an action returns a list of invocations.""" - with TestClient(server.app) as client: - # When we start the action has no invocations - invocations_before = client.get("/thing/increment_counter").json() - assert invocations_before == [] - # Start the action - r = client.post("/thing/increment_counter") - assert r.status_code in (200, 201) - # Now it is started, there is a list of 1 dictionary containing the - # invocation information. - invocations_after = client.get("/thing/increment_counter").json() - assert len(invocations_after) == 1 - assert isinstance(invocations_after, list) - assert isinstance(invocations_after[0], dict) - assert "status" in invocations_after[0] - assert "id" in invocations_after[0] - assert "action" in invocations_after[0] - assert "href" in invocations_after[0] - assert "timeStarted" in invocations_after[0] - # Let the task finish before ending the test - poll_task(client, r.json()) - - -def test_counter(): - with TestClient(server.app) as client: - before_value = client.get("/thing/counter").json() - r = client.post("/thing/increment_counter") - assert r.status_code in (200, 201) - poll_task(client, r.json()) - after_value = client.get("/thing/counter").json() - assert after_value == before_value + 1 + # When we start the action has no invocations + invocations_before = client.get("/thing/increment_counter").json() + assert invocations_before == [] + # Start the action + r = client.post("/thing/increment_counter") + assert r.status_code in (200, 201) + # Now it is started, there is a list of 1 dictionary containing the + # invocation information. + invocations_after = client.get("/thing/increment_counter").json() + assert len(invocations_after) == 1 + assert isinstance(invocations_after, list) + assert isinstance(invocations_after[0], dict) + assert "status" in invocations_after[0] + assert "id" in invocations_after[0] + assert "action" in invocations_after[0] + assert "href" in invocations_after[0] + assert "timeStarted" in invocations_after[0] + # Let the task finish before ending the test + poll_task(client, r.json()) + + +def test_counter(client): + """Test that the increment_counter action increments the property.""" + before_value = client.get("/thing/counter").json() + r = client.post("/thing/increment_counter") + assert r.status_code in (200, 201) + poll_task(client, r.json()) + after_value = client.get("/thing/counter").json() + assert after_value == before_value + 1 + + +def test_no_args(client): + """Test None and {} are both accepted as input. + + Actions that take no arguments will accept either an empty + dictionary or None as their input. + + Note that there is an assertion in `action_partial` so we + do check that the action runs. + """ + run = action_partial(client, "/thing/action_without_arguments") + run({}) # an empty dict should be OK + run(None) # it should also be OK to call it with None + # Calling with no payload is equivalent to None -def test_no_args(): - with TestClient(server.app) as client: - run = action_partial(client, "/thing/action_without_arguments") - run({}) # an empty dict should be OK - run(None) # it should also be OK to call it with None - # Calling with no payload is equivalent to None +def test_only_kwargs(client): + """Test an action that only has **kwargs works as expected. + It should be allowable to invoke such an action with no + input (see test above) or with arbitrary keyword arguments. -def test_only_kwargs(): - with TestClient(server.app) as client: - run = action_partial(client, "/thing/action_with_only_kwargs") - run({}) # an empty dict should be OK - run(None) # it should also be OK to call it with None - run({"foo": "bar"}) # it should be OK to call it with a payload + Note that there is an assertion in `action_partial` so we + do check that the action runs. + """ + run = action_partial(client, "/thing/action_with_only_kwargs") + run({}) # an empty dict should be OK + run(None) # it should also be OK to call it with None + run({"foo": "bar"}) # it should be OK to call it with a payload def test_varargs(): @@ -82,50 +100,48 @@ def action_with_varargs(self, *args) -> None: pass -def test_action_output(): +def test_action_output(client): """Test that an action's output may be retrieved directly. This tests the /action_invocation/{id}/output endpoint, including some error conditions (not found/output not available). """ - with TestClient(server.app) as client: - # Start an action and wait for it to complete - r = client.post("/thing/make_a_dict", json={}) - r.raise_for_status() - invocation = poll_task(client, r.json()) - assert invocation["status"] == "completed" - assert invocation["output"] == {"key": "value"} - # Retrieve the output directly and check it matches - r = client.get(get_link(invocation, "output")["href"]) - assert r.json() == {"key": "value"} - - # Test an action that doesn't have an output - r = client.post("/thing/action_without_arguments", json={}) - r.raise_for_status() - invocation = poll_task(client, r.json()) - assert invocation["status"] == "completed" - assert invocation["output"] is None - - # If the output is None, retrieving it directly should fail - r = client.get(get_link(invocation, "output")["href"]) - assert r.status_code == 503 - - # Repeat the last check, using a manually generated URL - # (mostly to check the manually generated URL is valid, - # so the next test can be trusted). - r = client.get(f"/action_invocation/{invocation['id']}/output") - assert r.status_code == 404 - - # Test an output on a non-existent invocation - r = client.get(f"/action_invocation/{uuid.uuid4()}/output") - assert r.status_code == 404 - - -def test_openapi(): + # Start an action and wait for it to complete + r = client.post("/thing/make_a_dict", json={}) + r.raise_for_status() + invocation = poll_task(client, r.json()) + assert invocation["status"] == "completed" + assert invocation["output"] == {"key": "value"} + # Retrieve the output directly and check it matches + r = client.get(get_link(invocation, "output")["href"]) + assert r.json() == {"key": "value"} + + # Test an action that doesn't have an output + r = client.post("/thing/action_without_arguments", json={}) + r.raise_for_status() + invocation = poll_task(client, r.json()) + assert invocation["status"] == "completed" + assert invocation["output"] is None + + # If the output is None, retrieving it directly should fail + r = client.get(get_link(invocation, "output")["href"]) + assert r.status_code == 503 + + # Repeat the last check, using a manually generated URL + # (mostly to check the manually generated URL is valid, + # so the next test can be trusted). + r = client.get(f"/action_invocation/{invocation['id']}/output") + assert r.status_code == 404 + + # Test an output on a non-existent invocation + r = client.get(f"/action_invocation/{uuid.uuid4()}/output") + assert r.status_code == 404 + + +def test_openapi(client): """Check the OpenAPI docs are generated OK""" - with TestClient(server.app) as client: - r = client.get("/openapi.json") - r.raise_for_status() + r = client.get("/openapi.json") + r.raise_for_status() def example_decorator(func): @@ -182,19 +198,5 @@ def decorated( assert Example.action.output_model == Example.decorated.output_model # Check we can make the thing and it has a valid TD - example = Example() - example.path = "/example" + example = create_thing_without_server(Example) example.validate_thing_description() - - -def test_affordance_and_fastapi_errors(mocker): - """Check that we get a sensible error if the Thing has no path. - - The thing will not have a ``path`` property before it has been added - to a server. - """ - thing = MyThing() - with pytest.raises(NotConnectedToServerError): - MyThing.anaction.add_to_fastapi(mocker.Mock(), thing) - with pytest.raises(NotConnectedToServerError): - MyThing.anaction.action_affordance(thing, None) diff --git a/tests/test_blob_output.py b/tests/test_blob_output.py index 749bc428..32fdb94b 100644 --- a/tests/test_blob_output.py +++ b/tests/test_blob_output.py @@ -8,6 +8,7 @@ from fastapi.testclient import TestClient import pytest import labthings_fastapi as lt +from labthings_fastapi.thing_server_interface import create_thing_without_server class TextBlob(lt.blob.Blob): @@ -17,7 +18,8 @@ class TextBlob(lt.blob.Blob): class ThingOne(lt.Thing): ACTION_ONE_RESULT = b"Action one result!" - def __init__(self): + def __init__(self, thing_server_interface): + super().__init__(thing_server_interface=thing_server_interface) self._temp_directory = TemporaryDirectory() @lt.thing_action @@ -47,7 +49,7 @@ def passthrough_blob(self, blob: TextBlob) -> TextBlob: return blob -ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "/thing_one/") +ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "thing_one") class ThingTwo(lt.Thing): @@ -66,6 +68,16 @@ def check_passthrough(self, thing_one: ThingOneDep) -> bool: return True +@pytest.fixture +def client(): + """Yield a test client connected to a ThingServer.""" + server = lt.ThingServer() + server.add_thing("thing_one", ThingOne) + server.add_thing("thing_two", ThingTwo) + with TestClient(server.app) as client: + yield client + + def test_blob_type(): """Check we can't put dodgy values into a blob output model""" with pytest.raises(ValueError): @@ -96,30 +108,23 @@ def test_blob_creation(): assert blob.content == TEXT -def test_blob_output_client(): +def test_blob_output_client(client): """Test that blob outputs work as expected when used over HTTP.""" - server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - with TestClient(server.app) as client: - tc = lt.ThingClient.from_url("/thing_one/", client=client) - check_actions(tc) + tc = lt.ThingClient.from_url("/thing_one/", client=client) + check_actions(tc) def test_blob_output_direct(): """Check blob outputs work correctly when we use a Thing directly in Python.""" - thing = ThingOne() + thing = create_thing_without_server(ThingOne) check_actions(thing) -def test_blob_output_inserver(): +def test_blob_output_inserver(client): """Test that the blob output works the same when used via a DirectThingClient.""" - server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - server.add_thing(ThingTwo(), "/thing_two") - with TestClient(server.app) as client: - tc = lt.ThingClient.from_url("/thing_two/", client=client) - output = tc.check_both() - assert output is True + tc = lt.ThingClient.from_url("/thing_two/", client=client) + output = tc.check_both() + assert output is True def check_blob(output, expected_content: bytes): @@ -145,23 +150,19 @@ def check_actions(thing): check_blob(output, ThingOne.ACTION_ONE_RESULT) -def test_blob_input(): +def test_blob_input(client): """Check that blobs can be used as input.""" - server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - server.add_thing(ThingTwo(), "/thing_two") - with TestClient(server.app) as client: - tc = lt.ThingClient.from_url("/thing_one/", client=client) - output = tc.action_one() - print(f"Output is {output}") - assert output is not None - - # Check that the blob can be passed from one action to another, - # via the client - passthrough = tc.passthrough_blob(blob=output) - print(f"Output is {passthrough}") - assert passthrough.content == ThingOne.ACTION_ONE_RESULT - - # Check that the same thing works on the server side - tc2 = lt.ThingClient.from_url("/thing_two/", client=client) - assert tc2.check_passthrough() is True + tc = lt.ThingClient.from_url("/thing_one/", client=client) + output = tc.action_one() + print(f"Output is {output}") + assert output is not None + + # Check that the blob can be passed from one action to another, + # via the client + passthrough = tc.passthrough_blob(blob=output) + print(f"Output is {passthrough}") + assert passthrough.content == ThingOne.ACTION_ONE_RESULT + + # Check that the same thing works on the server side + tc2 = lt.ThingClient.from_url("/thing_two/", client=client) + assert tc2.check_passthrough() is True diff --git a/tests/test_dependency_metadata.py b/tests/test_dependency_metadata.py index c0a04e7d..efec5cbf 100644 --- a/tests/test_dependency_metadata.py +++ b/tests/test_dependency_metadata.py @@ -4,13 +4,14 @@ from typing import Any, Mapping from fastapi.testclient import TestClient +import pytest from .temp_client import poll_task import labthings_fastapi as lt class ThingOne(lt.Thing): - def __init__(self): - lt.Thing.__init__(self) + def __init__(self, thing_server_interface): + super().__init__(thing_server_interface=thing_server_interface) self._a = 0 @lt.property @@ -26,7 +27,7 @@ def thing_state(self): return {"a": self.a} -ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "/thing_one/") +ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "thing_one") class ThingTwo(lt.Thing): @@ -37,7 +38,7 @@ def thing_state(self): return {"a": 1} @lt.thing_action - def count_and_watch( + def count_and_watch_deprecated( self, thing_one: ThingOneDep, get_metadata: lt.deps.GetThingStates ) -> Mapping[str, Mapping[str, Any]]: metadata = {} @@ -46,16 +47,44 @@ def count_and_watch( metadata[f"a_{a}"] = get_metadata() return metadata + @lt.thing_action + def count_and_watch( + self, thing_one: ThingOneDep + ) -> Mapping[str, Mapping[str, Any]]: + metadata = {} + for a in self.A_VALUES: + thing_one.a = a + metadata[f"a_{a}"] = self._thing_server_interface.get_thing_states() + return metadata + -def test_fresh_metadata(): +@pytest.fixture +def client(): + """Yield a test client connected to a ThingServer.""" server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one/") - server.add_thing(ThingTwo(), "/thing_two/") + server.add_thing("thing_one", ThingOne) + server.add_thing("thing_two", ThingTwo) with TestClient(server.app) as client: - r = client.post("/thing_two/count_and_watch") - invocation = poll_task(client, r.json()) - assert invocation["status"] == "completed" - out = invocation["output"] - for a in ThingTwo.A_VALUES: - assert out[f"a_{a}"]["/thing_one/"]["a"] == a - assert out[f"a_{a}"]["/thing_two/"]["a"] == 1 + yield client + + +def test_fresh_metadata(client): + """Check that fresh metadata is retrieved by get_thing_states.""" + r = client.post("/thing_two/count_and_watch") + invocation = poll_task(client, r.json()) + assert invocation["status"] == "completed" + out = invocation["output"] + for a in ThingTwo.A_VALUES: + assert out[f"a_{a}"]["thing_one"]["a"] == a + assert out[f"a_{a}"]["thing_two"]["a"] == 1 + + +def test_fresh_metadata_deprecated(client): + """Test that the old metadata dependency retrieves fresh metadata.""" + r = client.post("/thing_two/count_and_watch") + invocation = poll_task(client, r.json()) + assert invocation["status"] == "completed" + out = invocation["output"] + for a in ThingTwo.A_VALUES: + assert out[f"a_{a}"]["thing_one"]["a"] == a + assert out[f"a_{a}"]["thing_two"]["a"] == 1 diff --git a/tests/test_directthingclient.py b/tests/test_directthingclient.py index a11bd448..8bf8c8f4 100644 --- a/tests/test_directthingclient.py +++ b/tests/test_directthingclient.py @@ -8,6 +8,7 @@ import pytest import labthings_fastapi as lt from labthings_fastapi.deps import DirectThingClient, direct_thing_client_class +from labthings_fastapi.thing_server_interface import create_thing_without_server from .temp_client import poll_task @@ -40,10 +41,9 @@ def counter_client(mocker) -> DirectThingClient: :param mocker: the mocker test fixture from ``pytest-mock``\ . :returns: a ``DirectThingClient`` subclass wrapping a ``Counter``\ . """ - counter = Counter() - counter._labthings_blocking_portal = mocker.Mock(["start_task_soon"]) + counter = create_thing_without_server(Counter) - CounterClient = direct_thing_client_class(Counter, "/counter") + CounterClient = direct_thing_client_class(Counter, "counter") class StandaloneCounterClient(CounterClient): def __init__(self, wrapped): @@ -54,7 +54,7 @@ def __init__(self, wrapped): return StandaloneCounterClient(counter) -CounterDep = lt.deps.direct_thing_client_dependency(Counter, "/counter/") +CounterDep = lt.deps.direct_thing_client_dependency(Counter, "counter") RawCounterDep = lt.deps.raw_thing_dependency(Counter) @@ -145,8 +145,8 @@ def test_directthingclient_in_server(action): This uses the internal thing client mechanism. """ server = lt.ThingServer() - server.add_thing(Counter(), "/counter") - server.add_thing(Controller(), "/controller") + server.add_thing("counter", Counter) + server.add_thing("controller", Controller) with TestClient(server.app) as client: r = client.post(f"/controller/{action}") invocation = poll_task(client, r.json()) diff --git a/tests/test_docs.py b/tests/test_docs.py index 7b705bc3..b27344f5 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,6 +1,7 @@ from pathlib import Path from runpy import run_path from fastapi.testclient import TestClient +import pytest from labthings_fastapi import ThingClient from .test_server_cli import MonitoredProcess @@ -18,6 +19,7 @@ def run_quickstart_counter(): run_path(str(docs / "quickstart" / "counter.py")) +@pytest.mark.slow def test_quickstart_counter(): """Check we can create a server from the command line""" p = MonitoredProcess(target=run_quickstart_counter) diff --git a/tests/test_endpoint_decorator.py b/tests/test_endpoint_decorator.py index 6d4784b6..6360704e 100644 --- a/tests/test_endpoint_decorator.py +++ b/tests/test_endpoint_decorator.py @@ -1,6 +1,5 @@ from fastapi.testclient import TestClient from pydantic import BaseModel -import pytest import labthings_fastapi as lt @@ -26,8 +25,8 @@ def post_method(self, body: PostBodyModel) -> str: def test_endpoints(): """Check endpoints may be added to the app and work as expected.""" server = lt.ThingServer() - thing = MyThing() - server.add_thing(thing, "/thing") + server.add_thing("thing", MyThing) + thing = server.things["thing"] with TestClient(server.app) as client: # Check the function works when used directly assert thing.path_from_name() == "path_from_name" @@ -50,15 +49,3 @@ def test_endpoints(): r = client.post("/thing/path_from_path", json={"a": 1, "b": 2}) r.raise_for_status() assert r.json() == "post_method 1 2" - - -def test_endpoint_notconnected(mocker): - """Check for the correct error if we add endpoints prematurely. - - We should get this error if we call ``add_to_fastapi`` on an endpoint - where the `.Thing` does not have a valid ``path`` attribute. - """ - thing = MyThing() - - with pytest.raises(lt.exceptions.NotConnectedToServerError): - MyThing.get_method.add_to_fastapi(mocker.Mock(), thing) diff --git a/tests/test_example_thing.py b/tests/test_example_thing.py index 14ebe23f..1598c88c 100644 --- a/tests/test_example_thing.py +++ b/tests/test_example_thing.py @@ -6,21 +6,11 @@ ) import pytest - -class DummyBlockingPortal: - """A dummy blocking portal for testing - - This is a blocking portal that doesn't actually do anything. - In the future, we should improve LabThings so this is not required. - """ - - def start_task_soon(self, func, *args, **kwargs): - pass +from labthings_fastapi.thing_server_interface import create_thing_without_server def test_mything(): - thing = MyThing() - thing._labthings_blocking_portal = DummyBlockingPortal() + thing = create_thing_without_server(MyThing) assert isinstance(thing, MyThing) assert thing.counter == 0 ret = thing.anaction(3, 1, title="MyTitle", attempts=["a", "b", "c"]) @@ -40,7 +30,7 @@ def test_mything(): def test_thing_with_broken_affordances(): - thing = ThingWithBrokenAffordances() + thing = create_thing_without_server(ThingWithBrokenAffordances) assert isinstance(thing, ThingWithBrokenAffordances) with pytest.raises(RuntimeError): thing.broken_action() @@ -50,11 +40,11 @@ def test_thing_with_broken_affordances(): def test_thing_that_cannot_instantiate(): with pytest.raises(RuntimeError): - ThingThatCantInstantiate() + create_thing_without_server(ThingThatCantInstantiate) def test_thing_that_cannot_start(): - thing = ThingThatCantStart() + thing = create_thing_without_server(ThingThatCantStart) assert isinstance(thing, ThingThatCantStart) with pytest.raises(RuntimeError): with thing: diff --git a/tests/test_fallback.py b/tests/test_fallback.py index d3cd5f60..257ccda8 100644 --- a/tests/test_fallback.py +++ b/tests/test_fallback.py @@ -49,8 +49,8 @@ def test_fallback_with_server(): html = response.text assert "Something went wrong" in html assert "No logging info available" in html - assert "thing1/" in html - assert "thing2/" in html + assert "thing1" in html + assert "thing2" in html def test_fallback_with_log(): diff --git a/tests/test_locking_decorator.py b/tests/test_locking_decorator.py index 157dfaae..798bfd1d 100644 --- a/tests/test_locking_decorator.py +++ b/tests/test_locking_decorator.py @@ -7,6 +7,7 @@ import pytest import labthings_fastapi as lt +from labthings_fastapi.thing_server_interface import create_thing_without_server from .temp_client import poll_task @@ -35,8 +36,9 @@ class LockedExample(lt.Thing): flag: bool = lt.property(default=False) - def __init__(self): + def __init__(self, **kwargs): """Initialise the lock.""" + super().__init__(**kwargs) self._lock = RLock() # This lock is used by @requires_lock self._event = Event() # This is used to keep tests quick # by stopping waits as soon as they are no longer needed @@ -70,8 +72,7 @@ def wait_with_flag(self, time: float = 1) -> None: @pytest.fixture def thing(mocker) -> LockedExample: """Instantiate the LockedExample thing.""" - thing = LockedExample() - thing._labthings_blocking_portal = mocker.Mock() + thing = create_thing_without_server(LockedExample) return thing @@ -115,8 +116,8 @@ def echo_via_client(client): def test_locking_in_server(): """Check the lock works within LabThings.""" server = lt.ThingServer() - thing = LockedExample() - server.add_thing(thing, "/thing") + server.add_thing("thing", LockedExample) + thing = server.things["thing"] with TestClient(server.app) as client: # Start a long task r1 = client.post("/thing/wait_wrapper", json={}) diff --git a/tests/test_mjpeg_stream.py b/tests/test_mjpeg_stream.py index b88df318..a780d10c 100644 --- a/tests/test_mjpeg_stream.py +++ b/tests/test_mjpeg_stream.py @@ -3,6 +3,7 @@ import time from PIL import Image from fastapi.testclient import TestClient +import pytest import labthings_fastapi as lt @@ -35,16 +36,23 @@ def _make_images(self): i = 0 while self._streaming and (i < self.frame_limit or self.frame_limit < 0): - self.stream.add_frame( - jpegs[i % len(jpegs)], self._labthings_blocking_portal - ) + self.stream.add_frame(jpegs[i % len(jpegs)]) time.sleep(1 / self.framerate) i = i + 1 - self.stream.stop(self._labthings_blocking_portal) + self.stream.stop() self._streaming = False -def test_mjpeg_stream(): +@pytest.fixture +def client(): + """Yield a test client connected to a ThingServer""" + server = lt.ThingServer() + server.add_thing("telly", Telly) + with TestClient(server.app) as client: + yield client + + +def test_mjpeg_stream(client): """Verify the MJPEG stream contains at least one frame marker. A limitation of the TestClient is that it can't actually stream. @@ -55,24 +63,19 @@ def test_mjpeg_stream(): but it might be possible in the future to check there are three images there. """ - server = lt.ThingServer() - telly = Telly() - server.add_thing(telly, "telly") - with TestClient(server.app) as client: - with client.stream("GET", "/telly/stream") as stream: - stream.raise_for_status() - received = 0 - for b in stream.iter_bytes(): - received += 1 - assert b.startswith(b"--frame") + with client.stream("GET", "/telly/stream") as stream: + stream.raise_for_status() + received = 0 + for b in stream.iter_bytes(): + received += 1 + assert b.startswith(b"--frame") if __name__ == "__main__": import uvicorn server = lt.ThingServer() - telly = Telly() + telly = server.add_thing("telly", Telly) telly.framerate = 6 telly.frame_limit = -1 - server.add_thing(telly, "telly") uvicorn.run(server.app, port=5000) diff --git a/tests/test_numpy_type.py b/tests/test_numpy_type.py index b9b1d0cc..10fb5bf8 100644 --- a/tests/test_numpy_type.py +++ b/tests/test_numpy_type.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, RootModel import numpy as np +from labthings_fastapi.thing_server_interface import create_thing_without_server from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict import labthings_fastapi as lt @@ -63,19 +64,21 @@ class Model(BaseModel): class MyNumpyThing(lt.Thing): + """A thing that uses numpy types.""" + @lt.thing_action def action_with_arrays(self, a: NDArray) -> NDArray: return a * 2 def test_thing_description(): - thing = MyNumpyThing() - # We must mock a path, or it can't generate a Thing Description. - thing.path = "/mynumpything" + """Make sure the TD validates when numpy types are used.""" + thing = create_thing_without_server(MyNumpyThing) assert thing.validate_thing_description() is None def test_denumpifying_dict(): + """Check DenumpifyingDict converts arrays to lists.""" d = DenumpifyingDict( root={ "a": np.array([1, 2, 3]), @@ -94,6 +97,7 @@ def test_denumpifying_dict(): def test_rootmodel(): + """Check that RootModels with NDArray convert between array and list.""" for input in [[0, 1, 2], np.arange(3)]: m = ArrayModel(root=input) assert isinstance(m.root, np.ndarray) diff --git a/tests/test_properties.py b/tests/test_properties.py index 9380798e..a50a8b4a 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,13 +1,11 @@ from threading import Thread from typing import Any -from pytest import raises from pydantic import BaseModel, RootModel from fastapi.testclient import TestClient import pytest import labthings_fastapi as lt -from labthings_fastapi.exceptions import NotConnectedToServerError from .temp_client import poll_task @@ -53,9 +51,8 @@ def toggle_boolprop_from_thread(self): @pytest.fixture def server(): - thing = PropertyTestThing() server = lt.ThingServer() - server.add_thing(thing, "/thing") + server.add_thing("thing", PropertyTestThing) return server @@ -233,15 +230,3 @@ def test_setting_from_thread(server): r = client.get("/thing/boolprop") assert r.status_code == 200 assert r.json() is True - - -def test_setting_without_event_loop(server): - """Test that an exception is raised if updating a DataProperty - without connecting the Thing to a running server with an event loop. - """ - # This test may need to change, if we change the intended behaviour - # Currently it should never be necessary to change properties from the - # main thread, so we raise an error if you try to do so - thing = PropertyTestThing() - with raises(NotConnectedToServerError): - thing.boolprop = False # Can't call it until the event loop's running diff --git a/tests/test_server.py b/tests/test_server.py index f3edb8fb..bf12abb3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,39 +7,10 @@ """ import pytest -from fastapi.testclient import TestClient -import labthings_fastapi as lt from labthings_fastapi import server as ts -def test_thing_with_blocking_portal_error(mocker): - """Test that a thing with a _labthings_blocking_portal causes an error. - - The blocking portal is added when the server starts. If one is there already, - this is an error and the server should fail to start. - - As this ends up in an async context manager, the exception will be wrapped - in an ExceptionGroup, hence the slightly complicated code to test the exception. - - This is not an error condition that we expect to happen often. Handling - it more elegantly would result in enough additional code that the burden of - maintaining and testing that code outweighs the benefit of a more elegant - error message. - """ - - class Example(lt.Thing): - def __init__(self): - super().__init__() - self._labthings_blocking_portal = mocker.Mock() - - server = lt.ThingServer() - server.add_thing(Example(), "/example") - with pytest.RaisesGroup(pytest.RaisesExc(RuntimeError, match="blocking portal")): - with TestClient(server.app): - pass - - def test_server_from_config_non_thing_error(): """Test a typeerror is raised if something that's not a Thing is added.""" with pytest.raises(TypeError, match="not a Thing"): - ts.server_from_config({"things": {"/thingone": {"class": "builtins:object"}}}) + ts.server_from_config({"things": {"thingone": {"class": "builtins:object"}}}) diff --git a/tests/test_server_cli.py b/tests/test_server_cli.py index 06a3d973..0ec3a84b 100644 --- a/tests/test_server_cli.py +++ b/tests/test_server_cli.py @@ -4,6 +4,7 @@ import tempfile from pytest import raises +import pytest from labthings_fastapi import ThingServer from labthings_fastapi.server import server_from_config @@ -93,12 +94,14 @@ def check_serve_from_cli(args: list[str] | None = None): p.run_monitored(terminate_outputs=["Application startup complete"]) +@pytest.mark.slow def test_serve_from_cli_with_config_json(): """Check we can create a server from the command line, using JSON""" config_json = json.dumps(CONFIG) check_serve_from_cli(["-j", config_json]) +@pytest.mark.slow def test_serve_from_cli_with_config_file(): """Check we can create a server from the command line, using a file""" config_json = json.dumps(CONFIG) @@ -109,11 +112,13 @@ def test_serve_from_cli_with_config_file(): check_serve_from_cli(["-c", temp.name]) +@pytest.mark.slow def test_serve_with_no_config_without_multiprocessing(): with raises(RuntimeError): serve_from_cli([], dry_run=True) +@pytest.mark.slow def test_serve_with_no_config(): """Check an empty config fails, using multiprocessing. This is important, because if it passes it means our tests above @@ -123,6 +128,7 @@ def test_serve_with_no_config(): check_serve_from_cli([]) +@pytest.mark.slow def test_invalid_thing(): """Check it fails for invalid things""" config_json = json.dumps( @@ -136,6 +142,7 @@ def test_invalid_thing(): check_serve_from_cli(["-j", config_json]) +@pytest.mark.slow def test_fallback(): """test the fallback option @@ -152,12 +159,14 @@ def test_fallback(): check_serve_from_cli(["-j", config_json, "--fallback"]) +@pytest.mark.slow def test_invalid_config(): """Check it fails for invalid config""" with raises(FileNotFoundError): check_serve_from_cli(["-c", "non_existent_file.json"]) +@pytest.mark.slow def test_thing_that_cannot_start(): """Check it fails for a thing that can't start""" config_json = json.dumps( @@ -169,7 +178,3 @@ def test_thing_that_cannot_start(): ) with raises(SystemExit): check_serve_from_cli(["-j", config_json]) - - -if __name__ == "__main__": - test_serve_from_cli_with_config_json() diff --git a/tests/test_settings.py b/tests/test_settings.py index f31e3292..b138dcd0 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,6 +1,7 @@ from threading import Thread import tempfile import json +from typing import Any import pytest import os import logging @@ -8,15 +9,15 @@ from fastapi.testclient import TestClient import labthings_fastapi as lt -from labthings_fastapi.exceptions import NotConnectedToServerError +from labthings_fastapi.thing_server_interface import create_thing_without_server from .temp_client import poll_task class ThingWithSettings(lt.Thing): """A test `.Thing` with some settings and actions.""" - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # Initialize functional settings with default values self._floatsetting: float = 1.0 self._localonlysetting = "Local-only default." @@ -88,7 +89,7 @@ def toggle_boolsetting_from_thread(self): ThingWithSettingsClientDep = lt.deps.direct_thing_client_dependency( - ThingWithSettings, "/thing/" + ThingWithSettings, "thing" ) ThingWithSettingsDep = lt.deps.raw_thing_dependency(ThingWithSettings) @@ -174,16 +175,6 @@ def _settings_dict( } -@pytest.fixture -def thing(): - return ThingWithSettings() - - -@pytest.fixture -def client_thing(): - return ClientThing() - - @pytest.fixture def server(): with tempfile.TemporaryDirectory() as tempdir: @@ -192,8 +183,9 @@ def server(): yield lt.ThingServer(settings_folder=tempdir) -def test_setting_available(thing): +def test_setting_available(): """Check default settings are available before connecting to server""" + thing = create_thing_without_server(ThingWithSettings) assert not thing.boolsetting assert thing.stringsetting == "foo" assert thing.floatsetting == 1.0 @@ -201,13 +193,13 @@ def test_setting_available(thing): assert thing.dictsetting == {"a": 1, "b": 2} -def test_functional_settings_save(thing, server): +def test_functional_settings_save(server): """Check updated settings are saved to disk ``floatsetting`` is a functional setting, we should also test a `.DataSetting` for completeness.""" setting_file = _get_setting_file(server, "/thing") - server.add_thing(thing, "/thing") + server.add_thing("thing", ThingWithSettings) # No setting file created when first added assert not os.path.isfile(setting_file) with TestClient(server.app) as client: @@ -227,13 +219,13 @@ def test_functional_settings_save(thing, server): assert json.load(file_obj) == _settings_dict(floatsetting=2.0) -def test_data_settings_save(thing, server): +def test_data_settings_save(server): """Check updated settings are saved to disk This uses ``intsetting`` which is a `.DataSetting` so it tests a different code path to the functional setting above.""" setting_file = _get_setting_file(server, "/thing") - server.add_thing(thing, "/thing") + server.add_thing("thing", ThingWithSettings) # The settings file should not be created yet - it's created the # first time we write to a setting. assert not os.path.isfile(setting_file) @@ -264,7 +256,7 @@ def test_data_settings_save(thing, server): "method", ["http", "direct_thing_client", "direct"], ) -def test_readonly_setting(thing, client_thing, server, endpoint, value, method): +def test_readonly_setting(server, endpoint, value, method): """Check read-only functional settings cannot be set remotely. Functional settings must always have a setter, and will be @@ -280,8 +272,8 @@ def test_readonly_setting(thing, client_thing, server, endpoint, value, method): block of code inside the ``with`` block each time. """ setting_file = _get_setting_file(server, "/thing") - server.add_thing(thing, "/thing") - server.add_thing(client_thing, "/client_thing") + server.add_thing("thing", ThingWithSettings) + server.add_thing("client_thing", ClientThing) # No setting file created when first added assert not os.path.isfile(setting_file) @@ -327,10 +319,10 @@ def test_readonly_setting(thing, client_thing, server, endpoint, value, method): assert not os.path.isfile(setting_file) # No file created -def test_settings_dict_save(thing, server): +def test_settings_dict_save(server): """Check settings are saved if the dict is updated in full""" setting_file = _get_setting_file(server, "/thing") - server.add_thing(thing, "/thing") + thing = server.add_thing("thing", ThingWithSettings) # No setting file created when first added assert not os.path.isfile(setting_file) with TestClient(server.app): @@ -341,24 +333,14 @@ def test_settings_dict_save(thing, server): assert json.load(file_obj) == _settings_dict(dictsetting={"c": 3}) -def test_premature_Settings_save(thing): - """Check a helpful error is raised if the settings path is missing. - - The settings path is only set when a thing is connected to a server, - so if we use an unconnected thing, we should see the error. - """ - with pytest.raises(NotConnectedToServerError): - thing.save_settings() - - -def test_settings_dict_internal_update(thing, server): +def test_settings_dict_internal_update(server): """Confirm settings are not saved if the internal value of a dictionary is updated This behaviour is not ideal, but it is documented. If the behaviour is updated then the documentation should be updated and this test removed """ setting_file = _get_setting_file(server, "/thing") - server.add_thing(thing, "/thing") + thing = server.add_thing("thing", ThingWithSettings) # No setting file created when first added assert not os.path.isfile(setting_file) with TestClient(server.app): @@ -368,7 +350,7 @@ def test_settings_dict_internal_update(thing, server): assert not os.path.isfile(setting_file) -def test_settings_load(thing, server): +def test_settings_load(server): """Check settings can be loaded from disk when added to server""" setting_file = _get_setting_file(server, "/thing") setting_json = json.dumps(_settings_dict(floatsetting=3.0, stringsetting="bar")) @@ -377,13 +359,13 @@ def test_settings_load(thing, server): with open(setting_file, "w", encoding="utf-8") as file_obj: file_obj.write(setting_json) # Add thing to server and check new settings are loaded - server.add_thing(thing, "/thing") + thing = server.add_thing("thing", ThingWithSettings) assert not thing.boolsetting assert thing.stringsetting == "bar" assert thing.floatsetting == 3.0 -def test_load_extra_settings(thing, server, caplog): +def test_load_extra_settings(server, caplog): """Load from setting file. Extra setting in file should create a warning.""" setting_file = _get_setting_file(server, "/thing") setting_dict = _settings_dict(floatsetting=3.0, stringsetting="bar") @@ -396,7 +378,7 @@ def test_load_extra_settings(thing, server, caplog): with caplog.at_level(logging.WARNING): # Add thing to server - server.add_thing(thing, "/thing") + thing = server.add_thing("thing", ThingWithSettings) assert len(caplog.records) == 1 assert caplog.records[0].levelname == "WARNING" assert caplog.records[0].name == "labthings_fastapi.thing" @@ -407,7 +389,7 @@ def test_load_extra_settings(thing, server, caplog): assert thing.floatsetting == 3.0 -def test_try_loading_corrupt_settings(thing, server, caplog): +def test_try_loading_corrupt_settings(server, caplog): """Load from setting file. Extra setting in file should create a warning.""" setting_file = _get_setting_file(server, "/thing") setting_dict = _settings_dict(floatsetting=3.0, stringsetting="bar") @@ -421,7 +403,7 @@ def test_try_loading_corrupt_settings(thing, server, caplog): with caplog.at_level(logging.WARNING): # Add thing to server - server.add_thing(thing, "/thing") + thing = server.add_thing("thing", ThingWithSettings) assert len(caplog.records) == 1 assert caplog.records[0].levelname == "WARNING" assert caplog.records[0].name == "labthings_fastapi.thing" diff --git a/tests/test_thing.py b/tests/test_thing.py index 88c2e100..375c33e5 100644 --- a/tests/test_thing.py +++ b/tests/test_thing.py @@ -1,26 +1,16 @@ -import pytest from labthings_fastapi.example_things import MyThing from labthings_fastapi import ThingServer +from labthings_fastapi.thing_server_interface import create_thing_without_server def test_td_validates(): """This will raise an exception if it doesn't validate OK""" - thing = MyThing() - thing.path = "/mything" # can't generate a TD without a path + thing = create_thing_without_server(MyThing) assert thing.validate_thing_description() is None def test_add_thing(): """Check that thing can be added to the server""" - thing = MyThing() server = ThingServer() - server.add_thing(thing, "/thing") - - -def test_add_naughty_thing(): - """Check that a thing trying to access server resources - using .. is not allowed""" - thing = MyThing() - server = ThingServer() - with pytest.raises(ValueError): - server.add_thing(thing, "/../../../../bin") + server.add_thing("thing", MyThing) + assert isinstance(server.things["thing"], MyThing) diff --git a/tests/test_thing_dependencies.py b/tests/test_thing_dependencies.py index a9fd7a30..ef7b81bd 100644 --- a/tests/test_thing_dependencies.py +++ b/tests/test_thing_dependencies.py @@ -24,7 +24,7 @@ def action_one_internal(self) -> str: return self.ACTION_ONE_RESULT -ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "/thing_one/") +ThingOneDep = lt.deps.direct_thing_client_dependency(ThingOne, "thing_one") class ThingTwo(lt.Thing): @@ -39,7 +39,7 @@ def action_two_a(self, thing_one: ThingOneDep) -> str: return thing_one.action_one() -ThingTwoDep = lt.deps.direct_thing_client_dependency(ThingTwo, "/thing_two/") +ThingTwoDep = lt.deps.direct_thing_client_dependency(ThingTwo, "thing_two") class ThingThree(lt.Thing): @@ -57,8 +57,8 @@ def dependency_names(func: callable) -> list[str]: def test_direct_thing_dependency(): """Check that direct thing clients are distinct classes""" - ThingOneClient = direct_thing_client_class(ThingOne, "/thing_one/") - ThingTwoClient = direct_thing_client_class(ThingTwo, "/thing_two/") + ThingOneClient = direct_thing_client_class(ThingOne, "thing_one") + ThingTwoClient = direct_thing_client_class(ThingTwo, "thing_two") print(f"{ThingOneClient}: ThingOneClient{inspect.signature(ThingOneClient)}") for k in dir(ThingOneClient): if k.startswith("__"): @@ -81,8 +81,8 @@ def test_interthing_dependency(): This uses the internal thing client mechanism. """ server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - server.add_thing(ThingTwo(), "/thing_two") + server.add_thing("thing_one", ThingOne) + server.add_thing("thing_two", ThingTwo) with TestClient(server.app) as client: r = client.post("/thing_two/action_two") invocation = poll_task(client, r.json()) @@ -97,9 +97,9 @@ def test_interthing_dependency_with_dependencies(): dependency injection for the called action """ server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - server.add_thing(ThingTwo(), "/thing_two") - server.add_thing(ThingThree(), "/thing_three") + server.add_thing("thing_one", ThingOne) + server.add_thing("thing_two", ThingTwo) + server.add_thing("thing_three", ThingThree) with TestClient(server.app) as client: r = client.post("/thing_three/action_three") r.raise_for_status() @@ -122,8 +122,8 @@ def action_two(self, thing_one: ThingOneDep) -> str: return thing_one.action_one() server = lt.ThingServer() - server.add_thing(ThingOne(), "/thing_one") - server.add_thing(ThingTwo(), "/thing_two") + server.add_thing("thing_one", ThingOne) + server.add_thing("thing_two", ThingTwo) with TestClient(server.app) as client: r = client.post("/thing_two/action_two") invocation = poll_task(client, r.json()) @@ -153,7 +153,7 @@ def action_five(self, thing_two: ThingTwoDep) -> str: return thing_two.action_two() with pytest.raises(lt.client.in_server.DependencyNameClashError): - lt.deps.direct_thing_client_dependency(ThingFour, "/thing_four/") + lt.deps.direct_thing_client_dependency(ThingFour, "thing_four") def check_request(): diff --git a/tests/test_thing_lifecycle.py b/tests/test_thing_lifecycle.py index 1024cd5c..6be75547 100644 --- a/tests/test_thing_lifecycle.py +++ b/tests/test_thing_lifecycle.py @@ -16,14 +16,14 @@ def __exit__(self, *args): self.alive = False -thing = TestThing() server = lt.ThingServer() -server.add_thing(thing, "/thing") +thing = server.add_thing("thing", TestThing) def test_thing_alive(): assert thing.alive is False with TestClient(server.app) as client: + assert thing.alive is True r = client.get("/thing/alive") assert r.json() is True assert thing.alive is False diff --git a/tests/test_thing_server_interface.py b/tests/test_thing_server_interface.py new file mode 100644 index 00000000..82914d01 --- /dev/null +++ b/tests/test_thing_server_interface.py @@ -0,0 +1,168 @@ +"""Test the ThingServerInterface class and associated features.""" + +import gc +import os +import tempfile + +from fastapi.testclient import TestClient +import pytest + +import labthings_fastapi as lt +from labthings_fastapi.exceptions import ServerNotRunningError +from labthings_fastapi import thing_server_interface as tsi + + +NAME = "testname" +EXAMPLE_THING_STATE = {"foo": "bar"} + + +class ExampleThing(lt.Thing): + @lt.property + def thing_state(self): + return EXAMPLE_THING_STATE + + +@pytest.fixture +def server(): + """Return a LabThings server""" + with tempfile.TemporaryDirectory() as dir: + server = lt.ThingServer(settings_folder=dir) + server.add_thing("example", ExampleThing) + yield server + + +@pytest.fixture +def interface(server): + """Return a ThingServerInterface, connected to a server.""" + return tsi.ThingServerInterface(server, NAME) + + +@pytest.fixture +def mockinterface(): + """Return a MockThingServerInterface.""" + return tsi.MockThingServerInterface(NAME) + + +def test_get_server(server, interface): + """Check the server is retrieved correctly. + + This also tests for the right error if it's missing. + """ + assert interface._get_server() is server + + +def test_get_server_error(): + """Ensure a helpful error is raised if the server is deleted. + + This is an error condition that I would find surprising if it + ever occurred, but it's worth checking. + """ + server = lt.ThingServer() + interface = tsi.ThingServerInterface(server, NAME) + assert interface._get_server() is server + del server + gc.collect() + with pytest.raises(tsi.ThingServerMissingError): + interface._get_server() + + +def test_start_async_task_soon(server, interface): + """Check async tasks may be run in the event loop.""" + mutable = [False] + + async def set_mutable(val): + mutable[0] = val + + with pytest.raises(ServerNotRunningError): + # You can't run async code unless the server + # is running: this should raise a helpful + # error. + interface.start_async_task_soon(set_mutable, True) + + with TestClient(server.app) as _: + # TestClient starts an event loop in the background + # so this should work + interface.start_async_task_soon(set_mutable, True) + + # Check the async code really did run. + assert mutable[0] is True + + +def test_settings_folder(server, interface): + """Check the interface returns the right settings folder.""" + assert interface.settings_folder == os.path.join(server.settings_folder, NAME) + + +def test_settings_file_path(server, interface): + """Check the settings file path is as expected.""" + assert interface.settings_file_path == os.path.join( + server.settings_folder, NAME, "settings.json" + ) + + +def test_name(server, interface): + """Check the thing's name is passed on correctly.""" + assert interface.name is NAME + assert server.things["example"]._thing_server_interface.name == "example" + + +def test_path(interface, server): + """Check the thing's path is generated predictably.""" + with pytest.raises(KeyError): + # `interface` is for a thing called NAME, which isn't + # added to the server, so when we try to get its path + # it should raise an error. + _ = interface.path + # If we put something in the dictionary of things, it should work. + server._things[NAME] = None + assert interface.path == f"/{NAME}/" + # We can also check the example thing, which is actually added to the server. + # This doesn't need any mocking. + assert server.things["example"].path == "/example/" + + +def test_get_thing_states(interface): + """Check thing metadata is retrieved properly.""" + states = interface.get_thing_states() + assert states == {"example": EXAMPLE_THING_STATE} + + +def test_mock_start_async_task_soon(mockinterface): + """Check nothing happens when we run an async task.""" + mutable = [False] + + async def set_mutable(val): + mutable[0] = val + + mockinterface.start_async_task_soon(set_mutable, True) + + # Check the async code didn't run + assert mutable[0] is False + + +def test_mock_settings_folder(mockinterface): + """Check a temporary settings folder is provided.""" + # The temp folder should be created when accessed, + # so is None initially. + assert mockinterface._settings_tempdir is None + f = mockinterface.settings_folder + assert f == mockinterface._settings_tempdir.name + assert mockinterface.settings_file_path == os.path.join(f, "settings.json") + + +def test_mock_path(mockinterface): + """Check the path is generated predictably.""" + assert mockinterface.path == f"/{NAME}/" + + +def test_mock_get_thing_states(mockinterface): + """Check an empty dictionary is returned.""" + assert mockinterface.get_thing_states() == {} + + +def test_create_thing_without_server(): + """Check the test harness for creating things without a server.""" + example = tsi.create_thing_without_server(ExampleThing) + assert isinstance(example, ExampleThing) + assert example.path == "/examplething/" + assert isinstance(example._thing_server_interface, tsi.MockThingServerInterface) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c6e8d588..00d7947a 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -5,6 +5,7 @@ PropertyNotObservableError, InvocationCancelledError, ) +from labthings_fastapi.thing_server_interface import create_thing_without_server class ThingWithProperties(lt.Thing): @@ -52,16 +53,10 @@ def cancel_myself(self): @pytest.fixture -def thing(): - """Instantiate and return a test Thing.""" - return ThingWithProperties() - - -@pytest.fixture -def server(thing): +def server(): """Create a server, and add a MyThing test Thing to it.""" server = lt.ThingServer() - server.add_thing(thing, "/thing") + server.add_thing("thing", ThingWithProperties) return server @@ -74,7 +69,7 @@ def client(server): @pytest.fixture def ws(client): - """Yield a websocket connection to a server hosting a MyThing(). + """Yield a websocket connection to a server hosting a MyThing. This ensures the websocket is properly closed after the test, and avoids lots of indent levels. @@ -86,6 +81,12 @@ def ws(client): ws.close(1000) +@pytest.fixture +def thing(): + """Create a ThingWithProperties, not connected to a server.""" + return create_thing_without_server(ThingWithProperties) + + def test_observing_dataprop(thing, mocker): """Check `observe_property` is OK on a data property. diff --git a/typing_tests/thing_definitions.py b/typing_tests/thing_definitions.py index 553e4866..01c4c27a 100644 --- a/typing_tests/thing_definitions.py +++ b/typing_tests/thing_definitions.py @@ -26,6 +26,8 @@ from typing_extensions import assert_type import typing +from labthings_fastapi.thing_server_interface import create_thing_without_server + def optional_int_factory() -> int | None: """Return an optional int.""" @@ -107,7 +109,8 @@ class TestPropertyDefaultsMatch(lt.Thing): # Check that the type hints on an instance of the class are correct. -test_defaults_match = TestPropertyDefaultsMatch() +test_defaults_match = create_thing_without_server(TestPropertyDefaultsMatch) +assert_type(test_defaults_match, TestPropertyDefaultsMatch) assert_type(test_defaults_match.intprop, int) assert_type(test_defaults_match.intprop2, int) assert_type(test_defaults_match.intprop3, int) @@ -167,7 +170,8 @@ class TestExplicitDescriptor(lt.Thing): # Check instance attributes are typed correctly. -test_explicit_descriptor = TestExplicitDescriptor() +test_explicit_descriptor = create_thing_without_server(TestExplicitDescriptor) +assert_type(test_explicit_descriptor, TestExplicitDescriptor) assert_type(test_explicit_descriptor.intprop1, int) assert_type(test_explicit_descriptor.intprop2, int) assert_type(test_explicit_descriptor.intprop3, int) @@ -270,7 +274,8 @@ def strprop(self, val: str) -> None: # Don't check ``strprop`` because it caused an error and thus will # not have the right type, even though the error is ignored. -test_functional_property = TestFunctionalProperty() +test_functional_property = create_thing_without_server(TestFunctionalProperty) +assert_type(test_functional_property, TestFunctionalProperty) assert_type(test_functional_property.intprop1, int) assert_type(test_functional_property.intprop2, int) assert_type(test_functional_property.intprop3, int)