From f04f037975b7e7188cb8b214f00261a2a9024dd2 Mon Sep 17 00:00:00 2001
From: Christopher Ryan <clryan22@gmail.com>
Date: Thu, 15 Mar 2018 15:18:21 -0700
Subject: [PATCH] Incorporate gorm DB

Addresses: https://github.com/quickfixgo/quickfix/issues/284. Provides the abstraction needed, and compatibility for all dialects listed here: http://gorm.io/docs/dialects.html
---
 Gopkg.lock       | 43 +++++++++++++--------
 sqlstore.go      | 99 ++++++++++++++++++++++--------------------------
 sqlstore_test.go |  6 +--
 store_test.go    |  6 +--
 4 files changed, 78 insertions(+), 76 deletions(-)

diff --git a/Gopkg.lock b/Gopkg.lock
index 156d40f2..0d17c096 100644
--- a/Gopkg.lock
+++ b/Gopkg.lock
@@ -7,11 +7,23 @@
   revision = "346938d642f2ec3594ed81d874461961cd0faa76"
   version = "v1.1.0"
 
+[[projects]]
+  name = "github.com/jinzhu/gorm"
+  packages = ["."]
+  revision = "6ed508ec6a4ecb3531899a69cbc746ccf65a4166"
+  version = "v1.9.1"
+
+[[projects]]
+  branch = "master"
+  name = "github.com/jinzhu/inflection"
+  packages = ["."]
+  revision = "04140366298a54a039076d798123ffa108fff46c"
+
 [[projects]]
   name = "github.com/mattn/go-sqlite3"
   packages = ["."]
-  revision = "ca5e3819723d8eeaf170ad510e7da1d6d2e94a08"
-  version = "v1.2.0"
+  revision = "6c771bb9887719704b210e87e934f08be014bdb1"
+  version = "v1.6.0"
 
 [[projects]]
   name = "github.com/pmezard/go-difflib"
@@ -20,32 +32,31 @@
   version = "v1.0.0"
 
 [[projects]]
-  branch = "master"
   name = "github.com/shopspring/decimal"
   packages = ["."]
-  revision = "aed1bfe463fa3c9cc268d60dcc1491db613bff7e"
+  revision = "69b3a8ad1f5f2c8bd855cb6506d18593064a346b"
+  version = "1.0.1"
 
 [[projects]]
-  branch = "master"
   name = "github.com/stretchr/objx"
   packages = ["."]
-  revision = "1a9d0bb9f541897e62256577b352fdbc1fb4fd94"
+  revision = "facf9a85c22f48d2f52f2380e4efce1768749a89"
+  version = "v0.1"
 
 [[projects]]
   name = "github.com/stretchr/testify"
-  packages = ["assert","mock","require","suite"]
-  revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0"
-  version = "v1.1.4"
-
-[[projects]]
-  branch = "master"
-  name = "golang.org/x/net"
-  packages = ["context"]
-  revision = "a04bdaca5b32abe1c069418fb7088ae607de5bd0"
+  packages = [
+    "assert",
+    "mock",
+    "require",
+    "suite"
+  ]
+  revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71"
+  version = "v1.2.1"
 
 [solve-meta]
   analyzer-name = "dep"
   analyzer-version = 1
-  inputs-digest = "6efc9f467166be5af0c9b9f4b98d7860ba12b50ce641a2fba765049bd1ea4f27"
+  inputs-digest = "2e6225023090014b16e956b72d63654d37f1ad0a58e218933eed1c9831266522"
   solver-name = "gps-cdcl"
   solver-version = 1
diff --git a/sqlstore.go b/sqlstore.go
index b16b0a74..6b9058b5 100644
--- a/sqlstore.go
+++ b/sqlstore.go
@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/jinzhu/gorm"
 	"github.com/quickfixgo/quickfix/config"
 )
 
@@ -18,7 +19,7 @@ type sqlStore struct {
 	sqlDriver          string
 	sqlDataSourceName  string
 	sqlConnMaxLifetime time.Duration
-	db                 *sql.DB
+	db                 *gorm.DB
 }
 
 // NewSQLStoreFactory returns a sql-based implementation of MessageStoreFactory
@@ -60,12 +61,12 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn
 	}
 	store.cache.Reset()
 
