Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/hotfix-message/src/encoding/field_types/date.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::encoding::{Buffer, FieldType};
#[cfg(feature = "utils-chrono")]
use chrono::Datelike;
use std::convert::{TryFrom, TryInto};

const LEN_IN_BYTES: usize = 8;
Expand Down Expand Up @@ -112,6 +114,17 @@ impl Date {
}
}

#[cfg(feature = "utils-chrono")]
impl From<chrono::NaiveDate> for Date {
fn from(chrono_date: chrono::NaiveDate) -> Self {
Self {
year: chrono_date.year() as u32,
month: chrono_date.month(),
day: chrono_date.day(),
}
}
}

impl<'a> FieldType<'a> for Date {
type Error = &'static str;
type SerializeSettings = ();
Expand Down
14 changes: 14 additions & 0 deletions crates/hotfix-message/src/encoding/field_types/time.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::encoding::{Buffer, FieldType};
#[cfg(feature = "utils-chrono")]
use chrono::Timelike;

const ERR_INVALID: &str = "Invalid time.";

Expand Down Expand Up @@ -142,6 +144,18 @@ impl Time {
}
}

#[cfg(feature = "utils-chrono")]
impl From<chrono::NaiveTime> for Time {
fn from(chrono_time: chrono::NaiveTime) -> Self {
Self {
hour: chrono_time.hour(),
minute: chrono_time.minute(),
second: chrono_time.second(),
milli: chrono_time.nanosecond() / 1_000_000,
}
}
}

impl<'a> FieldType<'a> for Time {
type Error = &'static str;
type SerializeSettings = ();
Expand Down
10 changes: 10 additions & 0 deletions crates/hotfix-message/src/encoding/field_types/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ impl Timestamp {
}
}

#[cfg(feature = "utils-chrono")]
impl From<chrono::NaiveDateTime> for Timestamp {
fn from(chrono_datetime: chrono::NaiveDateTime) -> Self {
Self {
date: chrono_datetime.date().into(),
time: chrono_datetime.time().into(),
}
}
}

impl<'a> FieldType<'a> for Timestamp {
type Error = &'static str;
type SerializeSettings = ();
Expand Down
2 changes: 2 additions & 0 deletions crates/hotfix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ tracing = { workspace = true }
uuid = { workspace = true, features = ["v4"] }

[dev-dependencies]
hotfix-message = { version = "0.2.1", path = "../hotfix-message", features = ["utils-chrono"] }

testcontainers = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }
14 changes: 13 additions & 1 deletion crates/hotfix/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use hotfix_message::field_types::Timestamp;
use thiserror::Error;

#[derive(Debug, Error)]
Expand All @@ -21,13 +22,24 @@ pub enum MessageVerificationError {
IncorrectBeginString(String),

/// The comp ID is different from our expectations.
#[allow(dead_code)]
#[error("incorrect comp id {comp_id} ({comp_id_type:?})")]
IncorrectCompId {
comp_id: String,
comp_id_type: CompIdType,
msg_seq_num: u64,
},
/// Original sending time is not provided despite PossDupFlag being set.
#[error("original sending time missing")]
OriginalSendingTimeMissing { msg_seq_num: u64 },
/// The original sending time is after the sending time of the message.
#[error(
"original sending time {original_sending_time:?} is after sending time {sending_time:?}"
)]
OriginalSendingTimeAfterSendingTime {
msg_seq_num: u64,
original_sending_time: Timestamp,
sending_time: Timestamp,
},
}

#[derive(Debug)]
Expand Down
94 changes: 74 additions & 20 deletions crates/hotfix/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ use crate::message_utils::{is_admin, prepare_message_for_resend};
use crate::session::state::{AwaitingResendTransitionOutcome, TestRequestId};
use crate::session_schedule::SessionSchedule;
use event::SessionEvent;
use hotfix_message::fix44::{POSS_DUP_FLAG, SessionRejectReason};
use hotfix_message::field_types::Timestamp;
use hotfix_message::fix44::{ORIG_SENDING_TIME, POSS_DUP_FLAG, SENDING_TIME, SessionRejectReason};
use hotfix_message::parsed_message::{InvalidReason, ParsedMessage};
use state::SessionState;

Expand Down Expand Up @@ -255,25 +256,6 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
let expected_seq_number = self.store.next_target_seq_number();
let actual_seq_number: u64 = message.header().get(fix44::MSG_SEQ_NUM).unwrap_or_default();

match actual_seq_number.cmp(&expected_seq_number) {
Ordering::Greater => {
return Err(MessageVerificationError::SeqNumberTooHigh {
expected: expected_seq_number,
actual: actual_seq_number,
});
}
Ordering::Less => {
let possible_duplicate =
message.header().get::<bool>(POSS_DUP_FLAG).unwrap_or(false);
return Err(MessageVerificationError::SeqNumberTooLow {
expected: expected_seq_number,
actual: actual_seq_number,
possible_duplicate,
});
}
_ => {}
}

