diff --git a/client/auth.go b/client/auth.go index 950d67505..1d2e45768 100644 --- a/client/auth.go +++ b/client/auth.go @@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error { pos += 2 // The upper 2 bytes of the Capabilities Flags - c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability + c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0 @@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error { // Set default client capabilities that reflect the abilities of this library capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION | - mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH - // Adjust client capability flags based on server support - capability |= c.capability & mysql.CLIENT_LONG_FLAG - capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES + mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH | + mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES | mysql.CLIENT_DEPRECATE_EOF // Adjust client capability flags on specific client requests // Only flags that would make any sense setting and aren't handled elsewhere // in the library are supported here @@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error { data := make([]byte, length+4) // capability [32 bit] + c.capability &= capability data[4] = byte(capability) data[5] = byte(capability >> 8) data[6] = byte(capability >> 16) diff --git a/client/client_test.go b/client/client_test.go index 7443e2877..05c8089df 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -101,6 +101,18 @@ func (s *clientTestSuite) TestConn_Compress() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_NoDeprecateEOF() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error { + conn.UnsetCapability(mysql.CLIENT_DEPRECATE_EOF) + return nil + }) + require.NoError(s.T(), err) + + _, err = conn.Execute("SELECT VERSION()") + require.NoError(s.T(), err) +} + func (s *clientTestSuite) TestConn_SetCapability() { caps := []uint32{ mysql.CLIENT_LONG_PASSWORD, @@ -125,6 +137,7 @@ func (s *clientTestSuite) TestConn_SetCapability() { mysql.CLIENT_PLUGIN_AUTH, mysql.CLIENT_CONNECT_ATTRS, mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + mysql.CLIENT_DEPRECATE_EOF, } for _, capI := range caps { diff --git a/client/conn.go b/client/conn.go index 572fe2b09..dd81e5ebd 100644 --- a/client/conn.go +++ b/client/conn.go @@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) { // HasCapability returns true if the connection has the specific capability func (c *Conn) HasCapability(cap uint32) bool { - return c.ccaps&cap > 0 + return c.ccaps&cap != 0 } // UseSSL: use default SSL diff --git a/client/resp.go b/client/resp.go index 8d77b480e..1e8549549 100644 --- a/client/resp.go +++ b/client/resp.go @@ -14,24 +14,11 @@ import ( "github.com/go-mysql-org/go-mysql/utils" ) -func (c *Conn) readUntilEOF() (err error) { - var data []byte - - for { - data, err = c.ReadPacket() - if err != nil { - return err - } - - // EOF Packet - if c.isEOFPacket(data) { - return err - } - } -} - func (c *Conn) isEOFPacket(data []byte) bool { - return data[0] == mysql.EOF_HEADER && len(data) <= 5 + // 0xffffff due to https://dev.mysql.com/worklog/task/?id=7766 + // "Server will never send OK packet longer than 16777216 bytes thus limiting + // size of OK packet to be 16777215 bytes" + return data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff } func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) { @@ -336,10 +323,9 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Re } func (c *Conn) readResultColumns(result *mysql.Result) (err error) { - i := 0 var data []byte - for { + for i := range result.Fields { rawPkgLen := len(result.RawPkg) result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg) if err != nil { @@ -347,22 +333,6 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) { } data = result.RawPkg[rawPkgLen:] - // EOF Packet - if c.isEOFPacket(data) { - if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { - result.Warnings = binary.LittleEndian.Uint16(data[1:]) - // todo add strict_mode, warning will be treat as error - result.Status = binary.LittleEndian.Uint16(data[3:]) - c.status = result.Status - } - - if i != len(result.Fields) { - err = mysql.ErrMalformPacket - } - - return err - } - if result.Fields[i] == nil { result.Fields[i] = &mysql.Field{} } @@ -372,8 +342,30 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) { } result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i + } + + if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { + // EOF Packet + rawPkgLen := len(result.RawPkg) + result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg) + if err != nil { + return err + } + data = result.RawPkg[rawPkgLen:] - i++ + if c.isEOFPacket(data) { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { + result.Warnings = binary.LittleEndian.Uint16(data[1:]) + // todo add strict_mode, warning will be treat as error + result.Status = binary.LittleEndian.Uint16(data[3:]) + c.status = result.Status + } + return nil + } else { + return mysql.ErrMalformPacket + } + } else { + return nil } } @@ -388,15 +380,21 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) { } data = result.RawPkg[rawPkgLen:] - // EOF Packet if c.isEOFPacket(data) { - if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 { + // Treat like OK + affectedRows, _, n := mysql.LengthEncodedInt(data[1:]) + insertId, _, m := mysql.LengthEncodedInt(data[1+n:]) + result.Status = binary.LittleEndian.Uint16(data[1+n+m:]) + result.AffectedRows = affectedRows + result.InsertId = insertId + c.status = result.Status + } else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { result.Warnings = binary.LittleEndian.Uint16(data[1:]) // todo add strict_mode, warning will be treat as error result.Status = binary.LittleEndian.Uint16(data[3:]) c.status = result.Status } - break } @@ -435,9 +433,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR return err } - // EOF Packet if c.isEOFPacket(data) { - if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 { + // Treat like OK + affectedRows, _, n := mysql.LengthEncodedInt(data[1:]) + insertId, _, m := mysql.LengthEncodedInt(data[1+n:]) + result.Status = binary.LittleEndian.Uint16(data[1+n+m:]) + result.AffectedRows = affectedRows + result.InsertId = insertId + c.status = result.Status + } else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { result.Warnings = binary.LittleEndian.Uint16(data[1:]) // todo add strict_mode, warning will be treat as error result.Status = binary.LittleEndian.Uint16(data[3:]) diff --git a/client/stmt.go b/client/stmt.go index bab79c7ba..106e176de 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -275,14 +275,33 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } if s.params > 0 { - if err := s.conn.readUntilEOF(); err != nil { - return nil, errors.Trace(err) + for range s.params { + if _, err := s.conn.ReadPacket(); err != nil { + return nil, errors.Trace(err) + } + } + if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { + if packet, err := s.conn.ReadPacket(); err != nil { + return nil, errors.Trace(err) + } else if !c.isEOFPacket(packet) { + return nil, mysql.ErrMalformPacket + } } } if s.columns > 0 { - if err := s.conn.readUntilEOF(); err != nil { - return nil, errors.Trace(err) + // TODO process when CLIENT_CACHE_METADATA enabled + for range s.columns { + if _, err := s.conn.ReadPacket(); err != nil { + return nil, errors.Trace(err) + } + } + if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { + if packet, err := s.conn.ReadPacket(); err != nil { + return nil, errors.Trace(err) + } else if !c.isEOFPacket(packet) { + return nil, mysql.ErrMalformPacket + } } }