From 052aa667ee437c46a40a6852a9c481b2226991fb Mon Sep 17 00:00:00 2001 From: Chris Busbey <cbusbey@connamara.com> Date: Thu, 11 Aug 2016 12:54:01 -0500 Subject: [PATCH] KISS on registry, session management --- acceptor.go | 6 +-- connection.go | 26 ++++----- initiator.go | 6 +-- message_router_test.go | 6 ++- registry.go | 116 ++++++++--------------------------------- session.go | 29 +++++++++-- 6 files changed, 73 insertions(+), 116 deletions(-) diff --git a/acceptor.go b/acceptor.go index 68752896..f23c0982 100644 --- a/acceptor.go +++ b/acceptor.go @@ -135,9 +135,9 @@ func (a *Acceptor) run(server net.Listener, connections chan net.Conn) { } for sessionID := range a.qualifiedSessionIDs { - session, err := lookupSession(sessionID) - if err != nil { - a.globalLog.OnEventf("Error getting session: %v", err) + session, ok := lookupSession(sessionID) + if !ok { + a.globalLog.OnEventf("Session %v not found", sessionID) } else { go session.disconnect() } diff --git a/connection.go b/connection.go index be620e6a..0a4aea5c 100644 --- a/connection.go +++ b/connection.go @@ -17,14 +17,12 @@ func handleInitiatorConnection( tlsConfig *tls.Config, stopChan <-chan interface{}, ) { - session := activate(sessID) - if session == nil { + session, ok := lookupSession(sessID) + if !ok { log.OnEventf("Session not found for SessionID: %v", sessID) return } - defer deactivate(sessID) - for { msgIn := make(chan fixIn) msgOut := make(chan []byte) @@ -52,8 +50,12 @@ func handleInitiatorConnection( } } + if err := session.initiate(msgIn, msgOut); err != nil { + log.OnEventf("Failed to initiate: %v", err) + goto reconnect + } + go readLoop(newParser(bufio.NewReader(netConn)), msgIn) - session.initiate(msgIn, msgOut) writeLoop(netConn, msgOut, log) if err := netConn.Close(); err != nil { log.OnEvent(err.Error()) @@ -132,25 +134,25 @@ func handleAcceptorConnection(netConn net.Conn, qualifiedSessionIDs map[SessionI return } - session := activate(qualifiedSessID) - - if session == nil { + session, ok := lookupSession(qualifiedSessID) + if !ok { log.OnEventf("Cannot activate session for incoming message: %v", msg.String()) return } - defer func() { - deactivate(qualifiedSessID) - }() msgIn := make(chan fixIn) msgOut := make(chan []byte) + if err := session.accept(msgIn, msgOut); err != nil { + log.OnEventf("Unable to accept %v", err.Error()) + return + } + go func() { msgIn <- fixIn{msgBytes, parser.lastRead} readLoop(parser, msgIn) }() - session.accept(msgIn, msgOut) writeLoop(netConn, msgOut, log) } diff --git a/initiator.go b/initiator.go index c85e517a..8f528db8 100644 --- a/initiator.go +++ b/initiator.go @@ -66,9 +66,9 @@ func (i *Initiator) Stop() { close(i.stopChan) for sessionID := range i.sessionSettings { - session, err := lookupSession(sessionID) - if err != nil { - i.globalLog.OnEventf("Error getting session: %v", err) + session, ok := lookupSession(sessionID) + if !ok { + i.globalLog.OnEventf("Session %v not found", sessionID) } else { go session.disconnect() } diff --git a/message_router_test.go b/message_router_test.go index caaf8a6d..e008d03f 100644 --- a/message_router_test.go +++ b/message_router_test.go @@ -58,7 +58,7 @@ func (suite *MessageRouterTestSuite) givenTargetDefaultApplVerIDForSession(defau sessionID: sessionID, targetDefaultApplVerID: defaultApplVerID, } - sessions.newSession <- s + suite.Nil(registerSession(s)) } func (suite *MessageRouterTestSuite) givenAFIX42NewOrderSingle() { @@ -95,6 +95,10 @@ func (suite *MessageRouterTestSuite) resetRouter() { func (suite *MessageRouterTestSuite) SetupTest() { suite.resetRouter() + sessionsLock.Lock() + defer sessionsLock.Unlock() + + sessions = make(map[SessionID]*session) } func (suite *MessageRouterTestSuite) TestNoRoute() { diff --git a/registry.go b/registry.go index 7e0c909d..87d55d4d 100644 --- a/registry.go +++ b/registry.go @@ -1,9 +1,15 @@ package quickfix import ( - "fmt" + "errors" + "sync" ) +var sessionsLock sync.RWMutex +var sessions = make(map[SessionID]*session) +var errDuplicateSessionID = errors.New("Duplicate SessionID") +var errUnknownSession = errors.New("Unknown session") + //Messagable is a Message or something that can be converted to a Message type Messagable interface { ToMessage() Message @@ -36,106 +42,30 @@ func Send(m Messagable) (err error) { //SendToTarget sends a message based on the sessionID. Convenient for use in FromApp since it provides a session ID for incoming messages func SendToTarget(m Messagable, sessionID SessionID) error { msg := m.ToMessage() - session, err := lookupSession(sessionID) - if err != nil { - return err + session, ok := lookupSession(sessionID) + if !ok { + return errUnknownSession } return session.queueForSend(msg) } -type sessionActivate struct { - SessionID - reply chan *session -} - -type sessionResource struct { - session *session - active bool -} - -type sessionLookupResponse struct { - session *session - err error -} - -type sessionLookup struct { - SessionID - reply chan sessionLookupResponse -} - -type registry struct { - newSession chan *session - activate chan sessionActivate - deactivate chan SessionID - lookup chan sessionLookup -} - -var sessions *registry - -func init() { - sessions = new(registry) - sessions.newSession = make(chan *session) - sessions.activate = make(chan sessionActivate) - sessions.deactivate = make(chan SessionID) - sessions.lookup = make(chan sessionLookup) - - go sessions.sessionResourceServerLoop() -} +func registerSession(s *session) error { + sessionsLock.Lock() + defer sessionsLock.Unlock() -func activate(sessionID SessionID) *session { - response := make(chan *session) - sessions.activate <- sessionActivate{sessionID, response} - return <-response -} + if _, ok := sessions[s.sessionID]; ok { + return errDuplicateSessionID + } -func deactivate(sessionID SessionID) { - sessions.deactivate <- sessionID + sessions[s.sessionID] = s + return nil } -//lookupSession returns the Session associated with the sessionID. -func lookupSession(sessionID SessionID) (*session, error) { - responseChannel := make(chan sessionLookupResponse) - sessions.lookup <- sessionLookup{sessionID, responseChannel} +func lookupSession(sessionID SessionID) (s *session, ok bool) { + sessionsLock.RLock() + defer sessionsLock.RUnlock() - response := <-responseChannel - return response.session, response.err -} - -func (r *registry) sessionResourceServerLoop() { - sessions := make(map[SessionID]*sessionResource) - - for { - select { - case session := <-r.newSession: - sessions[session.sessionID] = &sessionResource{session, false} - - case deactivatedID := <-r.deactivate: - if resource, ok := sessions[deactivatedID]; ok { - resource.active = false - } - - case lookup := <-r.lookup: - if resource, ok := sessions[lookup.SessionID]; ok { - lookup.reply <- sessionLookupResponse{resource.session, nil} - } else { - lookup.reply <- sessionLookupResponse{nil, fmt.Errorf("session not found")} - } - - case request := <-r.activate: - resource, ok := sessions[request.SessionID] - - switch { - case !ok: - request.reply <- nil - - case resource.active: - request.reply <- nil - - default: - resource.active = true - request.reply <- resource.session - } - } - } + s, ok = sessions[sessionID] + return } diff --git a/session.go b/session.go index 06a6de03..b1ec9ee7 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package quickfix import ( + "errors" "fmt" "sync" "time" @@ -257,9 +258,11 @@ func createSession( return err } + if err := registerSession(session); err != nil { + return err + } application.OnCreate(session.sessionID) session.log.OnEvent("Created session") - sessions.newSession <- session go session.run() return nil @@ -269,25 +272,33 @@ type connect struct { messageOut chan<- []byte messageIn <-chan fixIn initiateLogon bool + err chan<- error } type disconnectReq chan interface{} //kicks off session as an initiator -func (s *session) initiate(msgIn <-chan fixIn, msgOut chan<- []byte) { +func (s *session) initiate(msgIn <-chan fixIn, msgOut chan<- []byte) error { + rep := make(chan error) s.admin <- connect{ messageOut: msgOut, messageIn: msgIn, initiateLogon: true, + err: rep, } + + return <-rep } //kicks off session as an acceptor -func (s *session) accept(msgIn chan fixIn, msgOut chan []byte) { +func (s *session) accept(msgIn chan fixIn, msgOut chan []byte) error { + rep := make(chan error) s.admin <- connect{ messageOut: msgOut, messageIn: msgIn, + err: rep, } + return <-rep } //blocks until the session has disconnected @@ -809,8 +820,16 @@ func (s *session) onAdmin(msg interface{}) { switch msg := msg.(type) { case connect: + defer func() { + if msg.err != nil { + close(msg.err) + } + }() + if s.IsConnected() { - s.log.OnEvent("Already connected") + if msg.err != nil { + msg.err <- errors.New("Already connected") + } return } @@ -820,6 +839,8 @@ func (s *session) onAdmin(msg interface{}) { s.initiateLogon = msg.initiateLogon s.Start(s) + return + case disconnectReq: s.Stop(s, msg) } -- GitLab