// our TargetCompId is always the same as the expected SenderCompId for them
let expected_sender_comp_id: &str = self.config.target_comp_id.as_str();
let actual_sender_comp_id: &str = message.header().get(fix44::SENDER_COMP_ID).unwrap_or("");
Expand All @@ -296,6 +278,49 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
});
}

let possible_duplicate = message.header().get::<bool>(POSS_DUP_FLAG).unwrap_or(false);
if possible_duplicate {
match message.header().get::<Timestamp>(ORIG_SENDING_TIME) {
Ok(original_sending_time) => {
if let Ok(sending_time) = message.header().get::<Timestamp>(SENDING_TIME) {
// TODO: check presence of sending time (see related test cases https://www.fixtrading.org/standards/fix-session-testcases-online/#scenario-2-receive-message-standard-header)
if original_sending_time > sending_time {
return Err(
MessageVerificationError::OriginalSendingTimeAfterSendingTime {
msg_seq_num: actual_seq_number,
original_sending_time,
sending_time,
},
);
}
}
}
Err(err) => {
error!(error = debug(err), "original sending time is missing");
return Err(MessageVerificationError::OriginalSendingTimeMissing {
msg_seq_num: actual_seq_number,
});
}
}
}

match actual_seq_number.cmp(&expected_seq_number) {
Ordering::Greater => {
return Err(MessageVerificationError::SeqNumberTooHigh {
expected: expected_seq_number,
actual: actual_seq_number,
});
}
Ordering::Less => {
return Err(MessageVerificationError::SeqNumberTooLow {
expected: expected_seq_number,
actual: actual_seq_number,
possible_duplicate,
});
}
_ => {}
}

Ok(())
}

Expand Down Expand Up @@ -473,6 +498,15 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
self.handle_incorrect_comp_id(comp_id, comp_id_type, msg_seq_num)
.await;
}
MessageVerificationError::OriginalSendingTimeMissing { msg_seq_num } => {
self.handle_original_sending_time_missing(msg_seq_num).await;
}
MessageVerificationError::OriginalSendingTimeAfterSendingTime {
msg_seq_num, ..
} => {
self.handle_original_sending_time_after_sending_time(msg_seq_num)
.await
}
}
}

Expand Down Expand Up @@ -575,6 +609,26 @@ impl<M: FixMessage, S: MessageStore> Session<M, S> {
}
}

async fn handle_original_sending_time_after_sending_time(&mut self, msg_seq_num: u64) {
let reject = Reject::new(msg_seq_num)
.session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem)
.text("original sending time is after sending time");
self.send_message(reject).await;
if let Err(err) = self.store.increment_target_seq_number().await {
error!("failed to increment target seq number: {:?}", err);
};
}

async fn handle_original_sending_time_missing(&mut self, msg_seq_num: u64) {
let reject = Reject::new(msg_seq_num)
.session_reject_reason(SessionRejectReason::RequiredTagMissing)
.text("original sending time is required");
self.send_message(reject).await;
if let Err(err) = self.store.increment_target_seq_number().await {
error!("failed to increment target seq number: {:?}", err);
};
}

async fn resend_messages(&mut self, begin: usize, end: usize, _message: &Message) {
debug!(begin, end, "resending messages as requested");
let messages = self.store.get_slice(begin, end).await.unwrap();
Expand Down
40 changes: 40 additions & 0 deletions crates/hotfix/tests/common/test_messages.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::common::setup::{COUNTERPARTY_COMP_ID, OUR_COMP_ID};
use chrono::TimeDelta;
use hotfix::Message as HotfixMessage;
use hotfix::message::{FixMessage, generate_message};
use hotfix_message::dict::{FieldLocation, FixDatatype};
use hotfix_message::field_types::Timestamp;
use hotfix_message::message::{Config, Message};
use hotfix_message::{HardCodedFixFieldDefinition, Part, fix44};
use std::ops::Add;

/// Business messages used for testing.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -277,6 +279,44 @@ pub fn build_execution_report_with_custom_msg_type(msg_seq_num: u64, msg_type: &
msg.encode(&Config::default()).unwrap()
}

pub fn build_execution_report_with_incorrect_orig_sending_time(msg_seq_num: u64) -> Vec<u8> {
let report = TestMessage::dummy_execution_report();

let mut msg = Message::new("FIX.4.4", "8");
msg.set(fix44::SENDER_COMP_ID, COUNTERPARTY_COMP_ID);
msg.set(fix44::TARGET_COMP_ID, OUR_COMP_ID);
msg.set(fix44::MSG_SEQ_NUM, msg_seq_num);

let sending_time = Timestamp::utc_now();
let original_sending_time: Timestamp = sending_time
.to_chrono_naive()
.unwrap()
.add(TimeDelta::seconds(1))
.into();
msg.set(fix44::SENDING_TIME, sending_time);
msg.set(fix44::ORIG_SENDING_TIME, original_sending_time);
msg.set(fix44::POSS_DUP_FLAG, "Y");

report.write(&mut msg);

msg.encode(&Config::default()).unwrap()
}

