diff --git a/src/labthings_fastapi/actions/__init__.py b/src/labthings_fastapi/actions/__init__.py index 7532a87e..0e224ff9 100644 --- a/src/labthings_fastapi/actions/__init__.py +++ b/src/labthings_fastapi/actions/__init__.py @@ -19,7 +19,7 @@ InvocationCancelledError, invocation_logger, ) -from ..outputs.blob import BlobIOContextDep +from ..outputs.blob import Blob, BlobDataManager if TYPE_CHECKING: # We only need these imports for type hints, so this avoids circular imports. @@ -40,6 +40,7 @@ def __init__( self, action: ActionDescriptor, thing: Thing, + blob_data_manager: BlobDataManager, input: Optional[BaseModel] = None, dependencies: Optional[dict[str, Any]] = None, default_stop_timeout: float = 5, @@ -56,6 +57,8 @@ def __init__( self.dependencies = dependencies if dependencies is not None else {} self.cancel_hook = cancel_hook + self._blob_data_manager = blob_data_manager + # A UUID for the Invocation (not the same as the threading.Thread ident) self._ID = id if id is not None else uuid.uuid4() # Task ID @@ -181,6 +184,9 @@ def run(self): ret = action.__get__(thing)(**kwargs, **self.dependencies) with self._status_lock: + if isinstance(ret, Blob): + blob_id = self._blob_data_manager.add_blob(ret.data) + ret.href = f"/blob/{blob_id}" self._return_value = ret self._status = InvocationStatus.COMPLETED self.action.emit_changed_event(self.thing, self._status) @@ -241,7 +247,8 @@ def emit(self, record): class ActionManager: """A class to manage a collection of actions""" - def __init__(self): + def __init__(self, server): + self._server = server self._invocations = {} self._invocations_lock = Lock() @@ -271,6 +278,7 @@ def invoke_action( dependencies=dependencies, id=id, cancel_hook=cancel_hook, + blob_data_manager=self._server.blob_data_manager, ) self.append_invocation(thread) thread.start() @@ -312,7 +320,7 @@ def attach_to_app(self, app: FastAPI): """Add /action_invocations and /action_invocation/{id} endpoints to FastAPI""" @app.get(ACTION_INVOCATIONS_PATH, response_model=list[InvocationModel]) - def list_all_invocations(request: Request, _blob_manager: BlobIOContextDep): + def list_all_invocations(request: Request): return self.list_invocations(as_responses=True, request=request) @app.get( @@ -320,9 +328,7 @@ def list_all_invocations(request: Request, _blob_manager: BlobIOContextDep): response_model=InvocationModel, responses={404: {"description": "Invocation ID not found"}}, ) - def action_invocation( - id: uuid.UUID, request: Request, _blob_manager: BlobIOContextDep - ): + def action_invocation(id: uuid.UUID, request: Request): try: with self._invocations_lock: return self._invocations[id].response(request=request) @@ -346,7 +352,7 @@ def action_invocation( 503: {"description": "No result is available for this invocation"}, }, ) - def action_invocation_output(id: uuid.UUID, _blob_manager: BlobIOContextDep): + def action_invocation_output(id: uuid.UUID): """Get the output of an action invocation This returns just the "output" component of the action invocation. If the diff --git a/src/labthings_fastapi/descriptors/action.py b/src/labthings_fastapi/descriptors/action.py index d2bf1084..0b577a84 100644 --- a/src/labthings_fastapi/descriptors/action.py +++ b/src/labthings_fastapi/descriptors/action.py @@ -23,7 +23,7 @@ input_model_from_signature, return_type, ) -from ..outputs.blob import BlobIOContextDep + from ..thing_description import type_to_dataschema from ..thing_description.model import ActionAffordance, ActionOp, Form, Union from ..utilities import labthings_data, get_blocking_portal @@ -178,7 +178,6 @@ def add_to_fastapi(self, app: FastAPI, thing: Thing): # the function to the decorator. def start_action( action_manager: ActionManagerContextDep, - _blob_manager: BlobIOContextDep, request: Request, body, id: InvocationID, diff --git a/src/labthings_fastapi/outputs/blob.py b/src/labthings_fastapi/outputs/blob.py index 23ff1f8f..009b8892 100644 --- a/src/labthings_fastapi/outputs/blob.py +++ b/src/labthings_fastapi/outputs/blob.py @@ -43,33 +43,24 @@ def get_image(self) -> MyImageBlob: """ from __future__ import annotations -from contextvars import ContextVar import io import os -import re import shutil from typing import ( - Annotated, - Callable, Literal, Mapping, Optional, ) from weakref import WeakValueDictionary -from typing_extensions import TypeAlias from tempfile import TemporaryDirectory import uuid -from fastapi import FastAPI, Depends, Request +from fastapi import FastAPI from fastapi.responses import FileResponse, Response from pydantic import ( BaseModel, - create_model, model_serializer, - model_validator, ) -from labthings_fastapi.dependencies.thing_server import find_thing_server -from starlette.exceptions import HTTPException from typing_extensions import Self, Protocol, runtime_checkable @@ -203,88 +194,25 @@ class Blob(BaseModel): documentation. """ - href: str + href: str = "blob://local" """The URL where the data may be retrieved. This will be `blob://local` if the data is stored locally.""" - media_type: str = "*/*" - """The MIME type of the data. This should be overridden in subclasses.""" rel: Literal["output"] = "output" description: str = ( "The output from this action is not serialised to JSON, so it must be " "retrieved as a file. This link will return the file." ) + media_type: str = "*/*" + """The MIME type of the data. This should be overridden in subclasses.""" - _data: Optional[ServerSideBlobData] = None - """This object holds the data, either in memory or as a file. - - If `_data` is `None`, then the Blob has not been deserialised yet, and the - `href` should point to a valid address where the data may be downloaded. - """ - - @model_validator(mode="after") - def retrieve_data(self): - """Retrieve the data from the URL - - When a [`Blob`](#labthings_fastapi.outputs.blob.Blob) is created - using its constructor, [`pydantic`](https://docs.pydantic.dev/latest/) - will attempt to deserialise it by retrieving the data from the URL - specified in `href`. Currently, this must be a URL pointing to a - [`Blob`](#labthings_fastapi.outputs.blob.Blob) that already exists on - this server. - - This validator will only work if the function to resolve URLs to - [`BlobData`](#labthings_fastapi.outputs.blob.BlobData) objects - has been set in the context variable - [`url_to_blobdata_ctx`](#labthings_fastapi.outputs.blob.url_to_blobdata_ctx). - This is done when actions are being invoked over HTTP by the - [`BlobIOContextDep`](#labthings_fastapi.outputs.blob.BlobIOContextDep) dependency. - """ - if self.href == "blob://local": - if self._data: - return self - raise ValueError("Blob objects must have data if the href is blob://local") - try: - url_to_blobdata = url_to_blobdata_ctx.get() - self._data = url_to_blobdata(self.href) - self.href = "blob://local" - except LookupError: - raise LookupError( - "Blobs may only be created from URLs passed in over HTTP." - f"The URL in question was {self.href}." - ) - return self + _data: ServerSideBlobData + """This object holds the data, either in memory or as a file.""" @model_serializer(mode="plain", when_used="always") def to_dict(self) -> Mapping[str, str]: - """Serialise the Blob to a dictionary and make it downloadable - - When [`pydantic`](https://docs.pydantic.dev/latest/) serialises this object, - it will call this method to convert it to a dictionary. There is a - significant side-effect, which is that we will add the blob to the - [`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager) so it - can be downloaded. - - This serialiser will only work if the function to assign URLs to - [`BlobData`](#labthings_fastapi.outputs.blob.BlobData) objects - has been set in the context variable - [`blobdata_to_url_ctx`](#labthings_fastapi.outputs.blob.blobdata_to_url_ctx). - This is done when actions are being returned over HTTP by the - [`BlobIOContextDep`](#labthings_fastapi.outputs.blob.BlobIOContextDep) dependency. - """ - if self.href == "blob://local": - try: - blobdata_to_url = blobdata_to_url_ctx.get() - # MyPy seems to miss that `self.data` is a property, hence the ignore - href = blobdata_to_url(self.data) # type: ignore[arg-type] - except LookupError: - raise LookupError( - "Blobs may only be serialised inside the " - "context created by BlobIOContextDep." - ) - else: - href = self.href + """Serialise the Blob to a dictionary and make it downloadable""" return { - "href": href, + "href": self.href, "media_type": self.media_type, "rel": self.rel, "description": self.description, @@ -348,9 +276,8 @@ def open(self) -> io.IOBase: @classmethod def from_bytes(cls, data: bytes) -> Self: """Create a BlobOutput from a bytes object""" - return cls.model_construct( # type: ignore[return-value] - href="blob://local", - _data=BlobBytes(data, media_type=cls.default_media_type()), + return cls.model_construct( + _data=BlobBytes(data, media_type=cls.default_media_type()) ) @classmethod @@ -362,8 +289,7 @@ def from_temporary_directory(cls, folder: TemporaryDirectory, file: str) -> Self collected. """ file_path = os.path.join(folder.name, file) - return cls.model_construct( # type: ignore[return-value] - href="blob://local", + return cls.model_construct( _data=BlobFile( file_path, media_type=cls.default_media_type(), @@ -381,9 +307,8 @@ def from_file(cls, file: str) -> Self: temporary. If you are using temporary files, consider creating your Blob with `from_temporary_directory` instead. """ - return cls.model_construct( # type: ignore[return-value] - href="blob://local", - _data=BlobFile(file, media_type=cls.default_media_type()), + return cls.model_construct( + _data=BlobFile(file, media_type=cls.default_media_type()) ) def response(self): @@ -391,26 +316,6 @@ def response(self): return self.data.response() -def blob_type(media_type: str) -> type[Blob]: - """Create a BlobOutput subclass for a given media type - - This convenience function may confuse static type checkers, so it is usually - clearer to make a subclass instead, e.g.: - - ```python - class MyImageBlob(Blob): - media_type = "image/png" - ``` - """ - if "'" in media_type or "\\" in media_type: - raise ValueError("media_type must not contain single quotes or backslashes") - return create_model( - f"{media_type.replace('/', '_')}_blob", - __base__=Blob, - media_type=(eval(f"Literal[r'{media_type}']"), media_type), - ) - - class BlobDataManager: """A class to manage BlobData objects @@ -452,59 +357,3 @@ def download_blob(self, blob_id: uuid.UUID): def attach_to_app(self, app: FastAPI): """Attach the BlobDataManager to a FastAPI app""" app.get("/blob/{blob_id}")(self.download_blob) - - -blobdata_to_url_ctx = ContextVar[Callable[[ServerSideBlobData], str]]("blobdata_to_url") -"""This context variable gives access to a function that makes BlobData objects -downloadable, by assigning a URL and adding them to the -[`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager). - -It is only available within a -[`blob_serialisation_context_manager`](#labthings_fastapi.outputs.blob.blob_serialisation_context_manager) -because it requires access to the `BlobDataManager` and the `url_for` function -from the FastAPI app. -""" - -url_to_blobdata_ctx = ContextVar[Callable[[str], BlobData]]("url_to_blobdata") -"""This context variable gives access to a function that makes BlobData objects -from a URL, by retrieving them from the -[`BlobDataManager`](#labthings_fastapi.outputs.blob.BlobDataManager). - -It is only available within a -[`blob_serialisation_context_manager`](#labthings_fastapi.outputs.blob.blob_serialisation_context_manager) -because it requires access to the `BlobDataManager`. -""" - - -async def blob_serialisation_context_manager(request: Request): - """Set context variables to allow blobs to be [de]serialised""" - thing_server = find_thing_server(request.app) - blob_manager: BlobDataManager = thing_server.blob_data_manager - url_for = request.url_for - - def blobdata_to_url(blob: ServerSideBlobData) -> str: - blob_id = blob_manager.add_blob(blob) - return str(url_for("download_blob", blob_id=blob_id)) - - def url_to_blobdata(url: str) -> BlobData: - m = re.search(r"blob/([0-9a-z\-]+)", url) - if not m: - raise HTTPException( - status_code=404, detail="Could not find blob ID in href" - ) - invocation_id = uuid.UUID(m.group(1)) - return blob_manager.get_blob(invocation_id) - - t1 = blobdata_to_url_ctx.set(blobdata_to_url) - t2 = url_to_blobdata_ctx.set(url_to_blobdata) - try: - yield blob_manager - finally: - blobdata_to_url_ctx.reset(t1) - url_to_blobdata_ctx.reset(t2) - - -BlobIOContextDep: TypeAlias = Annotated[ - BlobDataManager, Depends(blob_serialisation_context_manager) -] -"""A dependency that enables `Blob`s to be serialised and deserialised.""" diff --git a/src/labthings_fastapi/server/__init__.py b/src/labthings_fastapi/server/__init__.py index 4d7a5290..3335aceb 100644 --- a/src/labthings_fastapi/server/__init__.py +++ b/src/labthings_fastapi/server/__init__.py @@ -30,7 +30,7 @@ def __init__(self, settings_folder: Optional[str] = None): self.app = FastAPI(lifespan=self.lifespan) self.set_cors_middleware() self.settings_folder = settings_folder or "./settings" - self.action_manager = ActionManager() + self.action_manager = ActionManager(self) self.action_manager.attach_to_app(self.app) self.blob_data_manager = BlobDataManager() self.blob_data_manager.attach_to_app(self.app)