Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions libs/api/include/rtbot/Program.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ class Program {

// Message processing
ProgramMsgBatch receive(const Message<NumberData>& msg, const std::string& port_id = "i1") {
send_to_entry(msg, port_id);
send_to_entry(msg, port_id, false);
ProgramMsgBatch result = collect_outputs(false);
clear_all_outputs();
return result;
}

ProgramMsgBatch receive_debug(const Message<NumberData>& msg, const std::string& port_id = "i1") {
send_to_entry(msg, port_id);
send_to_entry(msg, port_id, true);
ProgramMsgBatch result = collect_outputs(true);
clear_all_outputs();
return result;
Expand Down Expand Up @@ -229,10 +229,10 @@ class Program {
throw runtime_error("Could not resolve operator ID: " + id);
}

void send_to_entry(const Message<NumberData>& msg, const std::string& port_id) {
void send_to_entry(const Message<NumberData>& msg, const std::string& port_id, bool debug=false) {
auto port_info = OperatorJson::parse_port_name(port_id);
operators_[entry_operator_id_]->receive_data(create_message<NumberData>(msg.time, msg.data), port_info.index);
operators_[entry_operator_id_]->execute();
operators_[entry_operator_id_]->execute(debug);
}

ProgramMsgBatch collect_outputs(bool debug_mode = false) {
Expand All @@ -255,7 +255,7 @@ class Program {
// In debug mode, collect all ports
if (debug_mode) {
for (size_t i = 0; i < op->num_output_ports(); i++) {
const auto& queue = op->get_output_queue(i);
const auto& queue = op->get_debug_output_queue(i);
if (!queue.empty()) {
PortMsgBatch port_msgs;
for (const auto& msg : queue) {
Expand Down
126 changes: 78 additions & 48 deletions libs/core/include/rtbot/Buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,38 @@ class Buffer : public Operator {
return std::sqrt(variance());
}

Bytes collect() override {
Bytes bytes = Operator::collect();
bool equals(const Buffer& other) const {

if (window_size_ != other.window_size_) return false;

if (buffer_.size() != other.buffer_.size()) return false;

auto it1 = buffer_.begin();
auto it2 = other.buffer_.begin();

for (; it1 != buffer_.end() && it2 != other.buffer_.end(); ++it1, ++it2) {
const auto& msg1 = *it1;
const auto& msg2 = *it2;

bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&window_size_),
reinterpret_cast<const uint8_t*>(&window_size_) + sizeof(window_size_));
if (msg1 && msg2) {
if (msg1->time != msg2->time) return false;
if (msg1->hash() != msg2->hash()) return false;
} else return false;
}

if constexpr (Features::TRACK_SUM) {
if (StateSerializer::hash_double(sum_) != StateSerializer::hash_double(other.sum_)) return false;
}

if constexpr (Features::TRACK_VARIANCE) {
if (StateSerializer::hash_double(M2_) != StateSerializer::hash_double(other.M2_)) return false;
}

return Operator::equals(other);
}

Bytes collect() override {
Bytes bytes = Operator::collect();

size_t buffer_size = buffer_.size();
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&buffer_size),
Expand All @@ -90,65 +117,68 @@ class Buffer : public Operator {
bytes.insert(bytes.end(), msg_bytes.begin(), msg_bytes.end());
}

if constexpr (Features::TRACK_SUM) {
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&sum_),
reinterpret_cast<const uint8_t*>(&sum_) + sizeof(sum_));
}

if constexpr (Features::TRACK_VARIANCE) {
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&M2_),
reinterpret_cast<const uint8_t*>(&M2_) + sizeof(M2_));
}

return bytes;
}

void restore(Bytes::const_iterator& it) override {
Operator::restore(it);

window_size_ = *reinterpret_cast<const size_t*>(&(*it));
it += sizeof(size_t);
// Call base restore first
Operator::restore(it);

size_t buffer_size = *reinterpret_cast<const size_t*>(&(*it));
it += sizeof(size_t);
// ---- Read buffer_size safely ----
size_t buffer_size;
std::memcpy(&buffer_size, &(*it), sizeof(buffer_size));
it += sizeof(buffer_size);

// ---- Deserialize buffer ----
buffer_.clear();
for (size_t i = 0; i < buffer_size; ++i) {
size_t msg_size = *reinterpret_cast<const size_t*>(&(*it));
it += sizeof(size_t);

Bytes msg_bytes(it, it + msg_size);
buffer_.push_back(
std::unique_ptr<Message<T>>(dynamic_cast<Message<T>*>(BaseMessage::deserialize(msg_bytes).release())));
it += msg_size;
// Read size of each message
size_t msg_size;
std::memcpy(&msg_size, &(*it), sizeof(msg_size));
it += sizeof(msg_size);

// Extract message bytes
Bytes msg_bytes(it, it + msg_size);

// Deserialize message and cast to derived type
buffer_.push_back(
std::unique_ptr<Message<T>>(
dynamic_cast<Message<T>*>(BaseMessage::deserialize(msg_bytes).release())
)
);

it += msg_size;
}

// ---- Optional statistics ----
if constexpr (Features::TRACK_SUM) {
sum_ = *reinterpret_cast<const double*>(&(*it));
it += sizeof(double);
sum_ = 0.0;
if (!buffer_.empty()) {
// First pass: compute sum
for (const auto& msg : buffer_) {
sum_ += msg->data.value;
}
}
}

if constexpr (Features::TRACK_VARIANCE) {
M2_ = *reinterpret_cast<const double*>(&(*it));
it += sizeof(double);

// Recompute statistics from buffer to ensure consistency
sum_ = 0.0;
M2_ = 0.0;

if (!buffer_.empty()) {
// First pass: compute mean
for (const auto& msg : buffer_) {
sum_ += msg->data.value;
// Recompute statistics from buffer to ensure consistency
sum_ = 0.0;
M2_ = 0.0;

if (!buffer_.empty()) {
// First pass: compute sum
for (const auto& msg : buffer_) {
sum_ += msg->data.value;
}

// Second pass: compute M2
double mean = sum_ / buffer_.size();
for (const auto& msg : buffer_) {
double delta = msg->data.value - mean;
M2_ += delta * delta;
}
}

// Second pass: compute M2
double mean = sum_ / buffer_.size();
for (const auto& msg : buffer_) {
double delta = msg->data.value - mean;
M2_ += delta * delta;
}
}
}
}

Expand Down
169 changes: 43 additions & 126 deletions libs/core/include/rtbot/Demultiplexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ class Demultiplexer : public Operator {
}

// Add single data input port with type T
add_data_port<T>();
data_time_tracker_ = std::set<timestamp_t>();
add_data_port<T>();

// Add corresponding control ports (always boolean)
for (size_t i = 0; i < num_ports; ++i) {
add_control_port<BooleanData>();
control_time_tracker_[i] = std::map<timestamp_t, bool>();
add_control_port<BooleanData>();
}

// Add output ports (same type as input)
Expand All @@ -37,143 +35,62 @@ class Demultiplexer : public Operator {

std::string type_name() const override { return "Demultiplexer"; }

size_t get_num_ports() const { return control_time_tracker_.size(); }
size_t get_num_ports() const { return num_control_ports(); }

Bytes collect() override {
Bytes bytes = Operator::collect(); // First collect base state

// Serialize data time tracker
StateSerializer::serialize_timestamp_set(bytes, data_time_tracker_);

// Serialize control time tracker
StateSerializer::serialize_port_control_map(bytes, control_time_tracker_);

return bytes;
bool equals(const Demultiplexer& other) const {
return Operator::equals(other);
}

void restore(Bytes::const_iterator& it) override {
// First restore base state
Operator::restore(it);

// Clear current state
data_time_tracker_.clear();
control_time_tracker_.clear();

// Restore data time tracker
StateSerializer::deserialize_timestamp_set(it, data_time_tracker_);

// Restore control time tracker
StateSerializer::deserialize_port_control_map(it, control_time_tracker_);

// Validate control port count
StateSerializer::validate_port_count(control_time_tracker_.size(), num_control_ports(), "Control");

bool operator==(const Demultiplexer& other) const {
return equals(other);
}

void reset() override {
Operator::reset();
data_time_tracker_.clear();
control_time_tracker_.clear();
bool operator!=(const Demultiplexer& other) const {
return !(*this == other);
}

void receive_data(std::unique_ptr<BaseMessage> msg, size_t port_index) override {
auto time = msg->time;
Operator::receive_data(std::move(msg), port_index);

data_time_tracker_.insert(time);
}
protected:

void receive_control(std::unique_ptr<BaseMessage> msg, size_t port_index) override {
if (port_index >= num_control_ports()) {
throw std::runtime_error("Invalid control port index");
}

auto* ctrl_msg = dynamic_cast<const Message<BooleanData>*>(msg.get());
if (!ctrl_msg) {
throw std::runtime_error("Invalid control message type");
}

// Update control tracker
control_time_tracker_[port_index][ctrl_msg->time] = ctrl_msg->data.value;

// Add message to queue
get_control_queue(port_index).push_back(std::move(msg));
control_ports_with_new_data_.insert(port_index);
}

protected:
void process_data() override {
while (true) {
// Find oldest common control timestamp
auto common_control_time = TimestampTracker::find_oldest_common_time(control_time_tracker_);
if (!common_control_time) {
break;
}

// Clean up any old input data messages
auto& data_queue = get_data_queue(0);
while (!data_queue.empty()) {
auto* msg = dynamic_cast<const Message<T>*>(data_queue.front().get());
if (msg && msg->time < *common_control_time) {
data_time_tracker_.erase(msg->time);
data_queue.pop_front();
} else {
break;
}
}

// Look for matching data message
bool message_found = false;
if (!data_queue.empty()) {
auto* msg = dynamic_cast<const Message<T>*>(data_queue.front().get());
if (msg && msg->time == *common_control_time) {
// Get active control ports
std::vector<size_t> active_ports;
for (size_t i = 0; i < num_control_ports(); ++i) {
if (control_time_tracker_[i].at(*common_control_time)) {
active_ports.push_back(i);
}
while(true) {

bool is_any_control_empty;
bool are_controls_sync;
do {
is_any_control_empty = false;
are_controls_sync = sync_control_inputs();
for (int i=0; i < num_control_ports(); i++) {
if (get_control_queue(i).empty()) {
is_any_control_empty = true;
break;
}
}
} while (!are_controls_sync && !is_any_control_empty );

// Route message to all active ports
for (size_t port : active_ports) {
get_output_queue(port).push_back(data_queue.front()->clone());
}

data_time_tracker_.erase(msg->time);
data_queue.pop_front();
message_found = true;
}
}

clean_up_control_messages(*common_control_time);
if (!are_controls_sync) return;

if (!message_found) {
break;
}
}
}

private:
void clean_up_control_messages(timestamp_t time) {
for (auto& [port, tracker] : control_time_tracker_) {
tracker.erase(time);
}

for (size_t port = 0; port < num_control_ports(); ++port) {
auto& queue = get_control_queue(port);
while (!queue.empty()) {
auto* msg = dynamic_cast<const Message<BooleanData>*>(queue.front().get());
if (msg && msg->time <= time) {
queue.pop_front();
} else {
break;
auto& data_queue = get_data_queue(0);
if (data_queue.empty()) return;
auto* msg = dynamic_cast<const Message<T>*>(data_queue.front().get());
auto* ctrl_msg = dynamic_cast<const Message<BooleanData>*>(get_control_queue(0).front().get());
if (msg && ctrl_msg && msg->time == ctrl_msg->time) {
for (int i = 0; i < num_control_ports(); i++) {
ctrl_msg = dynamic_cast<const Message<BooleanData>*>(get_control_queue(i).front().get());
if (ctrl_msg->data.value) {
get_output_queue(i).push_back(data_queue.front()->clone());
}
get_control_queue(i).pop_front();
}
data_queue.pop_front();
} else if (msg && ctrl_msg && msg->time < ctrl_msg->time) {
data_queue.pop_front();
} else if (msg && ctrl_msg && ctrl_msg->time < msg->time) {
for (int i = 0; i < num_control_ports(); i++)
get_control_queue(i).pop_front();

}
}
}

std::set<timestamp_t> data_time_tracker_;
std::map<size_t, std::map<timestamp_t, bool>> control_time_tracker_;
};

// Factory functions for common configurations using PortType
Expand Down
Loading
Loading