pub fn build_execution_report_with_missing_orig_sending_time(msg_seq_num: u64) -> Vec<u8> {
let report = TestMessage::dummy_execution_report();

let mut msg = Message::new("FIX.4.4", "8");
msg.set(fix44::SENDER_COMP_ID, COUNTERPARTY_COMP_ID);
msg.set(fix44::TARGET_COMP_ID, OUR_COMP_ID);
msg.set(fix44::MSG_SEQ_NUM, msg_seq_num);
msg.set(fix44::SENDING_TIME, Timestamp::utc_now());
msg.set(fix44::POSS_DUP_FLAG, "Y");

report.write(&mut msg);

msg.encode(&Config::default()).unwrap()
}

/// Replaces the value of a field in a raw FIX message.
pub fn replace_field_value(raw_message: &mut Vec<u8>, tag: u32, new_value: &[u8]) {
let tag_bytes = format!("{}=", tag).into_bytes();
Expand Down
77 changes: 76 additions & 1 deletion crates/hotfix/tests/session_test_cases/invalid_message_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use crate::common::test_messages::{
build_execution_report_with_custom_msg_type,
build_execution_report_with_incorrect_begin_string,
build_execution_report_with_incorrect_body_length,
build_execution_report_with_incorrect_orig_sending_time,
build_execution_report_with_missing_orig_sending_time,
};
use hotfix::session::Status;
use hotfix_message::Part;
use hotfix_message::fix44::{MSG_TYPE, SESSION_REJECT_REASON};
use hotfix_message::fix44::{MSG_TYPE, SESSION_REJECT_REASON, SessionRejectReason};

/// Tests that when a counterparty sends a message containing an invalid/unrecognised field,
/// the session rejects the message by sending a Reject (MsgType=3) message back.
Expand Down Expand Up @@ -227,3 +229,76 @@ async fn test_message_with_sequence_number_too_low_possdup_ignored() {
when(&session).requests_disconnect().await;
then(&mut mock_counterparty).gets_disconnected().await;
}

/// Tests that a message with `OrigSendingTime` after `SendingTime` is rejected
/// with an appropriate rejection reason.
#[tokio::test]
async fn test_message_with_incorrect_orig_sending_time_is_rejected() {
let (session, mut mock_counterparty) = given_an_active_session().await;

// A valid execution report is sent and processed normally
let seq_number = mock_counterparty.next_target_sequence_number();
when(&mut mock_counterparty)
.sends_message(TestMessage::dummy_execution_report())
.await;
then(&session)
.target_sequence_number_reaches(seq_number)
.await;

// the same is resent with PossDupFlag=Y, but with OriginalSendingTime after SendingTime
when(&mut mock_counterparty)
.sends_raw_message(build_execution_report_with_incorrect_orig_sending_time(
seq_number,
))
.await;
then(&mut mock_counterparty)
.receives(|msg| {
assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "3");
assert_eq!(
msg.get::<SessionRejectReason>(SESSION_REJECT_REASON)
.unwrap(),
SessionRejectReason::SendingtimeAccuracyProblem
);
})
.await;

when(&session).requests_disconnect().await;
then(&mut mock_counterparty).gets_disconnected().await;
}

/// Tests that a message with missing `OrigSendingTime` is rejected.
///
/// `OrigSendingTime` is required when `PossDupFlag` is set to `Y`.
#[tokio::test]
async fn test_message_with_missing_orig_sending_time_is_rejected() {
let (session, mut mock_counterparty) = given_an_active_session().await;

// A valid execution report is sent and processed normally
let seq_number = mock_counterparty.next_target_sequence_number();
when(&mut mock_counterparty)
.sends_message(TestMessage::dummy_execution_report())
.await;
then(&session)
.target_sequence_number_reaches(seq_number)
.await;

// the same is resent with PossDupFlag=Y, but with OriginalSendingTime after SendingTime
when(&mut mock_counterparty)
.sends_raw_message(build_execution_report_with_missing_orig_sending_time(
seq_number,
))
.await;
then(&mut mock_counterparty)
.receives(|msg| {
assert_eq!(msg.header().get::<&str>(MSG_TYPE).unwrap(), "3");
assert_eq!(
msg.get::<SessionRejectReason>(SESSION_REJECT_REASON)
.unwrap(),
SessionRejectReason::RequiredTagMissing
);
})
.await;

when(&session).requests_disconnect().await;
then(&mut mock_counterparty).gets_disconnected().await;
}