From 24317173f3d8424388715083a211ee48c1d2a63e Mon Sep 17 00:00:00 2001 From: Matt Brown Date: Sun, 26 Oct 2025 08:42:54 +0000 Subject: [PATCH] Handle PROTOCOL_CANCELLED (state 3) in _dispatch_result. Prior to this patch the state machine did not handle results arriving after cancellation. This add explicit support for that scenario by keeping track of the state prior the cancellation (cancelled_from_state), and using cancelled_from_state when consuming the results. The alternative option is just silently dropping the results (similar to PROTOCOL_TERMINATING). --- asyncpg/protocol/coreproto.pxd | 1 + asyncpg/protocol/coreproto.pyx | 3 +++ asyncpg/protocol/protocol.pyx | 31 +++++++++++++++++++------------ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 34c7c712..fa657231 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -80,6 +80,7 @@ cdef class CoreProtocol: ConnectionStatus con_status ProtocolState state + ProtocolState cancelled_from_state TransactionStatus xact_status str encoding diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index da96c412..c978a675 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -33,6 +33,7 @@ cdef class CoreProtocol: self.con_params = con_params self.con_status = CONNECTION_BAD self.state = PROTOCOL_IDLE + self.cancelled_from_state = PROTOCOL_IDLE self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` @@ -835,11 +836,13 @@ cdef class CoreProtocol: pass else: self.state = new_state + self.cancelled_from_state = PROTOCOL_IDLE elif new_state == PROTOCOL_FAILED: self.state = PROTOCOL_FAILED elif new_state == PROTOCOL_CANCELLED: + self.cancelled_from_state = self.state self.state = PROTOCOL_CANCELLED elif new_state == PROTOCOL_TERMINATING: diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index acce4e9f..be1b65c7 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -851,39 +851,46 @@ cdef class BaseProtocol(CoreProtocol): waiter.set_exception(exc) return + state = self.state + if state == PROTOCOL_CANCELLED: + state = self.cancelled_from_state + if state == PROTOCOL_IDLE: + waiter.set_exception(asyncio.CancelledError()) + return + try: - if self.state == PROTOCOL_AUTH: + if state == PROTOCOL_AUTH: self._on_result__connect(waiter) - elif self.state == PROTOCOL_PREPARE: + elif state == PROTOCOL_PREPARE: self._on_result__prepare(waiter) - elif self.state == PROTOCOL_BIND_EXECUTE: + elif state == PROTOCOL_BIND_EXECUTE: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_BIND_EXECUTE_MANY: + elif state == PROTOCOL_BIND_EXECUTE_MANY: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_EXECUTE: + elif state == PROTOCOL_EXECUTE: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_BIND: + elif state == PROTOCOL_BIND: self._on_result__bind(waiter) - elif self.state == PROTOCOL_CLOSE_STMT_PORTAL: + elif state == PROTOCOL_CLOSE_STMT_PORTAL: self._on_result__close_stmt_or_portal(waiter) - elif self.state == PROTOCOL_SIMPLE_QUERY: + elif state == PROTOCOL_SIMPLE_QUERY: self._on_result__simple_query(waiter) - elif (self.state == PROTOCOL_COPY_OUT_DATA or - self.state == PROTOCOL_COPY_OUT_DONE): + elif (state == PROTOCOL_COPY_OUT_DATA or + state == PROTOCOL_COPY_OUT_DONE): self._on_result__copy_out(waiter) - elif self.state == PROTOCOL_COPY_IN_DATA: + elif state == PROTOCOL_COPY_IN_DATA: self._on_result__copy_in(waiter) - elif self.state == PROTOCOL_TERMINATING: + elif state == PROTOCOL_TERMINATING: # We are waiting for the connection to drop, so # ignore any stray results at this point. pass