diff --git a/field_map.go b/field_map.go index 82d8454ab080d255d1c17d92d7bcc2b01a6acaa4..7410242acaa5a22617fd9483b5174534148a9011 100644 --- a/field_map.go +++ b/field_map.go @@ -70,6 +70,16 @@ func (m FieldMap) GetField(tag Tag, parser FieldValueReader) MessageRejectError return nil } +//GetBytes is a zero-copy GetField wrapper for []bytes fields +func (m FieldMap) GetBytes(tag Tag) ([]byte, MessageRejectError) { + tagValues, ok := m.tagLookup[tag] + if !ok { + return nil, ConditionallyRequiredFieldMissing(tag) + } + + return tagValues[0].value, nil +} + //GetBool is a GetField wrapper for bool fields func (m FieldMap) GetBool(tag Tag) (bool, MessageRejectError) { var val FIXBoolean diff --git a/field_map_test.go b/field_map_test.go index febb07818a8faa92928e21e231ad1b6e432c3be9..ce6a791198fc10aaa8397196acb5205144e91454 100644 --- a/field_map_test.go +++ b/field_map_test.go @@ -1,7 +1,10 @@ package quickfix import ( + "bytes" "testing" + + "github.com/stretchr/testify/assert" ) func TestFieldMap_Clear(t *testing.T) { @@ -40,19 +43,12 @@ func TestFieldMap_SetAndGet(t *testing.T) { err := fMap.GetField(tc.tag, &testField) if tc.expectErr { - if err == nil { - t.Error("Expected Error") - } + assert.NotNil(t, err, "Expected Error") continue } - if err != nil { - t.Error("Unexpected Error", err) - } - - if string(testField) != tc.expectValue { - t.Errorf("Expected %v got %v", tc.expectValue, testField) - } + assert.Nil(t, err, "Unexpected error") + assert.Equal(t, tc.expectValue, string(testField)) } } @@ -64,10 +60,7 @@ func TestFieldMap_Length(t *testing.T) { fMap.SetField(8, FIXString("FIX.4.4")) fMap.SetField(9, FIXInt(100)) fMap.SetField(10, FIXString("100")) - - if fMap.length() != 16 { - t.Error("Length should include all fields but beginString, bodyLength, and checkSum- got ", fMap.length()) - } + assert.Equal(t, 16, fMap.length(), "Length should include all fields but beginString, bodyLength, and checkSum") } func TestFieldMap_Total(t *testing.T) { @@ -80,9 +73,7 @@ func TestFieldMap_Total(t *testing.T) { fMap.SetField(Tag(9), FIXInt(100)) fMap.SetField(10, FIXString("100")) - if fMap.total() != 2116 { - t.Error("Total should includes all fields but checkSum- got ", fMap.total()) - } + assert.Equal(t, 2116, fMap.total(), "Total should includes all fields but checkSum") } func TestFieldMap_TypedSetAndGet(t *testing.T) { @@ -93,30 +84,23 @@ func TestFieldMap_TypedSetAndGet(t *testing.T) { fMap.SetInt(2, 256) s, err := fMap.GetString(1) - if err != nil { - t.Error("Unexpected Error", err) - } else if s != "hello" { - t.Errorf("Expected %v got %v", "hello", s) - } + assert.Nil(t, err) + assert.Equal(t, "hello", s) i, err := fMap.GetInt(2) - if err != nil { - t.Error("Unexpected Error", err) - } else if i != 256 { - t.Errorf("Expected %v got %v", 256, i) - } + assert.Nil(t, err) + assert.Equal(t, 256, i) _, err = fMap.GetInt(1) - if err == nil { - t.Error("Type mismatch should occur error but nil") - } + assert.NotNil(t, err, "Type mismatch should occur error") s, err = fMap.GetString(2) - if err != nil { - t.Error("Type mismatch should occur error but nil") - } else if s != "256" { - t.Errorf("Expected %v got %v", "256", s) - } + assert.Nil(t, err) + assert.Equal(t, "256", s) + + b, err := fMap.GetBytes(1) + assert.Nil(t, err) + assert.True(t, bytes.Equal([]byte("hello"), b)) } func TestFieldMap_BoolTypedSetAndGet(t *testing.T) { @@ -125,25 +109,19 @@ func TestFieldMap_BoolTypedSetAndGet(t *testing.T) { fMap.SetBool(1, true) v, err := fMap.GetBool(1) - if err != nil { - t.Error("Unexpected Error", err) - } else if !v { - t.Errorf("Expected %v got %v", true, v) - } - s, _ := fMap.GetString(1) - if s != "Y" { - t.Errorf("Expected %v got %v", "Y", s) - } + assert.Nil(t, err) + assert.True(t, v) + + s, err := fMap.GetString(1) + assert.Nil(t, err) + assert.Equal(t, "Y", s) fMap.SetBool(2, false) v, err = fMap.GetBool(2) - if err != nil { - t.Error("Unexpected Error", err) - } else if v { - t.Errorf("Expected %v got %v", false, v) - } - s, _ = fMap.GetString(2) - if s != "N" { - t.Errorf("Expected %v got %v", "N", s) - } + assert.Nil(t, err) + assert.False(t, v) + + s, err = fMap.GetString(2) + assert.Nil(t, err) + assert.Equal(t, "N", s) } diff --git a/fix_bytes.go b/fix_bytes.go new file mode 100644 index 0000000000000000000000000000000000000000..5e3bb13003acac90ae41205338c4ab450ba5ef71 --- /dev/null +++ b/fix_bytes.go @@ -0,0 +1,13 @@ +package quickfix + +//FIXBytes is a generic FIX field value, implements FieldValue. Enables zero copy read from a FieldMap +type FIXBytes []byte + +func (f *FIXBytes) Read(bytes []byte) (err error) { + *f = FIXBytes(bytes) + return +} + +func (f FIXBytes) Write() []byte { + return []byte(f) +} diff --git a/fix_bytes_test.go b/fix_bytes_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4dc94b56b855ab513e82a15298bd7e329022194a --- /dev/null +++ b/fix_bytes_test.go @@ -0,0 +1,22 @@ +package quickfix + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFIXBytesWrite(t *testing.T) { + val := []byte("hello") + field := FIXBytes(val) + b := field.Write() + + assert.Equal(t, val, b) +} + +func TestFIXBytesRead(t *testing.T) { + val := []byte("world") + var field FIXBytes + assert.Nil(t, field.Read(val)) + assert.Equal(t, val, []byte(field)) +} diff --git a/in_session.go b/in_session.go index 87cbd977ff35ca145bc9d0f3d005134a6b8704fa..527de63ada7889f0fe4f75d7aedeb5aec3c0b064 100644 --- a/in_session.go +++ b/in_session.go @@ -212,7 +212,7 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int nextSeqNum := seqNum for _, msgBytes := range msgs { msg, _ := ParseMessage(msgBytes) - msgType, _ := msg.Header.GetString(tagMsgType) + msgType, _ := msg.Header.GetBytes(tagMsgType) sentMessageSeqNum, _ := msg.Header.GetInt(tagMsgSeqNum) if isAdminMessageType(msgType) { diff --git a/message_router.go b/message_router.go index 586b40573840042d11c0f515b412250710aa9f06..0e2343ac540b9ba83b2c1a9e5a70b2a378e9d312 100644 --- a/message_router.go +++ b/message_router.go @@ -44,7 +44,9 @@ func (c MessageRouter) Route(msg Message, sessionID SessionID) MessageRejectErro func (c MessageRouter) tryRoute(beginString string, msgType string, msg Message, sessionID SessionID) MessageRejectError { fixVersion := beginString - if beginString == enum.BeginStringFIXT11 && !isAdminMessageType(msgType) { + isAdminMsg := isAdminMessageType([]byte(msgType)) + + if beginString == enum.BeginStringFIXT11 && !isAdminMsg { var applVerID FIXString if err := msg.Header.GetField(tagApplVerID, &applVerID); err != nil { session, _ := lookupSession(sessionID) @@ -72,7 +74,7 @@ func (c MessageRouter) tryRoute(beginString string, msgType string, msg Message, } switch { - case isAdminMessageType(msgType) || msgType == "j": + case isAdminMsg || msgType == "j": return nil } diff --git a/msg_type.go b/msg_type.go index e1d8aadbe485894569656834c396cbaa81ce1f94..03faed69814599a5ed79543321db01d07bdec05d 100644 --- a/msg_type.go +++ b/msg_type.go @@ -1,9 +1,25 @@ package quickfix -//IsAdminMessageType returns true if the message type is a sesion level message. -func isAdminMessageType(m string) bool { - switch m { - case "0", "A", "1", "2", "3", "4", "5": +import "bytes" + +var msgTypeHeartbeat = []byte("0") +var msgTypeLogon = []byte("A") +var msgTypeTestRequest = []byte("1") +var msgTypeResendRequest = []byte("2") +var msgTypeReject = []byte("3") +var msgTypeSequenceReset = []byte("4") +var msgTypeLogout = []byte("5") + +//isAdminMessageType returns true if the message type is a session level message. +func isAdminMessageType(m []byte) bool { + switch { + case bytes.Equal(msgTypeHeartbeat, m), + bytes.Equal(msgTypeLogon, m), + bytes.Equal(msgTypeTestRequest, m), + bytes.Equal(msgTypeResendRequest, m), + bytes.Equal(msgTypeReject, m), + bytes.Equal(msgTypeSequenceReset, m), + bytes.Equal(msgTypeLogout, m): return true } diff --git a/session.go b/session.go index 23e694808fe0c7b40be2cd9f16dd2f16d13500d6..fc68adc9be2b0b439cd61a265f48b2fa12ff8fd3 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package quickfix import ( + "bytes" "errors" "fmt" "sync" @@ -274,15 +275,15 @@ func (s *session) prepMessageForSend(msg Message, inReplyTo *Message) (msgBytes seqNum := s.store.NextSenderMsgSeqNum() msg.Header.SetField(tagMsgSeqNum, FIXInt(seqNum)) - msgType, err := msg.MsgType() + msgType, err := msg.Header.GetBytes(tagMsgType) if err != nil { return } - if isAdminMessageType(string(msgType)) { + if isAdminMessageType(msgType) { s.application.ToAdmin(msg, s.sessionID) - if msgType == enum.MsgType_LOGON { + if bytes.Equal(msgType, msgTypeLogon) { var resetSeqNumFlag FIXBoolean if msg.Body.Has(tagResetSeqNumFlag) { if err = msg.Body.GetField(tagResetSeqNumFlag, &resetSeqNumFlag); err != nil { @@ -508,12 +509,12 @@ func (s *session) verifySelect(msg Message, checkTooHigh bool, checkTooLow bool) } func (s *session) fromCallback(msg Message) MessageRejectError { - msgType, err := msg.MsgType() + msgType, err := msg.Header.GetBytes(tagMsgType) if err != nil { return err } - if isAdminMessageType(string(msgType)) { + if isAdminMessageType(msgType) { return s.application.FromAdmin(msg, s.sessionID) } diff --git a/validation.go b/validation.go index 5c4013bdbd2dc63368fbfe866c39bf529191b34e..f408a5aabf48f436ee4d094b599f10a6591f8e0d 100644 --- a/validation.go +++ b/validation.go @@ -53,7 +53,7 @@ func (v *fixtValidator) Validate(msg Message) MessageRejectError { return err } - if isAdminMessageType(msgType) { + if isAdminMessageType([]byte(msgType)) { return validateFIX(v.transportDataDictionary, v.settings, msgType, msg) } return validateFIXT(v.transportDataDictionary, v.appDataDictionary, v.settings, msgType, msg)