@@ -61,8 +61,6 @@ type Conn struct {
6161}
6262
6363func (c * Conn ) close (err error ) {
64- err = xerrors .Errorf ("websocket closed: %w" , err )
65-
6664 c .closeOnce .Do (func () {
6765 runtime .SetFinalizer (c , nil )
6866
@@ -71,7 +69,7 @@ func (c *Conn) close(err error) {
7169 cerr = err
7270 }
7371
74- c .closeErr = cerr
72+ c .closeErr = xerrors . Errorf ( "websocket closed: %w" , cerr )
7573
7674 close (c .closed )
7775 })
@@ -98,7 +96,7 @@ func (c *Conn) init() {
9896 c .readDone = make (chan int )
9997
10098 runtime .SetFinalizer (c , func (c * Conn ) {
101- c .Close ( StatusInternalError , "connection garbage collected" )
99+ c .close ( xerrors . New ( "connection garbage collected" ) )
102100 })
103101
104102 go c .writeLoop ()
@@ -238,7 +236,7 @@ func (c *Conn) handleControl(h header) {
238236 case opClose :
239237 ce , err := parseClosePayload (b )
240238 if err != nil {
241- c .close (xerrors .Errorf ("read invalid close payload: %w" , err ))
239+ c .close (xerrors .Errorf ("received invalid close payload: %w" , err ))
242240 return
243241 }
244242 if ce .Code == StatusNoStatusRcvd {
@@ -302,7 +300,7 @@ func (c *Conn) readLoop() {
302300 }
303301}
304302
305- func (c * Conn ) dataReadLoop (h header ) ( err error ) {
303+ func (c * Conn ) dataReadLoop (h header ) error {
306304 maskPos := 0
307305 left := h .payloadLength
308306 firstReadDone := false
@@ -355,7 +353,6 @@ func (c *Conn) writePong(p []byte) error {
355353
356354// Close closes the WebSocket connection with the given status code and reason.
357355// It will write a WebSocket close frame with a timeout of 5 seconds.
358- // Concurrent calls to Close are ok.
359356func (c * Conn ) Close (code StatusCode , reason string ) error {
360357 err := c .exportedClose (code , reason )
361358 if err != nil {
@@ -400,7 +397,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
400397 return err
401398 }
402399
403- if cerr != c .closeErr {
400+ if ! xerrors . Is ( c .closeErr , cerr ) {
404401 return c .closeErr
405402 }
406403
@@ -420,9 +417,8 @@ func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) er
420417 payload : p ,
421418 }:
422419 case <- ctx .Done ():
423- err := xerrors .Errorf ("control frame write timed out: %w" , ctx .Err ())
424- c .close (err )
425- return err
420+ c .close (xerrors .Errorf ("control frame write timed out: %w" , ctx .Err ()))
421+ return ctx .Err ()
426422 }
427423
428424 select {
@@ -487,7 +483,7 @@ func (w messageWriter) write(p []byte) (int, error) {
487483 select {
488484 case <- w .ctx .Done ():
489485 w .c .close (xerrors .Errorf ("data write timed out: %w" , w .ctx .Err ()))
490- // Wait for writeLoop to complete so we know p is done.
486+ // Wait for writeLoop to complete so we know p is done with .
491487 <- w .c .writeDone
492488 return 0 , w .ctx .Err ()
493489 case _ , ok := <- w .c .writeDone :
@@ -542,25 +538,21 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
542538}
543539
544540func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
545- for ! atomic .CompareAndSwapInt64 (& c .activeReader , 0 , 1 ) {
546- select {
547- case <- c .closed :
548- return 0 , nil , c .closeErr
549- case c .readBytes <- nil :
550- select {
551- case <- ctx .Done ():
552- return 0 , nil , ctx .Err ()
553- case _ , ok := <- c .readDone :
554- if ! ok {
555- return 0 , nil , c .closeErr
556- }
557- if atomic .LoadInt64 (& c .activeReader ) == 1 {
558- return 0 , nil , xerrors .New ("previous message not fully read" )
559- }
560- }
561- case <- ctx .Done ():
562- return 0 , nil , ctx .Err ()
541+ if ! atomic .CompareAndSwapInt64 (& c .activeReader , 0 , 1 ) {
542+ // If the next read yields io.EOF we are good to go.
543+ r := messageReader {
544+ ctx : ctx ,
545+ c : c ,
563546 }
547+ _ , err := r .Read (nil )
548+ if err == nil {
549+ return 0 , nil , xerrors .New ("previous message not fully read" )
550+ }
551+ if ! xerrors .Is (err , io .EOF ) {
552+ return 0 , nil , xerrors .Errorf ("failed to check if last message at io.EOF: %w" , err )
553+ }
554+
555+ atomic .StoreInt64 (& c .activeReader , 1 )
564556 }
565557
566558 select {
@@ -586,7 +578,8 @@ type messageReader struct {
586578func (r messageReader ) Read (p []byte ) (int , error ) {
587579 n , err := r .read (p )
588580 if err != nil {
589- // Have to return io.EOF directly for now, cannot wrap.
581+ // Have to return io.EOF directly for now, we cannot wrap as xerrors
582+ // isn't used in stdlib.
590583 if err == io .EOF {
591584 return n , io .EOF
592585 }
0 commit comments