Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
Expand Down Expand Up @@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error {

// Set default client capabilities that reflect the abilities of this library
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
// Adjust client capability flags based on server support
capability |= c.capability & mysql.CLIENT_LONG_FLAG
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH |
mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES | mysql.CLIENT_DEPRECATE_EOF
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
Expand Down Expand Up @@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error {
data := make([]byte, length+4)

// capability [32 bit]
c.capability &= capability
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
Expand Down
13 changes: 13 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ func (s *clientTestSuite) TestConn_Compress() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_NoDeprecateEOF() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
conn.UnsetCapability(mysql.CLIENT_DEPRECATE_EOF)
return nil
})
require.NoError(s.T(), err)

_, err = conn.Execute("SELECT VERSION()")
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCapability() {
caps := []uint32{
mysql.CLIENT_LONG_PASSWORD,
Expand All @@ -125,6 +137,7 @@ func (s *clientTestSuite) TestConn_SetCapability() {
mysql.CLIENT_PLUGIN_AUTH,
mysql.CLIENT_CONNECT_ATTRS,
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
mysql.CLIENT_DEPRECATE_EOF,
}

for _, capI := range caps {
Expand Down
2 changes: 1 addition & 1 deletion client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) {

// HasCapability returns true if the connection has the specific capability
func (c *Conn) HasCapability(cap uint32) bool {
return c.ccaps&cap > 0
return c.ccaps&cap != 0
}

// UseSSL: use default SSL
Expand Down
87 changes: 46 additions & 41 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,11 @@ import (
"github.com/go-mysql-org/go-mysql/utils"
)

func (c *Conn) readUntilEOF() (err error) {
var data []byte

for {
data, err = c.ReadPacket()
if err != nil {
return err
}

// EOF Packet
if c.isEOFPacket(data) {
return err
}
}
}

func (c *Conn) isEOFPacket(data []byte) bool {
return data[0] == mysql.EOF_HEADER && len(data) <= 5
// 0xffffff due to https://dev.mysql.com/worklog/task/?id=7766
// "Server will never send OK packet longer than 16777216 bytes thus limiting
// size of OK packet to be 16777215 bytes"
return data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff
}

func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) {
Expand Down Expand Up @@ -336,33 +323,16 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Re
}

func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
i := 0
var data []byte

for {
for i := range result.Fields {
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

if i != len(result.Fields) {
err = mysql.ErrMalformPacket
}

return err
}

if result.Fields[i] == nil {
result.Fields[i] = &mysql.Field{}
}
Expand All @@ -372,8 +342,30 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
}

result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i
}

if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
// EOF Packet
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

i++
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}
return nil
} else {
return mysql.ErrMalformPacket
}
} else {
return nil
}
}

Expand All @@ -388,15 +380,21 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) {
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

break
}

Expand Down Expand Up @@ -435,9 +433,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR
return err
}

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
Expand Down
27 changes: 23 additions & 4 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,33 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
}

if s.params > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
for range s.params {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

if s.columns > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
// TODO process when CLIENT_CACHE_METADATA enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean MARIADB_CLIENT_CACHE_METADATA? https://mariadb.com/docs/server/reference/clientserver-protocol/mariadb-protocol-differences-with-mysql

I want to understand what's the effect after this PR. Hope it will not break existing program with that capabilities flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently support those flags, so it won't affect. We have code that reads until eof to skip info that certain options would omit from protocol

for range s.columns {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

Expand Down
Loading