From 052aa667ee437c46a40a6852a9c481b2226991fb Mon Sep 17 00:00:00 2001
From: Chris Busbey <cbusbey@connamara.com>
Date: Thu, 11 Aug 2016 12:54:01 -0500
Subject: [PATCH] KISS on registry, session management

---
 acceptor.go            |   6 +--
 connection.go          |  26 ++++-----
 initiator.go           |   6 +--
 message_router_test.go |   6 ++-
 registry.go            | 116 ++++++++---------------------------------
 session.go             |  29 +++++++++--
 6 files changed, 73 insertions(+), 116 deletions(-)

diff --git a/acceptor.go b/acceptor.go
index 68752896..f23c0982 100644
--- a/acceptor.go
+++ b/acceptor.go
@@ -135,9 +135,9 @@ func (a *Acceptor) run(server net.Listener, connections chan net.Conn) {
 			}
 
 			for sessionID := range a.qualifiedSessionIDs {
-				session, err := lookupSession(sessionID)
-				if err != nil {
-					a.globalLog.OnEventf("Error getting session: %v", err)
+				session, ok := lookupSession(sessionID)
+				if !ok {
+					a.globalLog.OnEventf("Session %v not found", sessionID)
 				} else {
 					go session.disconnect()
 				}
diff --git a/connection.go b/connection.go
index be620e6a..0a4aea5c 100644
--- a/connection.go
+++ b/connection.go
@@ -17,14 +17,12 @@ func handleInitiatorConnection(
 	tlsConfig *tls.Config,
 	stopChan <-chan interface{},
 ) {
-	session := activate(sessID)
-	if session == nil {
+	session, ok := lookupSession(sessID)
+	if !ok {
 		log.OnEventf("Session not found for SessionID: %v", sessID)
 		return
 	}
 
-	defer deactivate(sessID)
-
 	for {
 		msgIn := make(chan fixIn)
 		msgOut := make(chan []byte)
@@ -52,8 +50,12 @@ func handleInitiatorConnection(
 			}
 		}
 
+		if err := session.initiate(msgIn, msgOut); err != nil {
+			log.OnEventf("Failed to initiate: %v", err)
+			goto reconnect
+		}
+
 		go readLoop(newParser(bufio.NewReader(netConn)), msgIn)
-		session.initiate(msgIn, msgOut)
 		writeLoop(netConn, msgOut, log)
 		if err := netConn.Close(); err != nil {
 			log.OnEvent(err.Error())
@@ -132,25 +134,25 @@ func handleAcceptorConnection(netConn net.Conn, qualifiedSessionIDs map[SessionI
 		return
 	}
 
-	session := activate(qualifiedSessID)
-
-	if session == nil {
+	session, ok := lookupSession(qualifiedSessID)
+	if !ok {
 		log.OnEventf("Cannot activate session for incoming message: %v", msg.String())
 		return
 	}
-	defer func() {
-		deactivate(qualifiedSessID)
-	}()
 
 	msgIn := make(chan fixIn)
 	msgOut := make(chan []byte)
 
+	if err := session.accept(msgIn, msgOut); err != nil {
+		log.OnEventf("Unable to accept %v", err.Error())
+		return
+	}
+
 	go func() {
 		msgIn <- fixIn{msgBytes, parser.lastRead}
 		readLoop(parser, msgIn)
 	}()
 
-	session.accept(msgIn, msgOut)
 	writeLoop(netConn, msgOut, log)
 }
 
diff --git a/initiator.go b/initiator.go
index c85e517a..8f528db8 100644
--- a/initiator.go
+++ b/initiator.go
@@ -66,9 +66,9 @@ func (i *Initiator) Stop() {
 	close(i.stopChan)
 
 	for sessionID := range i.sessionSettings {
-		session, err := lookupSession(sessionID)
-		if err != nil {
-			i.globalLog.OnEventf("Error getting session: %v", err)
+		session, ok := lookupSession(sessionID)
+		if !ok {
+			i.globalLog.OnEventf("Session %v not found", sessionID)
 		} else {
 			go session.disconnect()
 		}
diff --git a/message_router_test.go b/message_router_test.go
index caaf8a6d..e008d03f 100644
--- a/message_router_test.go
+++ b/message_router_test.go
@@ -58,7 +58,7 @@ func (suite *MessageRouterTestSuite) givenTargetDefaultApplVerIDForSession(defau
 		sessionID:              sessionID,
 		targetDefaultApplVerID: defaultApplVerID,
 	}
-	sessions.newSession <- s
+	suite.Nil(registerSession(s))
 }
 
 func (suite *MessageRouterTestSuite) givenAFIX42NewOrderSingle() {
@@ -95,6 +95,10 @@ func (suite *MessageRouterTestSuite) resetRouter() {
 
 func (suite *MessageRouterTestSuite) SetupTest() {
 	suite.resetRouter()
+	sessionsLock.Lock()
+	defer sessionsLock.Unlock()
+
+	sessions = make(map[SessionID]*session)
 }
 
 func (suite *MessageRouterTestSuite) TestNoRoute() {
diff --git a/registry.go b/registry.go
index 7e0c909d..87d55d4d 100644
--- a/registry.go
+++ b/registry.go
@@ -1,9 +1,15 @@
 package quickfix
 
 import (
-	"fmt"
+	"errors"
+	"sync"
 )
 
+var sessionsLock sync.RWMutex
+var sessions = make(map[SessionID]*session)
+var errDuplicateSessionID = errors.New("Duplicate SessionID")
+var errUnknownSession = errors.New("Unknown session")
+
 //Messagable is a Message or something that can be converted to a Message
 type Messagable interface {
 	ToMessage() Message
@@ -36,106 +42,30 @@ func Send(m Messagable) (err error) {
 //SendToTarget sends a message based on the sessionID. Convenient for use in FromApp since it provides a session ID for incoming messages
 func SendToTarget(m Messagable, sessionID SessionID) error {
 	msg := m.ToMessage()
-	session, err := lookupSession(sessionID)
-	if err != nil {
-		return err
+	session, ok := lookupSession(sessionID)
+	if !ok {
+		return errUnknownSession
 	}
 
 	return session.queueForSend(msg)
 }
 
-type sessionActivate struct {
-	SessionID
-	reply chan *session
-}
-
-type sessionResource struct {
-	session *session
-	active  bool
-}
-
-type sessionLookupResponse struct {
-	session *session
-	err     error
-}
-
-type sessionLookup struct {
-	SessionID
-	reply chan sessionLookupResponse
-}
-
-type registry struct {
-	newSession chan *session
-	activate   chan sessionActivate
-	deactivate chan SessionID
-	lookup     chan sessionLookup
-}
-
-var sessions *registry
-
-func init() {
-	sessions = new(registry)
-	sessions.newSession = make(chan *session)
-	sessions.activate = make(chan sessionActivate)
-	sessions.deactivate = make(chan SessionID)
-	sessions.lookup = make(chan sessionLookup)
-
-	go sessions.sessionResourceServerLoop()
-}
+func registerSession(s *session) error {
+	sessionsLock.Lock()
+	defer sessionsLock.Unlock()
 
-func activate(sessionID SessionID) *session {
-	response := make(chan *session)
-	sessions.activate <- sessionActivate{sessionID, response}
-	return <-response
-}
+	if _, ok := sessions[s.sessionID]; ok {
+		return errDuplicateSessionID
+	}
 
-func deactivate(sessionID SessionID) {
-	sessions.deactivate <- sessionID
+	sessions[s.sessionID] = s
+	return nil
 }
 
-//lookupSession returns the Session associated with the sessionID.
-func lookupSession(sessionID SessionID) (*session, error) {
-	responseChannel := make(chan sessionLookupResponse)
-	sessions.lookup <- sessionLookup{sessionID, responseChannel}
+func lookupSession(sessionID SessionID) (s *session, ok bool) {
+	sessionsLock.RLock()
+	defer sessionsLock.RUnlock()
 
-	response := <-responseChannel
-	return response.session, response.err
-}
-
-func (r *registry) sessionResourceServerLoop() {
-	sessions := make(map[SessionID]*sessionResource)
-
-	for {
-		select {
-		case session := <-r.newSession:
-			sessions[session.sessionID] = &sessionResource{session, false}
-
-		case deactivatedID := <-r.deactivate:
-			if resource, ok := sessions[deactivatedID]; ok {
-				resource.active = false
-			}
-
-		case lookup := <-r.lookup:
-			if resource, ok := sessions[lookup.SessionID]; ok {
-				lookup.reply <- sessionLookupResponse{resource.session, nil}
-			} else {
-				lookup.reply <- sessionLookupResponse{nil, fmt.Errorf("session not found")}
-			}
-
-		case request := <-r.activate:
-			resource, ok := sessions[request.SessionID]
-
-			switch {
-			case !ok:
-				request.reply <- nil
-
-			case resource.active:
-				request.reply <- nil
-
-			default:
-				resource.active = true
-				request.reply <- resource.session
-			}
-		}
-	}
+	s, ok = sessions[sessionID]
+	return
 }
diff --git a/session.go b/session.go
index 06a6de03..b1ec9ee7 100644
--- a/session.go
+++ b/session.go
@@ -1,6 +1,7 @@
 package quickfix
 
 import (
+	"errors"
 	"fmt"
 	"sync"
 	"time"
@@ -257,9 +258,11 @@ func createSession(
 		return err
 	}
 
+	if err := registerSession(session); err != nil {
+		return err
+	}
 	application.OnCreate(session.sessionID)
 	session.log.OnEvent("Created session")
-	sessions.newSession <- session
 	go session.run()
 
 	return nil
@@ -269,25 +272,33 @@ type connect struct {
 	messageOut    chan<- []byte
 	messageIn     <-chan fixIn
 	initiateLogon bool
+	err           chan<- error
 }
 
 type disconnectReq chan interface{}
 
 //kicks off session as an initiator
-func (s *session) initiate(msgIn <-chan fixIn, msgOut chan<- []byte) {
+func (s *session) initiate(msgIn <-chan fixIn, msgOut chan<- []byte) error {
+	rep := make(chan error)
 	s.admin <- connect{
 		messageOut:    msgOut,
 		messageIn:     msgIn,
 		initiateLogon: true,
+		err:           rep,
 	}
+
+	return <-rep
 }
 
 //kicks off session as an acceptor
-func (s *session) accept(msgIn chan fixIn, msgOut chan []byte) {
+func (s *session) accept(msgIn chan fixIn, msgOut chan []byte) error {
+	rep := make(chan error)
 	s.admin <- connect{
 		messageOut: msgOut,
 		messageIn:  msgIn,
+		err:        rep,
 	}
+	return <-rep
 }
 
 //blocks until the session has disconnected
@@ -809,8 +820,16 @@ func (s *session) onAdmin(msg interface{}) {
 	switch msg := msg.(type) {
 
 	case connect:
+		defer func() {
+			if msg.err != nil {
+				close(msg.err)
+			}
+		}()
+
 		if s.IsConnected() {
-			s.log.OnEvent("Already connected")
+			if msg.err != nil {
+				msg.err <- errors.New("Already connected")
+			}
 			return
 		}
 
@@ -820,6 +839,8 @@ func (s *session) onAdmin(msg interface{}) {
 		s.initiateLogon = msg.initiateLogon
 		s.Start(s)
 
+		return
+
 	case disconnectReq:
 		s.Stop(s, msg)
 	}
-- 
GitLab