Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ cdef class CoreProtocol:

ConnectionStatus con_status
ProtocolState state
ProtocolState cancelled_from_state
TransactionStatus xact_status

str encoding
Expand Down
3 changes: 3 additions & 0 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cdef class CoreProtocol:
self.con_params = con_params
self.con_status = CONNECTION_BAD
self.state = PROTOCOL_IDLE
self.cancelled_from_state = PROTOCOL_IDLE
self.xact_status = PQTRANS_IDLE
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
Expand Down Expand Up @@ -835,11 +836,13 @@ cdef class CoreProtocol:
pass
else:
self.state = new_state
self.cancelled_from_state = PROTOCOL_IDLE

elif new_state == PROTOCOL_FAILED:
self.state = PROTOCOL_FAILED

elif new_state == PROTOCOL_CANCELLED:
self.cancelled_from_state = self.state
self.state = PROTOCOL_CANCELLED

elif new_state == PROTOCOL_TERMINATING:
Expand Down
31 changes: 19 additions & 12 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -851,39 +851,46 @@ cdef class BaseProtocol(CoreProtocol):
waiter.set_exception(exc)
return

state = self.state
if state == PROTOCOL_CANCELLED:
state = self.cancelled_from_state
if state == PROTOCOL_IDLE:
waiter.set_exception(asyncio.CancelledError())
return

try:
if self.state == PROTOCOL_AUTH:
if state == PROTOCOL_AUTH:
self._on_result__connect(waiter)

elif self.state == PROTOCOL_PREPARE:
elif state == PROTOCOL_PREPARE:
self._on_result__prepare(waiter)

elif self.state == PROTOCOL_BIND_EXECUTE:
elif state == PROTOCOL_BIND_EXECUTE:
self._on_result__bind_and_exec(waiter)

elif self.state == PROTOCOL_BIND_EXECUTE_MANY:
elif state == PROTOCOL_BIND_EXECUTE_MANY:
self._on_result__bind_and_exec(waiter)

elif self.state == PROTOCOL_EXECUTE:
elif state == PROTOCOL_EXECUTE:
self._on_result__bind_and_exec(waiter)

elif self.state == PROTOCOL_BIND:
elif state == PROTOCOL_BIND:
self._on_result__bind(waiter)

elif self.state == PROTOCOL_CLOSE_STMT_PORTAL:
elif state == PROTOCOL_CLOSE_STMT_PORTAL:
self._on_result__close_stmt_or_portal(waiter)

elif self.state == PROTOCOL_SIMPLE_QUERY:
elif state == PROTOCOL_SIMPLE_QUERY:
self._on_result__simple_query(waiter)

elif (self.state == PROTOCOL_COPY_OUT_DATA or
self.state == PROTOCOL_COPY_OUT_DONE):
elif (state == PROTOCOL_COPY_OUT_DATA or
state == PROTOCOL_COPY_OUT_DONE):
self._on_result__copy_out(waiter)

elif self.state == PROTOCOL_COPY_IN_DATA:
elif state == PROTOCOL_COPY_IN_DATA:
self._on_result__copy_in(waiter)

elif self.state == PROTOCOL_TERMINATING:
elif state == PROTOCOL_TERMINATING:
# We are waiting for the connection to drop, so
# ignore any stray results at this point.
pass
Expand Down