@@ -130,45 +130,38 @@ def self.value2net(v)
130130 # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
131131
132132 # make socket connection to server.
133- # === Argument
134- # host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
135- # port :: [Integer] port number using by TCPSocket
136- # socket :: [String] socket file name using by UNIXSocket
137- # conn_timeout :: [Integer] connect timeout (sec).
138- # read_timeout :: [Integer] read timeout (sec).
139- # write_timeout :: [Integer] write timeout (sec).
140- # local_infile :: [String] local infile path
141- # ssl_mode :: [Integer]
142- # get_server_public_key :: [Boolean]
143- # === Exception
144- # [ClientError] :: connection timeout
145- def initialize ( host , port , socket , conn_timeout , read_timeout , write_timeout , local_infile , ssl_mode , get_server_public_key )
133+ # @param host [String] if "localhost" or "" or nil then use UNIX socket. Otherwise use TCP socket
134+ # @param port [Integer] port number using by TCP socket
135+ # @param socket [String] socket file name using by UNIX socket
136+ # @param [Hash] opts
137+ # @option opts :conn_timeout [Integer] connect timeout (sec).
138+ # @option opts :read_timeout [Integer] read timeout (sec).
139+ # @option opts :write_timeout [Integer] write timeout (sec).
140+ # @option opts :local_infile [String] local infile path
141+ # @option opts :get_server_public_key [Boolean]
142+ # @raise [ClientError] connection timeout
143+ def initialize ( host , port , socket , opts )
144+ @opts = opts
146145 @insert_id = 0
147146 @warning_count = 0
148147 @gc_stmt_queue = [ ] # stmt id list which GC destroy.
149148 set_state :INIT
150- @read_timeout = read_timeout
151- @write_timeout = write_timeout
152- @local_infile = local_infile
153- @ssl_mode = ssl_mode
154- @get_server_public_key = get_server_public_key
149+ @get_server_public_key = @opts [ :get_server_public_key ]
155150 begin
156- Timeout . timeout conn_timeout do
157- if host . nil? or host . empty? or host == "localhost"
158- socket ||= ENV [ "MYSQL_UNIX_PORT" ] || MYSQL_UNIX_PORT
159- @sock = UNIXSocket . new socket
160- else
161- port ||= ENV [ "MYSQL_TCP_PORT" ] || ( Socket . getservbyname ( "mysql" , "tcp" ) rescue MYSQL_TCP_PORT )
162- @sock = TCPSocket . new host , port
163- end
151+ if host . nil? or host . empty? or host == "localhost"
152+ socket ||= ENV [ "MYSQL_UNIX_PORT" ] || MYSQL_UNIX_PORT
153+ @socket = Socket . unix ( socket )
154+ else
155+ port ||= ENV [ "MYSQL_TCP_PORT" ] || ( Socket . getservbyname ( "mysql" , "tcp" ) rescue MYSQL_TCP_PORT )
156+ @socket = Socket . tcp ( host , port , connect_timeout : @opts [ :connect_timeout ] )
164157 end
165- rescue Timeout :: Error
158+ rescue Errno :: ETIMEDOUT
166159 raise ClientError , "connection timeout"
167160 end
168161 end
169162
170163 def close
171- @sock . close
164+ @socket . close
172165 end
173166
174167 # initial negotiate and authenticate.
@@ -190,7 +183,7 @@ def authenticate(user, passwd, db, flag, charset)
190183 @server_capabilities = init_packet . server_capabilities
191184 @thread_id = init_packet . thread_id
192185 @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
193- @client_flags |= CLIENT_LOCAL_FILES if @local_infile
186+ @client_flags |= CLIENT_LOCAL_FILES if @opts [ : local_infile]
194187 @client_flags |= CLIENT_CONNECT_WITH_DB if db
195188 @client_flags |= flag
196189 @charset = charset
@@ -204,28 +197,28 @@ def authenticate(user, passwd, db, flag, charset)
204197 end
205198
206199 def enable_ssl
207- case @ssl_mode
200+ case @opts [ : ssl_mode]
208201 when SSL_MODE_DISABLED
209202 return
210203 when SSL_MODE_PREFERRED
211- return if @sock . is_a? UNIXSocket
204+ return if @socket . local_address . unix?
212205 return if @server_capabilities & CLIENT_SSL == 0
213206 when SSL_MODE_REQUIRED
214207 if @server_capabilities & CLIENT_SSL == 0
215208 raise ClientError ::SslConnectionError , "SSL is required but the server doesn't support it"
216209 end
217210 else
218- raise ClientError , "ssl_mode #{ @ssl_mode } is not supported"
211+ raise ClientError , "ssl_mode #{ @opts [ : ssl_mode] } is not supported"
219212 end
220213 begin
221214 @client_flags |= CLIENT_SSL
222215 write Protocol ::TlsAuthenticationPacket . serialize ( @client_flags , 1024 **3 , @charset . number )
223- @sock = OpenSSL ::SSL ::SSLSocket . new ( @sock )
224- @sock . sync_close = true
225- @sock . connect
216+ @socket = OpenSSL ::SSL ::SSLSocket . new ( @socket )
217+ @socket . sync_close = true
218+ @socket . connect
226219 rescue => e
227220 @client_flags &= ~CLIENT_SSL
228- return if @ssl_mode == SSL_MODE_PREFERRED
221+ return if @opts [ : ssl_mode] == SSL_MODE_PREFERRED
229222 raise e
230223 end
231224 end
@@ -282,7 +275,7 @@ def get_result
282275 # send local file to server
283276 def send_local_file ( filename )
284277 filename = File . absolute_path ( filename )
285- if filename . start_with? @local_infile
278+ if filename . start_with? @opts [ : local_infile]
286279 File . open ( filename ) { |f | write f }
287280 else
288281 raise ClientError ::LoadDataLocalInfileRejected , 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
@@ -482,7 +475,7 @@ def check_state(st)
482475
483476 def set_state ( st )
484477 @state = st
485- if st == :READY
478+ if st == :READY && ! @gc_stmt_queue . empty?
486479 gc_disabled = GC . disable
487480 begin
488481 while st = @gc_stmt_queue . shift
@@ -518,14 +511,14 @@ def read
518511 data = ''
519512 len = nil
520513 begin
521- Timeout . timeout @read_timeout do
522- header = @sock . read ( 4 )
514+ Timeout . timeout @opts [ : read_timeout] do
515+ header = @socket . read ( 4 )
523516 raise EOFError unless header && header . length == 4
524517 len1 , len2 , seq = header . unpack ( "CvC" )
525518 len = ( len2 << 8 ) + len1
526519 raise ProtocolError , "invalid packet: sequence number mismatch(#{ seq } != #{ @seq } (expected))" if @seq != seq
527520 @seq = ( @seq + 1 ) % 256
528- ret = @sock . read ( len )
521+ ret = @socket . read ( len )
529522 raise EOFError unless ret && ret . length == len
530523 data . concat ret
531524 end
@@ -558,25 +551,21 @@ def read
558551 # data :: [String / IO] packet data. If data is nil, write empty packet.
559552 def write ( data )
560553 begin
561- @sock . sync = false
562- if data . nil?
563- Timeout . timeout @write_timeout do
564- @sock . write [ 0 , 0 , @seq ] . pack ( "CvC" )
565- end
566- @seq = ( @seq + 1 ) % 256
567- else
568- data = StringIO . new data if data . is_a? String
569- while d = data . read ( MAX_PACKET_LENGTH )
570- Timeout . timeout @write_timeout do
571- @sock . write [ d . length %256 , d . length /256 , @seq ] . pack ( "CvC" )
572- @sock . write d
573- end
554+ Timeout . timeout @opts [ :write_timeout ] do
555+ @socket . sync = false
556+ if data . nil?
557+ @socket . write [ 0 , 0 , @seq ] . pack ( "CvC" )
574558 @seq = ( @seq + 1 ) % 256
559+ else
560+ data = StringIO . new data if data . is_a? String
561+ while d = data . read ( MAX_PACKET_LENGTH )
562+ @socket . write [ d . length %256 , d . length /256 , @seq ] . pack ( "CvC" )
563+ @socket . write d
564+ @seq = ( @seq + 1 ) % 256
565+ end
575566 end
576- end
577- @sock . sync = true
578- Timeout . timeout @write_timeout do
579- @sock . flush
567+ @socket . sync = true
568+ @socket . flush
580569 end
581570 rescue Errno ::EPIPE
582571 raise ClientError ::ServerGoneError , 'MySQL server has gone away'
0 commit comments