Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package quickfix
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"io"
"net"
Expand Down Expand Up @@ -361,6 +362,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
a.sessionAddr.Store(sessID, netConn.RemoteAddr())
msgIn := make(chan fixIn)
msgOut := make(chan []byte)
ctx := context.Background()

if err := session.connect(msgIn, msgOut); err != nil {
a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error())
Expand All @@ -369,10 +371,10 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {

go func() {
msgIn <- fixIn{msgBytes, parser.lastRead}
readLoop(parser, msgIn, a.globalLog)
readLoop(ctx, parser, msgIn, a.globalLog)
}()

writeLoop(netConn, msgOut, a.globalLog)
writeLoop(ctx, netConn, msgOut, a.globalLog)
}

func (a *Acceptor) dynamicSessionsLoop() {
Expand Down
21 changes: 18 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@

package quickfix

import "io"
import (
"context"
"io"
)

func writeLoop(connection io.Writer, messageOut chan []byte, log Log) {
func writeLoop(ctx context.Context, connection io.Writer, messageOut chan []byte, log Log) {
for {
select {
case <-ctx.Done():
return
default:
}

msg, ok := <-messageOut
if !ok {
return
Expand All @@ -30,10 +39,16 @@ func writeLoop(connection io.Writer, messageOut chan []byte, log Log) {
}
}

func readLoop(parser *parser, msgIn chan fixIn, log Log) {
func readLoop(ctx context.Context, parser *parser, msgIn chan fixIn, log Log) {
defer close(msgIn)

for {
select {
case <-ctx.Done():
return
default:
}

msg, err := parser.ReadMessage()
if err != nil {
log.OnEvent(err.Error())
Expand Down
41 changes: 39 additions & 2 deletions connection_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package quickfix

import (
"bytes"
"context"
"strings"
"testing"
)

func TestWriteLoop(t *testing.T) {
ctx := context.Background()
writer := bytes.NewBufferString("")
msgOut := make(chan []byte)

Expand All @@ -31,7 +33,7 @@ func TestWriteLoop(t *testing.T) {
msgOut <- []byte("test msg 3")
close(msgOut)
}()
writeLoop(writer, msgOut, nullLog{})
writeLoop(ctx, writer, msgOut, nullLog{})

expected := "test msg 1 test msg 2 test msg 3"

Expand All @@ -40,12 +42,32 @@ func TestWriteLoop(t *testing.T) {
}
}

func TestWriteLoopCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
writer := bytes.NewBufferString("")
msgOut := make(chan []byte)

go func() {
msgOut <- []byte("test msg 1")
cancel()
}()
writeLoop(ctx, writer, msgOut, nullLog{})

expected := "test msg 1"

if writer.String() != expected {
t.Errorf("expected %v got %v", expected, writer.String())
}
}

func TestReadLoop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
msgIn := make(chan fixIn)
stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103"

parser := newParser(strings.NewReader(stream))
go readLoop(parser, msgIn, nullLog{})
go readLoop(ctx, parser, msgIn, nullLog{})

var tests = []struct {
expectedMsg string
Expand All @@ -71,3 +93,18 @@ func TestReadLoop(t *testing.T) {
}
}
}

func TestReadLoopCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
msgIn := make(chan fixIn)
stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103"

parser := newParser(strings.NewReader(stream))

cancel()
go readLoop(ctx, parser, msgIn, nullLog{})
_, ok := <-msgIn
if ok {
t.Error("Channel should be closed on context cancel")
}
}
26 changes: 15 additions & 11 deletions initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,17 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
return
}

ctx, cancel := context.WithCancel(context.Background())
ctx := context.Background()
dialCtx, dialCancel := context.WithCancel(ctx)
readWriteCtx, readWriteCancel := context.WithCancel(ctx)

// We start a goroutine in order to be able to cancel the dialer mid-connection
// on receiving a stop signal to stop the initiator.
go func() {
select {
case <-i.stopChan:
cancel()
dialCancel()
readWriteCancel()
case <-ctx.Done():
return
}
Expand All @@ -183,7 +186,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

netConn, err := dialer.DialContext(ctx, "tcp", address)
netConn, err := dialer.DialContext(dialCtx, "tcp", address)
if err != nil {
session.log.OnEventf("Failed to connect: %v", err)
goto reconnect
Expand All @@ -207,24 +210,25 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di

msgIn = make(chan fixIn)
msgOut = make(chan []byte)
if err := session.connect(msgIn, msgOut); err != nil {
session.log.OnEventf("Failed to initiate: %v", err)
goto reconnect
}

go readLoop(newParser(bufio.NewReader(netConn)), msgIn, session.log)
go readLoop(readWriteCtx, newParser(bufio.NewReader(netConn)), msgIn, session.log)
disconnected = make(chan interface{})
go func() {
writeLoop(netConn, msgOut, session.log)
writeLoop(readWriteCtx, netConn, msgOut, session.log)
if err := netConn.Close(); err != nil {
session.log.OnEvent(err.Error())
}
close(disconnected)
}()

if err := session.connect(msgIn, msgOut); err != nil {
session.log.OnEventf("Failed to initiate: %v", err)
goto reconnect
}

// This ensures we properly cleanup the goroutine and context used for
// dial cancelation after successful connection.
cancel()
dialCancel()

select {
case <-disconnected:
Expand All @@ -233,7 +237,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
}

reconnect:
cancel()
dialCancel()

connectionAttempt++
session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval)
Expand Down
177 changes: 177 additions & 0 deletions initiator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package quickfix

import (
"context"
"errors"
"net"
"testing"
"time"

"github.com/quickfixgo/quickfix/config"
)

func TestNewInitiatorKeepReconnectingAfterLogonError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

logonCount := 0
app := &mockApplication{}
storeFactory := &mockMessageStoreFactory{saveMessageAndIncrError: errDBError}
logFactory := &mockLogFactory{
onEvent: func(s string) {
if s == "Sending logon request" {
logonCount++
if logonCount >= 5 {
cancel()
}
}
},
}

settings := NewSettings()
sessionSettings := newSession()
sessionID, err := settings.AddSession(sessionSettings)
if err != nil {
t.Fatalf("Expected no error adding session, got %v", err)
}

initiator, err := NewInitiator(app, storeFactory, settings, logFactory)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

s, ok := initiator.sessions[sessionID]
if !ok {
t.Fatal("Expected session to be created")
}

initiator.stopChan = make(chan interface{})
go initiator.handleConnection(s, nil, &mockDialer{})

select {
case <-ctx.Done():
initiator.Stop()
return
case <-time.After(10 * time.Second):
t.Error("retry stopped after logon error")
return
}
}

func newSession() *SessionSettings {
sessionSettings := NewSessionSettings()
sessionSettings.Set(config.BeginString, "FIX.4.4")
sessionSettings.Set(config.SenderCompID, "X")
sessionSettings.Set(config.TargetCompID, "X")
sessionSettings.Set(config.HeartBtInt, "30")
sessionSettings.Set(config.SocketConnectHost, "localhost")
sessionSettings.Set(config.SocketConnectPort, "9878")
sessionSettings.Set(config.ReconnectInterval, "1")
return sessionSettings
}

type mockApplication struct{}

func (m *mockApplication) OnCreate(_ SessionID) {}
func (m *mockApplication) OnLogon(_ SessionID) {}
func (m *mockApplication) OnLogout(_ SessionID) {}
func (m *mockApplication) ToAdmin(_ *Message, _ SessionID) {}
func (m *mockApplication) ToApp(_ *Message, _ SessionID) error { return nil }
func (m *mockApplication) FromAdmin(_ *Message, _ SessionID) MessageRejectError {
return nil
}
func (m *mockApplication) FromApp(_ *Message, _ SessionID) MessageRejectError {
return nil
}

type mockMessageStoreFactory struct {
saveMessageAndIncrError error
}

func (m *mockMessageStoreFactory) Create(_ SessionID) (MessageStore, error) {
return &mockMessageStore{saveMessageAndIncrError: m.saveMessageAndIncrError}, nil
}

var errDBError = errors.New("db error")

type mockMessageStore struct {
saveMessageAndIncrError error
}

func (m *mockMessageStore) NextSenderMsgSeqNum() int { return 1 }
func (m *mockMessageStore) NextTargetMsgSeqNum() int { return 1 }
func (m *mockMessageStore) IncrSenderMsgSeqNum() error { return nil }
func (m *mockMessageStore) IncrTargetMsgSeqNum() error { return nil }
func (m *mockMessageStore) SetNextSenderMsgSeqNum(_ int) error { return nil }
func (m *mockMessageStore) SetNextTargetMsgSeqNum(_ int) error { return nil }
func (m *mockMessageStore) CreationTime() time.Time { return time.Now() }
func (m *mockMessageStore) SaveMessage(_ int, _ []byte) error { return nil }
func (m *mockMessageStore) GetMessages(_, _ int) ([][]byte, error) { return nil, nil }
func (m *mockMessageStore) Refresh() error { return nil }
func (m *mockMessageStore) Reset() error { return nil }
func (m *mockMessageStore) Close() error { return nil }
func (m *mockMessageStore) IncrNextSenderMsgSeqNum() error { return nil }
func (m *mockMessageStore) IncrNextTargetMsgSeqNum() error { return nil }
func (m *mockMessageStore) IterateMessages(int, int, func([]byte) error) error { return nil }
func (m *mockMessageStore) SaveMessageAndIncrNextSenderMsgSeqNum(_ int, _ []byte) error {
return m.saveMessageAndIncrError
}
func (m *mockMessageStore) SetCreationTime(time.Time) {}

type mockLogFactory struct {
shouldFail bool
onEvent func(string)
}

func (m *mockLogFactory) Create() (Log, error) {
if m.shouldFail {
return nil, errors.New("log factory error")
}
return &mockLog{
onEvent: m.onEvent,
}, nil
}

func (m *mockLogFactory) CreateSessionLog(_ SessionID) (Log, error) {
return &mockLog{
onEvent: m.onEvent,
}, nil
}

type mockDialer struct{}

type mockAddr struct {
network string
address string
}

func (m *mockAddr) Network() string { return m.network }
func (m *mockAddr) String() string { return m.address }

type mockConn struct{}

func (m *mockConn) Read(_ []byte) (n int, err error) { return 0, nil }
func (m *mockConn) Write(_ []byte) (n int, err error) { return 0, nil }
func (m *mockConn) Close() error { return nil }
func (m *mockConn) LocalAddr() net.Addr { return &mockAddr{network: "tcp", address: "127.0.0.1:8080"} }
func (m *mockConn) RemoteAddr() net.Addr { return &mockAddr{network: "tcp", address: "127.0.0.1:9090"} }
func (m *mockConn) SetDeadline(_ time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(_ time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(_ time.Time) error { return nil }

func (m *mockDialer) DialContext(_ context.Context, _, _ string) (net.Conn, error) {
return &mockConn{}, nil
}

type mockLog struct {
onEvent func(string)
}

func (m *mockLog) OnIncoming(_ []byte) {}
func (m *mockLog) OnOutgoing(_ []byte) {}
func (m *mockLog) OnEvent(s string) {
if m.onEvent != nil {
m.onEvent(s)
}
}
func (m *mockLog) OnEventf(_ string, _ ...interface{}) {}
Loading
Loading