-	if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil {
+	if store.db, err = gorm.Open(store.sqlDriver, store.sqlDataSourceName); err != nil {
 		return nil, err
 	}
-	store.db.SetConnMaxLifetime(store.sqlConnMaxLifetime)
+	store.db.DB().SetConnMaxLifetime(store.sqlConnMaxLifetime)
 
-	if err = store.db.Ping(); err != nil { // ensure immediate connection
+	if err = store.db.DB().Ping(); err != nil { // ensure immediate connection
 		return nil, err
 	}
 	if err = store.populateCache(); err != nil {
@@ -76,16 +77,15 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn
 }
 
 // Reset deletes the store records and sets the seqnums back to 1
-func (store *sqlStore) Reset() error {
+func (store *sqlStore) Reset() (err error) {
 	s := store.sessionID
-	_, err := store.db.Exec(`DELETE FROM messages
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?`,
+	if err = store.db.Exec(`DELETE FROM messages
+		WHERE beginstring = ? AND session_qualifier = ?
+		AND sendercompid = ? AND sendersubid = ? AND senderlocid = ?
+		AND targetcompid = ? AND targetsubid = ? AND targetlocid = ?`,
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-	if err != nil {
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error; err != nil {
 		return err
 	}
 
@@ -93,22 +93,20 @@ func (store *sqlStore) Reset() error {
 		return err
 	}
 
-	_, err = store.db.Exec(`UPDATE sessions
-		SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=?
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?`,
+	return store.db.Exec(`UPDATE sessions
+		SET creation_time = ?, incoming_seqnum = ?, outgoing_seqnum = ?
+		WHERE beginstring = ? AND session_qualifier = ?
+		AND sendercompid= ? AND sendersubid = ? AND senderlocid = ?
+		AND targetcompid = ? AND targetsubid = ? AND targetlocid = ?`,
 		store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(),
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-
-	return err
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error
 }
 
 // Refresh reloads the store from the database
-func (store *sqlStore) Refresh() error {
-	if err := store.cache.Reset(); err != nil {
+func (store *sqlStore) Refresh() (err error) {
+	if err = store.cache.Reset(); err != nil {
 		return err
 	}
 	return store.populateCache()
@@ -118,17 +116,16 @@ func (store *sqlStore) populateCache() (err error) {
 	s := store.sessionID
 	var creationTime time.Time
 	var incomingSeqNum, outgoingSeqNum int
-	row := store.db.QueryRow(`SELECT creation_time, incoming_seqnum, outgoing_seqnum
-	  FROM sessions
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?`,
+	row := store.db.Raw(`SELECT creation_time, incoming_seqnum, outgoing_seqnum
+	  	FROM sessions
+		WHERE beginstring = ? AND session_qualifier = ?
+		AND sendercompid = ? AND sendersubid = ? AND senderlocid = ?
+		AND targetcompid = ? AND targetsubid = ? AND targetlocid = ?`,
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Row()
 
 	err = row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum)
-
 	// session record found, load it
 	if err == nil {
 		store.cache.creationTime = creationTime
@@ -143,7 +140,7 @@ func (store *sqlStore) populateCache() (err error) {
 	}
 
 	// session record not found, create it
-	_, err = store.db.Exec(`INSERT INTO sessions (
+	return store.db.Exec(`INSERT INTO sessions (
 			creation_time, incoming_seqnum, outgoing_seqnum,
 			beginstring, session_qualifier,
 			sendercompid, sendersubid, senderlocid,
@@ -154,9 +151,7 @@ func (store *sqlStore) populateCache() (err error) {
 		store.cache.NextSenderMsgSeqNum(),
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-
-	return err
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error
 }
 
 // NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent
@@ -172,14 +167,13 @@ func (store *sqlStore) NextTargetMsgSeqNum() int {
 // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent
 func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error {
 	s := store.sessionID
-	_, err := store.db.Exec(`UPDATE sessions SET outgoing_seqnum = ?
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?`,
+	if err := store.db.Exec(`UPDATE sessions SET outgoing_seqnum = ?
+		WHERE beginstring = ? AND session_qualifier = ?
+		AND sendercompid = ? AND sendersubid = ? AND senderlocid = ?
+		AND targetcompid = ? AND targetsubid = ? AND targetlocid = ?`,
 		next, s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-	if err != nil {
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error; err != nil {
 		return err
 	}
 	return store.cache.SetNextSenderMsgSeqNum(next)
@@ -188,14 +182,13 @@ func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error {
 // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received
 func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error {
 	s := store.sessionID
-	_, err := store.db.Exec(`UPDATE sessions SET incoming_seqnum = ?
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?`,
+	if err := store.db.Exec(`UPDATE sessions SET incoming_seqnum = ?
+		WHERE beginstring = ? AND session_qualifier = ?
+		AND sendercompid = ? AND sendersubid = ? AND senderlocid = ?
+		AND targetcompid = ? AND targetsubid = ? AND targetlocid = ?`,
 		next, s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-	if err != nil {
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error; err != nil {
 		return err
 	}
 	return store.cache.SetNextTargetMsgSeqNum(next)
@@ -221,7 +214,7 @@ func (store *sqlStore) CreationTime() time.Time {
 func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error {
 	s := store.sessionID
 
-	_, err := store.db.Exec(`INSERT INTO messages (
+	return store.db.Exec(`INSERT INTO messages (
 			msgseqnum, message,
 			beginstring, session_qualifier,
 			sendercompid, sendersubid, senderlocid,
@@ -230,24 +223,22 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error {
 		seqNum, string(msg),
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
-		s.TargetCompID, s.TargetSubID, s.TargetLocationID)
-
-	return err
+		s.TargetCompID, s.TargetSubID, s.TargetLocationID).Error
 }
 
 func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
 	s := store.sessionID
 	var msgs [][]byte
-	rows, err := store.db.Query(`SELECT message FROM messages
-		WHERE beginstring=? AND session_qualifier=?
-		AND sendercompid=? AND sendersubid=? AND senderlocid=?
-		AND targetcompid=? AND targetsubid=? AND targetlocid=?
-		AND msgseqnum>=? AND msgseqnum<=?
+	rows, err := store.db.Raw(`SELECT message FROM messages
+		WHERE beginstring= ? AND session_qualifier= ?
+		AND sendercompid= ? AND sendersubid= ? AND senderlocid= ?
+		AND targetcompid= ? AND targetsubid= ? AND targetlocid= ?
+		AND msgseqnum>= ? AND msgseqnum<= ?
 		ORDER BY msgseqnum`,
 		s.BeginString, s.Qualifier,
 		s.SenderCompID, s.SenderSubID, s.SenderLocationID,
 		s.TargetCompID, s.TargetSubID, s.TargetLocationID,
-		beginSeqNum, endSeqNum)
+		beginSeqNum, endSeqNum).Rows()
 	if err != nil {
 		return nil, err
 	}
diff --git a/sqlstore_test.go b/sqlstore_test.go
index b252b2c2..977ff5fe 100644
--- a/sqlstore_test.go
+++ b/sqlstore_test.go
@@ -1,7 +1,6 @@
 package quickfix
 
 import (
-	"database/sql"
 	"fmt"
 	"io/ioutil"
 	"os"
@@ -11,6 +10,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/jinzhu/gorm"
 	_ "github.com/mattn/go-sqlite3"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/suite"
@@ -30,14 +30,14 @@ func (suite *SQLStoreTestSuite) SetupTest() {
 	sqlDsn := path.Join(suite.sqlStoreRootPath, fmt.Sprintf("%d.db", time.Now().UnixNano()))
 
 	// create tables
-	db, err := sql.Open(sqlDriver, sqlDsn)
+	db, err := gorm.Open(sqlDriver, sqlDsn)
 	require.Nil(suite.T(), err)
 	ddlFnames, err := filepath.Glob(fmt.Sprintf("_sql/%s/*.sql", sqlDriver))
 	require.Nil(suite.T(), err)
 	for _, fname := range ddlFnames {
 		sqlBytes, err := ioutil.ReadFile(fname)
 		require.Nil(suite.T(), err)
-		_, err = db.Exec(string(sqlBytes))
+		err = db.Exec(string(sqlBytes)).Error
 		require.Nil(suite.T(), err)
 	}
 
diff --git a/store_test.go b/store_test.go
index a61ccbb0..e9d6350e 100644
--- a/store_test.go
+++ b/store_test.go
@@ -34,8 +34,8 @@ func (suite *MessageStoreTestSuite) TestMessageStore_SetNextMsgSeqNum_Refresh_In
 	t := suite.T()
 
 	// Given a MessageStore with the following sender and target seqnums
-	suite.msgStore.SetNextSenderMsgSeqNum(867)
-	suite.msgStore.SetNextTargetMsgSeqNum(5309)
+	require.Nil(t, suite.msgStore.SetNextSenderMsgSeqNum(867))
+	require.Nil(t, suite.msgStore.SetNextTargetMsgSeqNum(5309))
 
 	// When the store is refreshed from its backing store
 	suite.msgStore.Refresh()
@@ -75,7 +75,7 @@ func (suite *MessageStoreTestSuite) TestMessageStore_Reset() {
 	assert.Equal(t, 1, suite.msgStore.NextTargetMsgSeqNum())
 
 	// When the store is refreshed from its backing store
-	suite.msgStore.Refresh()
+	require.Nil(t, suite.msgStore.Refresh())
 
 	// Then the sender and target seqnums should still be
 	assert.Equal(t, 1, suite.msgStore.NextSenderMsgSeqNum())
-- 
GitLab