diff --git a/examples/mutable_chaining.rs b/examples/mutable_chaining.rs new file mode 100644 index 0000000..db53b4b --- /dev/null +++ b/examples/mutable_chaining.rs @@ -0,0 +1,93 @@ +//! Demonstrates chaining mutable packet views across Ethernet/IPv4/UDP layers. + +use nex::net::mac::MacAddr; +use nex::packet::ethernet::{ + EtherType, EthernetPacket, MutableEthernetPacket, ETHERNET_HEADER_LEN, +}; +use nex::packet::ip::IpNextProtocol; +use nex::packet::ipv4::{self, Ipv4Packet, MutableIpv4Packet, IPV4_HEADER_LEN}; +use nex::packet::packet::{MutablePacket, Packet}; +use nex::packet::udp::{self, MutableUdpPacket, UdpPacket, UDP_HEADER_LEN}; +use std::net::Ipv4Addr; + +fn main() { + // Build a simple Ethernet/IPv4/UDP frame in-place. + let payload = b"hello mutable packets"; + let frame_len = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN + payload.len(); + let mut frame = vec![0u8; frame_len]; + + { + let mut ethernet = MutableEthernetPacket::new(&mut frame).expect("ethernet"); + ethernet.set_source(MacAddr::new(0x00, 0x11, 0x22, 0x33, 0x44, 0x55)); + ethernet.set_destination(MacAddr::new(0x08, 0x00, 0x27, 0xaa, 0xbb, 0xcc)); + ethernet.set_ethertype(EtherType::Ipv4); + + let ipv4_len = (IPV4_HEADER_LEN + UDP_HEADER_LEN + payload.len()) as u16; + { + // Use `new_unchecked` because the buffer starts zeroed; we will + // populate all required header fields before freezing it back into + // an immutable packet for validation. + let mut ipv4 = MutableIpv4Packet::new_unchecked(ethernet.payload_mut()); + ipv4.set_version(4); + ipv4.set_header_length(5); + ipv4.set_total_length(ipv4_len); + ipv4.set_ttl(64); + ipv4.set_next_level_protocol(IpNextProtocol::Udp); + ipv4.set_source(Ipv4Addr::new(192, 0, 2, 1)); + ipv4.set_destination(Ipv4Addr::new(198, 51, 100, 1)); + ipv4.set_identification(0x1337); + ipv4.set_checksum(0); + + { + let mut udp = MutableUdpPacket::new(ipv4.payload_mut()).expect("udp"); + udp.set_source(5353); + udp.set_destination(8080); + udp.set_length((UDP_HEADER_LEN + payload.len()) as u16); + udp.set_checksum(0); + + let udp_payload = udp.payload_mut(); + udp_payload[..payload.len()].copy_from_slice(payload); + } + + let snapshot = ipv4.freeze().expect("snapshot ipv4"); + let udp_snapshot = UdpPacket::from_buf(&snapshot.payload).expect("snapshot udp"); + let udp_checksum = udp::ipv4_checksum( + &udp_snapshot, + &snapshot.header.source, + &snapshot.header.destination, + ); + MutableUdpPacket::new(ipv4.payload_mut()) + .expect("udp checksum") + .set_checksum(udp_checksum); + let ipv4_checksum = ipv4::checksum(&snapshot); + ipv4.set_checksum(ipv4_checksum); + } + } + + // Inspect immutable packet views to confirm changes persisted across layers. + let ethernet_packet = EthernetPacket::from_buf(&frame).expect("immutable ethernet"); + let ipv4_packet = Ipv4Packet::from_buf(ðernet_packet.payload).expect("immutable ipv4"); + let udp_packet = UdpPacket::from_buf(&ipv4_packet.payload).expect("immutable udp"); + + println!( + "Ethernet: {} -> {} ({:?})", + ethernet_packet.header.source, + ethernet_packet.header.destination, + ethernet_packet.header.ethertype + ); + println!( + "IPv4: {} -> {} ttl={} checksum=0x{:04x}", + ipv4_packet.header.source, + ipv4_packet.header.destination, + ipv4_packet.header.ttl, + ipv4_packet.header.checksum + ); + println!( + "UDP: {} -> {} len={} checksum=0x{:04x}", + udp_packet.header.source, + udp_packet.header.destination, + udp_packet.header.length, + udp_packet.header.checksum + ); + println!("Payload: {}", String::from_utf8_lossy(&udp_packet.payload)); +} diff --git a/nex-packet/src/arp.rs b/nex-packet/src/arp.rs index 67f2bff..e3bf32a 100644 --- a/nex-packet/src/arp.rs +++ b/nex-packet/src/arp.rs @@ -2,7 +2,7 @@ use crate::{ ethernet::{EtherType, ETHERNET_HEADER_LEN}, - packet::Packet, + packet::{MutablePacket, Packet}, }; use bytes::{Bytes, BytesMut}; @@ -448,6 +448,146 @@ impl fmt::Display for ArpPacket { } } +/// Represents a mutable ARP Packet. +pub struct MutableArpPacket<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableArpPacket<'a> { + type Packet = ArpPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < ARP_HEADER_LEN { + None + } else { + Some(Self { buffer }) + } + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..ARP_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(ARP_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[ARP_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(ARP_HEADER_LEN); + payload + } +} + +impl<'a> MutableArpPacket<'a> { + /// Create a packet without performing length checks. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + pub fn get_hardware_type(&self) -> ArpHardwareType { + ArpHardwareType::new(u16::from_be_bytes([self.raw()[0], self.raw()[1]])) + } + + pub fn set_hardware_type(&mut self, ty: ArpHardwareType) { + self.raw_mut()[0..2].copy_from_slice(&ty.value().to_be_bytes()); + } + + pub fn get_protocol_type(&self) -> EtherType { + EtherType::new(u16::from_be_bytes([self.raw()[2], self.raw()[3]])) + } + + pub fn set_protocol_type(&mut self, ty: EtherType) { + self.raw_mut()[2..4].copy_from_slice(&ty.value().to_be_bytes()); + } + + pub fn get_hw_addr_len(&self) -> u8 { + self.raw()[4] + } + + pub fn set_hw_addr_len(&mut self, len: u8) { + self.raw_mut()[4] = len; + } + + pub fn get_proto_addr_len(&self) -> u8 { + self.raw()[5] + } + + pub fn set_proto_addr_len(&mut self, len: u8) { + self.raw_mut()[5] = len; + } + + pub fn get_operation(&self) -> ArpOperation { + ArpOperation::new(u16::from_be_bytes([self.raw()[6], self.raw()[7]])) + } + + pub fn set_operation(&mut self, op: ArpOperation) { + self.raw_mut()[6..8].copy_from_slice(&op.value().to_be_bytes()); + } + + pub fn get_sender_hw_addr(&self) -> MacAddr { + MacAddr::from_octets(self.raw()[8..14].try_into().unwrap()) + } + + pub fn set_sender_hw_addr(&mut self, addr: MacAddr) { + self.raw_mut()[8..14].copy_from_slice(&addr.octets()); + } + + pub fn get_sender_proto_addr(&self) -> Ipv4Addr { + Ipv4Addr::new( + self.raw()[14], + self.raw()[15], + self.raw()[16], + self.raw()[17], + ) + } + + pub fn set_sender_proto_addr(&mut self, addr: Ipv4Addr) { + self.raw_mut()[14..18].copy_from_slice(&addr.octets()); + } + + pub fn get_target_hw_addr(&self) -> MacAddr { + MacAddr::from_octets(self.raw()[18..24].try_into().unwrap()) + } + + pub fn set_target_hw_addr(&mut self, addr: MacAddr) { + self.raw_mut()[18..24].copy_from_slice(&addr.octets()); + } + + pub fn get_target_proto_addr(&self) -> Ipv4Addr { + Ipv4Addr::new( + self.raw()[24], + self.raw()[25], + self.raw()[26], + self.raw()[27], + ) + } + + pub fn set_target_proto_addr(&mut self, addr: Ipv4Addr) { + self.raw_mut()[24..28].copy_from_slice(&addr.octets()); + } +} + #[cfg(test)] mod tests { use super::*; @@ -543,4 +683,31 @@ mod tests { _ => panic!("Expected unknown operation"), } } + + #[test] + fn test_mutable_arp_packet_updates() { + let mut raw = [ + 0x00, 0x01, // Hardware Type: Ethernet + 0x08, 0x00, // Protocol Type: IPv4 + 0x06, // HW Addr Len + 0x04, // Proto Addr Len + 0x00, 0x01, // Operation: Request + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Sender MAC + 192, 168, 1, 1, // Sender IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Target MAC + 192, 168, 1, 2, // Target IP + 0xde, 0xad, 0xbe, 0xef, // payload + ]; + + let mut packet = MutableArpPacket::new(&mut raw).expect("mutable arp"); + assert_eq!(packet.get_operation(), ArpOperation::Request); + packet.set_operation(ArpOperation::Reply); + packet.set_sender_proto_addr(Ipv4Addr::new(10, 0, 0, 1)); + packet.payload_mut()[0] = 0xaa; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.operation, ArpOperation::Reply); + assert_eq!(frozen.header.sender_proto_addr, Ipv4Addr::new(10, 0, 0, 1)); + assert_eq!(frozen.payload[0], 0xaa); + } } diff --git a/nex-packet/src/dhcp.rs b/nex-packet/src/dhcp.rs index 35b93cd..3c634cf 100644 --- a/nex-packet/src/dhcp.rs +++ b/nex-packet/src/dhcp.rs @@ -2,7 +2,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use nex_core::mac::MacAddr; use std::net::Ipv4Addr; -use crate::packet::Packet; +use crate::packet::{GenericMutablePacket, Packet}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -406,9 +406,13 @@ impl Packet for DhcpPacket { } } +/// Represents a mutable DHCP packet. +pub type MutableDhcpPacket<'a> = GenericMutablePacket<'a, DhcpPacket>; + #[cfg(test)] mod tests { use super::*; + use crate::packet::MutablePacket; use nex_core::mac::MacAddr; #[test] @@ -447,4 +451,20 @@ mod tests { let rebuilt = packet.to_bytes(); assert_eq!(rebuilt, raw); } + + #[test] + fn test_mutable_dhcp_packet_alias() { + let mut raw = [0u8; DHCP_MIN_PACKET_SIZE + 4]; + raw[0] = DhcpOperation::Request.value(); + raw[1] = DhcpHardwareType::Ethernet.value(); + raw[2] = 6; // hardware length + + let mut packet = ::new(&mut raw).expect("mutable dhcp"); + packet.header_mut()[0] = DhcpOperation::Reply.value(); + packet.payload_mut()[0] = 0xaa; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.op, DhcpOperation::Reply); + assert_eq!(frozen.payload[0], 0xaa); + } } diff --git a/nex-packet/src/dns.rs b/nex-packet/src/dns.rs index d52e604..33d214f 100644 --- a/nex-packet/src/dns.rs +++ b/nex-packet/src/dns.rs @@ -1,4 +1,4 @@ -use crate::packet::Packet; +use crate::packet::{GenericMutablePacket, Packet}; use bytes::{BufMut, Bytes, BytesMut}; use core::str; use nex_core::bitfield::{u1, u16be, u32be}; @@ -1208,9 +1208,13 @@ impl std::fmt::Display for DnsName { } } +/// Represents a mutable DNS packet. +pub type MutableDnsPacket<'a> = GenericMutablePacket<'a, DnsPacket>; + #[cfg(test)] mod tests { use super::*; + use crate::packet::MutablePacket; #[test] fn test_dns_query() { @@ -1294,4 +1298,19 @@ mod tests { assert_eq!(packet.responses[0].data_len, 4); assert_eq!(packet.responses[0].data, vec![192, 168, 122, 189]); } + + #[test] + fn test_mutable_dns_packet_header_edit() { + let mut raw = [0u8; 16]; + raw[1] = 0x01; // id + + let mut packet = ::new(&mut raw).expect("mutable dns"); + packet.header_mut()[0] = 0x12; + packet.header_mut()[1] = 0x34; + packet.payload_mut()[0] = 0xaa; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.id, 0x1234); + assert_eq!(frozen.payload[0], 0xaa); + } } diff --git a/nex-packet/src/ethernet.rs b/nex-packet/src/ethernet.rs index 58f15e3..08a157f 100644 --- a/nex-packet/src/ethernet.rs +++ b/nex-packet/src/ethernet.rs @@ -7,7 +7,7 @@ use nex_core::mac::MacAddr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::packet::Packet; +use crate::packet::{MutablePacket, Packet}; /// Represents the Ethernet header length. pub const ETHERNET_HEADER_LEN: usize = 14; @@ -284,11 +284,97 @@ impl fmt::Display for EthernetPacket { } } +/// Represents a mutable Ethernet packet. +pub struct MutableEthernetPacket<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableEthernetPacket<'a> { + type Packet = EthernetPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < ETHERNET_HEADER_LEN { + None + } else { + Some(Self { buffer }) + } + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..ETHERNET_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(ETHERNET_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[ETHERNET_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(ETHERNET_HEADER_LEN); + payload + } +} + +impl<'a> MutableEthernetPacket<'a> { + /// Create a mutable packet without performing size checks. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + /// Retrieve the destination MAC address. + pub fn get_destination(&self) -> MacAddr { + MacAddr::from_octets(self.header()[0..MAC_ADDR_LEN].try_into().unwrap()) + } + + /// Update the destination MAC address. + pub fn set_destination(&mut self, addr: MacAddr) { + self.header_mut()[0..MAC_ADDR_LEN].copy_from_slice(&addr.octets()); + } + + /// Retrieve the source MAC address. + pub fn get_source(&self) -> MacAddr { + MacAddr::from_octets( + self.header()[MAC_ADDR_LEN..2 * MAC_ADDR_LEN] + .try_into() + .unwrap(), + ) + } + + /// Update the source MAC address. + pub fn set_source(&mut self, addr: MacAddr) { + self.header_mut()[MAC_ADDR_LEN..2 * MAC_ADDR_LEN].copy_from_slice(&addr.octets()); + } + + /// Retrieve the EtherType. + pub fn get_ethertype(&self) -> EtherType { + EtherType::new(u16::from_be_bytes([self.header()[12], self.header()[13]])) + } + + /// Update the EtherType. + pub fn set_ethertype(&mut self, ty: EtherType) { + let bytes = ty.value().to_be_bytes(); + self.header_mut()[12..14].copy_from_slice(&bytes); + } +} + #[cfg(test)] mod tests { use super::*; use bytes::Bytes; use nex_core::mac::MacAddr; + use std::net::Ipv4Addr; #[test] fn test_ethernet_parse_basic() { @@ -361,4 +447,42 @@ mod tests { _ => panic!("Expected unknown EtherType"), } } + + #[test] + fn test_mutable_chaining_updates_in_place() { + let mut raw = [ + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // dst + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, // src + 0x08, 0x00, // IPv4 EtherType + 0x45, 0x00, 0x00, 0x1c, // IPv4 header start (20 bytes header + 8 bytes payload) + 0x1c, 0x46, 0x40, 0x00, 0x40, 0x11, 0x00, 0x00, // rest of IPv4 header + 0xc0, 0xa8, 0x00, 0x01, // src IP + 0xc0, 0xa8, 0x00, 0xc7, // dst IP + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, // payload + ]; + + let mut ethernet = MutableEthernetPacket::new(&mut raw).expect("mutable ethernet"); + assert_eq!(ethernet.get_ethertype(), EtherType::Ipv4); + + use crate::ipv4::MutableIpv4Packet; + + { + let mut ipv4 = MutableIpv4Packet::new(ethernet.payload_mut()).expect("mutable ipv4"); + ipv4.set_ttl(99); + ipv4.set_source(Ipv4Addr::new(10, 0, 0, 1)); + ipv4.payload_mut()[0] = 0xaa; + } + + { + let packet_view = ethernet.packet(); + assert_eq!(packet_view[22], 99); + assert_eq!(&packet_view[26..30], &[10, 0, 0, 1]); + assert_eq!(packet_view[34], 0xaa); + } + + drop(ethernet); + assert_eq!(raw[22], 99); + assert_eq!(&raw[26..30], &[10, 0, 0, 1]); + assert_eq!(raw[34], 0xaa); + } } diff --git a/nex-packet/src/flowcontrol.rs b/nex-packet/src/flowcontrol.rs index 1df265a..bb6b7a1 100644 --- a/nex-packet/src/flowcontrol.rs +++ b/nex-packet/src/flowcontrol.rs @@ -4,7 +4,7 @@ use core::fmt; use bytes::{Buf, BufMut, Bytes}; use nex_core::bitfield::u16be; -use crate::packet::Packet; +use crate::packet::{GenericMutablePacket, Packet}; /// Represents the opcode field in an Ethernet Flow Control packet. /// @@ -117,9 +117,13 @@ impl Packet for FlowControlPacket { } } +/// Represents a mutable Ethernet Flow Control packet. +pub type MutableFlowControlPacket<'a> = GenericMutablePacket<'a, FlowControlPacket>; + #[cfg(test)] mod tests { use super::*; + use crate::packet::MutablePacket; #[test] fn flowcontrol_pause_test() { @@ -134,4 +138,23 @@ mod tests { assert_eq!(fc_packet.quanta, 0x1234); assert_eq!(fc_packet.to_bytes(), packet); } + + #[test] + fn flowcontrol_mutable_packet() { + let mut raw = [ + 0x00, 0x01, // Opcode: Pause + 0x12, 0x34, // Quanta: 0x1234 + 0xaa, 0xbb, + ]; + + let mut packet = ::new(&mut raw) + .expect("mutable flowcontrol"); + packet.header_mut()[0] = 0x00; + packet.header_mut()[1] = 0x02; + packet.payload_mut()[0] = 0xff; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.command, FlowControlOpcode::Unknown(2)); + assert_eq!(frozen.payload[0], 0xff); + } } diff --git a/nex-packet/src/gre.rs b/nex-packet/src/gre.rs index 356a39e..0bab76a 100644 --- a/nex-packet/src/gre.rs +++ b/nex-packet/src/gre.rs @@ -1,6 +1,6 @@ //! GRE Packet abstraction. -use crate::packet::Packet; +use crate::packet::{GenericMutablePacket, Packet}; use bytes::{Buf, Bytes}; use nex_core::bitfield::{u1, u16be, u3, u32be, u5}; @@ -253,9 +253,14 @@ impl GrePacket { } } +/// Represents a mutable GRE packet. +pub type MutableGrePacket<'a> = GenericMutablePacket<'a, GrePacket>; + #[cfg(test)] mod tests { use super::*; + use crate::packet::MutablePacket; + #[test] fn gre_packet_test() { let packet = Bytes::from_static(&[ @@ -285,4 +290,22 @@ mod tests { assert_eq!(&gre_packet.to_bytes(), &packet); } + + #[test] + fn test_mutable_gre_packet_alias() { + let mut raw = [ + 0x00, 0x00, // flags + 0x08, 0x00, // protocol type + 0xaa, 0xbb, + ]; + + let mut packet = ::new(&mut raw).expect("mutable gre"); + packet.header_mut()[2] = 0x86; + packet.header_mut()[3] = 0xdd; // IPv6 protocol + packet.payload_mut()[0] = 0xff; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.protocol_type, 0x86dd); + assert_eq!(frozen.payload[0], 0xff); + } } diff --git a/nex-packet/src/icmp.rs b/nex-packet/src/icmp.rs index 8f178da..ae2154d 100644 --- a/nex-packet/src/icmp.rs +++ b/nex-packet/src/icmp.rs @@ -1,7 +1,9 @@ //! An ICMP packet abstraction. use crate::ipv4::IPV4_HEADER_LEN; -use crate::{ethernet::ETHERNET_HEADER_LEN, packet::Packet}; - +use crate::{ + ethernet::ETHERNET_HEADER_LEN, + packet::{GenericMutablePacket, Packet}, +}; use bytes::{BufMut, Bytes, BytesMut}; use nex_core::bitfield::u16be; #[cfg(feature = "serde")] @@ -244,6 +246,9 @@ impl IcmpPacket { } } +/// Represents a mutable ICMP packet. +pub type MutableIcmpPacket<'a> = GenericMutablePacket<'a, IcmpPacket>; + /// Calculates a checksum of an ICMP packet. pub fn checksum(packet: &IcmpPacket) -> u16be { use crate::util; @@ -522,6 +527,7 @@ pub mod time_exceeded { #[cfg(test)] mod tests { use super::*; + use crate::packet::MutablePacket; #[test] fn test_echo_request_from_bytes() { @@ -631,4 +637,21 @@ mod tests { assert_eq!(exceeded.unused, unused); assert_eq!(exceeded.payload, payload); } + + #[test] + fn test_mutable_icmp_packet_alias() { + let mut raw = [ + 8, 0, 0, 0, // type, code, checksum + 0, 1, 0, 1, // identifier, sequence + b'p', b'i', + ]; + + let mut packet = ::new(&mut raw).expect("mutable icmp"); + packet.header_mut()[0] = IcmpType::EchoReply.value(); + packet.payload_mut()[0] = b'x'; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.icmp_type, IcmpType::EchoReply); + assert_eq!(frozen.payload[0], b'x'); + } } diff --git a/nex-packet/src/icmpv6.rs b/nex-packet/src/icmpv6.rs index 6f3c0e9..55142ea 100644 --- a/nex-packet/src/icmpv6.rs +++ b/nex-packet/src/icmpv6.rs @@ -1,7 +1,10 @@ //! An ICMPv6 packet abstraction. use crate::ipv6::IPV6_HEADER_LEN; -use crate::{ethernet::ETHERNET_HEADER_LEN, packet::Packet}; +use crate::{ + ethernet::ETHERNET_HEADER_LEN, + packet::{GenericMutablePacket, Packet}, +}; use std::net::Ipv6Addr; use bytes::Bytes; @@ -290,6 +293,40 @@ impl Packet for Icmpv6Packet { } } +/// Represents a mutable ICMPv6 packet. +pub type MutableIcmpv6Packet<'a> = GenericMutablePacket<'a, Icmpv6Packet>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::packet::MutablePacket; + + #[test] + fn test_mutable_icmpv6_packet_alias() { + let mut raw = [ + Icmpv6Type::EchoRequest.value(), + 0, + 0, + 0, + 0, + 1, + 0, + 1, + b'p', + b'i', + ]; + + let mut packet = + ::new(&mut raw).expect("mutable icmpv6"); + packet.header_mut()[0] = Icmpv6Type::EchoReply.value(); + packet.payload_mut()[0] = b'x'; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.icmpv6_type, Icmpv6Type::EchoReply); + assert_eq!(frozen.payload[0], b'x'); + } +} + /// Calculates a checksum of an ICMPv6 packet. pub fn checksum(packet: &Icmpv6Packet, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16 { use crate::util; diff --git a/nex-packet/src/ipv4.rs b/nex-packet/src/ipv4.rs index a55fad4..ec64165 100644 --- a/nex-packet/src/ipv4.rs +++ b/nex-packet/src/ipv4.rs @@ -1,6 +1,9 @@ //! An IPv4 packet abstraction. -use crate::{ip::IpNextProtocol, packet::Packet}; +use crate::{ + ip::IpNextProtocol, + packet::{MutablePacket, Packet}, +}; use bytes::{BufMut, Bytes, BytesMut}; use nex_core::bitfield::*; use std::net::Ipv4Addr; @@ -414,6 +417,255 @@ impl Ipv4Packet { } } +/// Represents a mutable IPv4 packet. +pub struct MutableIpv4Packet<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableIpv4Packet<'a> { + type Packet = Ipv4Packet; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < IPV4_HEADER_LEN { + return None; + } + + let ihl = (buffer[0] & 0x0F) as usize; + if ihl < 5 { + return None; + } + + let header_len = ihl * IPV4_HEADER_LENGTH_BYTE_UNITS; + if header_len > buffer.len() { + return None; + } + + let total_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize; + if total_len != 0 && total_len < header_len { + return None; + } + + Some(Self { buffer }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + let header_len = self.header_len(); + &self.packet()[..header_len] + } + + fn header_mut(&mut self) -> &mut [u8] { + let header_len = self.header_len(); + let (header, _) = (&mut *self.buffer).split_at_mut(header_len); + header + } + + fn payload(&self) -> &[u8] { + let start = self.header_len(); + let end = start + self.payload_len(); + &self.packet()[start..end] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let header_len = self.header_len(); + let payload_len = self.payload_len(); + let (_, payload) = (&mut *self.buffer).split_at_mut(header_len); + &mut payload[..payload_len] + } +} + +impl<'a> MutableIpv4Packet<'a> { + /// Create a mutable packet without validating the header fields. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + /// Returns the header length in bytes. + pub fn header_len(&self) -> usize { + let ihl = (self.raw()[0] & 0x0F) as usize; + let header_len = ihl * IPV4_HEADER_LENGTH_BYTE_UNITS; + header_len.max(IPV4_HEADER_LEN).min(self.raw().len()) + } + + /// Returns the payload length based on the total length field. + pub fn payload_len(&self) -> usize { + let total = self.total_len(); + total.saturating_sub(self.header_len()) + } + + /// Returns the effective total length of the packet. + pub fn total_len(&self) -> usize { + let total = u16::from_be_bytes([self.raw()[2], self.raw()[3]]) as usize; + if total == 0 { + self.raw().len() + } else { + total.min(self.raw().len()) + } + } + + /// Retrieve the version field. + pub fn get_version(&self) -> u8 { + self.raw()[0] >> 4 + } + + /// Update the version field. + pub fn set_version(&mut self, version: u8) { + let buffer = self.raw_mut(); + buffer[0] = (buffer[0] & 0x0F) | ((version & 0x0F) << 4); + } + + /// Retrieve the header length in 32-bit words. + pub fn get_header_length(&self) -> u8 { + self.raw()[0] & 0x0F + } + + /// Update the header length in 32-bit words. + pub fn set_header_length(&mut self, ihl: u8) { + let buffer = self.raw_mut(); + buffer[0] = (buffer[0] & 0xF0) | (ihl & 0x0F); + } + + /// Retrieve the DSCP field. + pub fn get_dscp(&self) -> u8 { + self.raw()[1] >> 2 + } + + /// Update the DSCP field. + pub fn set_dscp(&mut self, dscp: u8) { + let buffer = self.raw_mut(); + buffer[1] = (buffer[1] & 0x03) | ((dscp & 0x3F) << 2); + } + + /// Retrieve the ECN field. + pub fn get_ecn(&self) -> u8 { + self.raw()[1] & 0x03 + } + + /// Update the ECN field. + pub fn set_ecn(&mut self, ecn: u8) { + let buffer = self.raw_mut(); + buffer[1] = (buffer[1] & 0xFC) | (ecn & 0x03); + } + + /// Retrieve the total length field. + pub fn get_total_length(&self) -> u16 { + u16::from_be_bytes([self.raw()[2], self.raw()[3]]) + } + + /// Update the total length field. + pub fn set_total_length(&mut self, len: u16) { + self.raw_mut()[2..4].copy_from_slice(&len.to_be_bytes()); + } + + /// Retrieve the identification field. + pub fn get_identification(&self) -> u16 { + u16::from_be_bytes([self.raw()[4], self.raw()[5]]) + } + + /// Update the identification field. + pub fn set_identification(&mut self, id: u16) { + self.raw_mut()[4..6].copy_from_slice(&id.to_be_bytes()); + } + + /// Retrieve the flags field. + pub fn get_flags(&self) -> u8 { + (self.raw()[6] & 0xE0) >> 5 + } + + /// Update the flags field. + pub fn set_flags(&mut self, flags: u8) { + let buffer = self.raw_mut(); + buffer[6] = (buffer[6] & 0x1F) | ((flags & 0x07) << 5); + } + + /// Retrieve the fragment offset field. + pub fn get_fragment_offset(&self) -> u16 { + u16::from_be_bytes([self.raw()[6], self.raw()[7]]) & 0x1FFF + } + + /// Update the fragment offset field. + pub fn set_fragment_offset(&mut self, offset: u16) { + let buffer = self.raw_mut(); + let combined = (u16::from_be_bytes([buffer[6], buffer[7]]) & 0xE000) | (offset & 0x1FFF); + buffer[6..8].copy_from_slice(&combined.to_be_bytes()); + } + + /// Retrieve the TTL field. + pub fn get_ttl(&self) -> u8 { + self.raw()[8] + } + + /// Update the TTL field. + pub fn set_ttl(&mut self, ttl: u8) { + self.raw_mut()[8] = ttl; + } + + /// Retrieve the next-level protocol field. + pub fn get_next_level_protocol(&self) -> IpNextProtocol { + IpNextProtocol::new(self.raw()[9]) + } + + /// Update the next-level protocol field. + pub fn set_next_level_protocol(&mut self, proto: IpNextProtocol) { + self.raw_mut()[9] = proto.value(); + } + + /// Retrieve the checksum field. + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes([self.raw()[10], self.raw()[11]]) + } + + /// Update the checksum field. + pub fn set_checksum(&mut self, checksum: u16) { + self.raw_mut()[10..12].copy_from_slice(&checksum.to_be_bytes()); + } + + /// Retrieve the source address. + pub fn get_source(&self) -> Ipv4Addr { + Ipv4Addr::new( + self.raw()[12], + self.raw()[13], + self.raw()[14], + self.raw()[15], + ) + } + + /// Update the source address. + pub fn set_source(&mut self, addr: Ipv4Addr) { + self.raw_mut()[12..16].copy_from_slice(&addr.octets()); + } + + /// Retrieve the destination address. + pub fn get_destination(&self) -> Ipv4Addr { + Ipv4Addr::new( + self.raw()[16], + self.raw()[17], + self.raw()[18], + self.raw()[19], + ) + } + + /// Update the destination address. + pub fn set_destination(&mut self, addr: Ipv4Addr) { + self.raw_mut()[16..20].copy_from_slice(&addr.octets()); + } +} + /// Calculates a checksum of an IPv4 packet header. /// The checksum field of the packet is regarded as zeros during the calculation. pub fn checksum(packet: &Ipv4Packet) -> u16be { @@ -580,4 +832,42 @@ mod tests { raw_copy[11] = (computed & 0xff) as u8; assert_eq!(&packet.to_bytes()[..], &raw_copy[..]); } + + #[test] + fn test_mutable_ipv4_packet_updates() { + let mut raw = [ + 0x45, 0x00, 0x00, 0x1c, // Version + IHL, DSCP/ECN, Total Length + 0x1c, 0x46, 0x40, 0x00, // Identification, Flags/Fragment offset + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol, Header checksum + 0xc0, 0xa8, 0x00, 0x01, // Source + 0xc0, 0xa8, 0x00, 0xc7, // Destination + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, // Payload + ]; + + let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4"); + assert_eq!(packet.get_version(), 4); + assert_eq!(packet.get_ttl(), 0x40); + + packet.set_ttl(128); + packet.set_destination(Ipv4Addr::new(192, 0, 2, 1)); + packet.payload_mut()[0] = 0x11; + + { + let packet_view = packet.packet(); + assert_eq!(packet_view[8], 128); + assert_eq!(&packet_view[16..20], &[192, 0, 2, 1]); + assert_eq!(packet_view[20], 0x11); + } + + let frozen = packet.freeze().expect("freeze mutable packet"); + drop(packet); + + assert_eq!(raw[8], 128); + assert_eq!(&raw[16..20], &[192, 0, 2, 1]); + assert_eq!(raw[20], 0x11); + + assert_eq!(frozen.header.ttl, 128); + assert_eq!(frozen.header.destination, Ipv4Addr::new(192, 0, 2, 1)); + assert_eq!(frozen.payload[0], 0x11); + } } diff --git a/nex-packet/src/ipv6.rs b/nex-packet/src/ipv6.rs index ae23372..568c600 100644 --- a/nex-packet/src/ipv6.rs +++ b/nex-packet/src/ipv6.rs @@ -1,5 +1,5 @@ use crate::ip::IpNextProtocol; -use crate::packet::Packet; +use crate::packet::{MutablePacket, Packet}; use bytes::{BufMut, Bytes, BytesMut}; use std::net::Ipv6Addr; @@ -299,6 +299,142 @@ impl Ipv6Packet { } } +/// Represents a mutable IPv6 packet. +pub struct MutableIpv6Packet<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableIpv6Packet<'a> { + type Packet = Ipv6Packet; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < IPV6_HEADER_LEN { + None + } else { + Some(Self { buffer }) + } + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..IPV6_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(IPV6_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[IPV6_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(IPV6_HEADER_LEN); + payload + } +} + +impl<'a> MutableIpv6Packet<'a> { + /// Create a new packet without checking length. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + pub fn payload_len(&self) -> usize { + self.raw().len().saturating_sub(IPV6_HEADER_LEN) + } + + pub fn get_version(&self) -> u8 { + self.raw()[0] >> 4 + } + + pub fn set_version(&mut self, version: u8) { + let buf = self.raw_mut(); + buf[0] = (buf[0] & 0x0F) | ((version & 0x0F) << 4); + } + + pub fn get_traffic_class(&self) -> u8 { + ((self.raw()[0] & 0x0F) << 4) | (self.raw()[1] >> 4) + } + + pub fn set_traffic_class(&mut self, class: u8) { + let buf = self.raw_mut(); + buf[0] = (buf[0] & 0xF0) | ((class >> 4) & 0x0F); + buf[1] = (buf[1] & 0x0F) | ((class & 0x0F) << 4); + } + + pub fn get_flow_label(&self) -> u32 { + let buf = self.raw(); + let high = (buf[1] as u32 & 0x0F) << 16; + let mid = (buf[2] as u32) << 8; + let low = buf[3] as u32; + high | mid | low + } + + pub fn set_flow_label(&mut self, label: u32) { + let buf = self.raw_mut(); + buf[1] = (buf[1] & 0xF0) | (((label >> 16) as u8) & 0x0F); + buf[2] = (label >> 8) as u8; + buf[3] = label as u8; + } + + pub fn get_payload_length(&self) -> u16 { + u16::from_be_bytes([self.raw()[4], self.raw()[5]]) + } + + pub fn set_payload_length(&mut self, length: u16) { + self.raw_mut()[4..6].copy_from_slice(&length.to_be_bytes()); + } + + pub fn get_next_header(&self) -> IpNextProtocol { + IpNextProtocol::new(self.raw()[6]) + } + + pub fn set_next_header(&mut self, proto: IpNextProtocol) { + self.raw_mut()[6] = proto.value(); + } + + pub fn get_hop_limit(&self) -> u8 { + self.raw()[7] + } + + pub fn set_hop_limit(&mut self, value: u8) { + self.raw_mut()[7] = value; + } + + pub fn get_source(&self) -> Ipv6Addr { + Ipv6Addr::from(<[u8; 16]>::try_from(&self.raw()[8..24]).unwrap()) + } + + pub fn set_source(&mut self, addr: Ipv6Addr) { + self.raw_mut()[8..24].copy_from_slice(&addr.octets()); + } + + pub fn get_destination(&self) -> Ipv6Addr { + Ipv6Addr::from(<[u8; 16]>::try_from(&self.raw()[24..40]).unwrap()) + } + + pub fn set_destination(&mut self, addr: Ipv6Addr) { + self.raw_mut()[24..40].copy_from_slice(&addr.octets()); + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExtensionHeaderType { HopByHop, @@ -581,4 +717,33 @@ mod tests { }; assert_eq!(raw_unknown.kind(), ExtensionHeaderType::Unknown(250)); } + + #[test] + fn test_mutable_ipv6_packet_mutations() { + let mut raw = [ + 0x60, 0x00, 0x00, 0x00, // version, traffic class, flow label + 0x00, 0x04, // payload length + 0x11, // next header (UDP) + 0x40, // hop limit + // source + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // destination + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // payload + 0xde, 0xad, 0xbe, 0xef, + ]; + + let mut packet = MutableIpv6Packet::new(&mut raw).expect("mutable ipv6"); + assert_eq!(packet.get_version(), 6); + packet.set_hop_limit(0x7f); + packet.set_next_header(IpNextProtocol::Tcp); + packet.set_flow_label(0x12345); + packet.set_source(Ipv6Addr::LOCALHOST); + packet.payload_mut()[0] = 0xaa; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.hop_limit, 0x7f); + assert_eq!(frozen.header.next_header, IpNextProtocol::Tcp); + assert_eq!(frozen.header.flow_label, 0x12345); + assert_eq!(frozen.header.source, Ipv6Addr::LOCALHOST); + assert_eq!(frozen.payload[0], 0xaa); + } } diff --git a/nex-packet/src/packet.rs b/nex-packet/src/packet.rs index 889313f..665aa39 100644 --- a/nex-packet/src/packet.rs +++ b/nex-packet/src/packet.rs @@ -1,4 +1,5 @@ use bytes::{Bytes, BytesMut}; +use std::marker::PhantomData; /// Represents a generic network packet. pub trait Packet: Sized { @@ -47,3 +48,107 @@ pub trait Packet: Sized { fn into_parts(self) -> (Self::Header, Bytes); } + +/// Represents a mutable network packet that can be parsed and modified in place. +/// +/// Types implementing this trait work on top of the same backing buffer and allow +/// layered packet parsing to be chained without additional allocations. +pub trait MutablePacket<'a>: Sized { + /// The immutable packet type associated with this mutable view. + type Packet: Packet; + + /// Construct a mutable packet from the provided buffer. + fn new(buffer: &'a mut [u8]) -> Option; + + /// Get a shared view over the entire packet buffer. + fn packet(&self) -> &[u8]; + + /// Get a mutable view over the entire packet buffer. + fn packet_mut(&mut self) -> &mut [u8]; + + /// Get the serialized header bytes of the packet. + fn header(&self) -> &[u8]; + + /// Get a mutable view over the serialized header bytes of the packet. + fn header_mut(&mut self) -> &mut [u8]; + + /// Get the payload bytes of the packet. + fn payload(&self) -> &[u8]; + + /// Get a mutable view over the payload bytes of the packet. + fn payload_mut(&mut self) -> &mut [u8]; + + /// Convert the mutable packet into its immutable counterpart. + fn freeze(&self) -> Option { + Self::Packet::from_buf(self.packet()) + } +} + +/// A generic mutable packet wrapper that validates using the immutable packet +/// parser and exposes the raw buffer for in-place mutation. +pub struct GenericMutablePacket<'a, P: Packet> { + buffer: &'a mut [u8], + _marker: PhantomData

, +} + +impl<'a, P: Packet> MutablePacket<'a> for GenericMutablePacket<'a, P> { + type Packet = P; + + fn new(buffer: &'a mut [u8]) -> Option { + P::from_buf(buffer)?; + Some(Self { + buffer, + _marker: PhantomData, + }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + let (header_len, _) = self.lengths(); + &self.packet()[..header_len] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header_len, _) = self.lengths(); + let (header, _) = (&mut *self.buffer).split_at_mut(header_len); + header + } + + fn payload(&self) -> &[u8] { + let (header_len, payload_len) = self.lengths(); + &self.packet()[header_len..header_len + payload_len] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (header_len, payload_len) = self.lengths(); + let (_, payload) = (&mut *self.buffer).split_at_mut(header_len); + &mut payload[..payload_len] + } +} + +impl<'a, P: Packet> GenericMutablePacket<'a, P> { + /// Construct a mutable packet without running additional validation. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { + buffer, + _marker: PhantomData, + } + } + + fn lengths(&self) -> (usize, usize) { + if let Some(packet) = P::from_buf(self.packet()) { + let header_len = packet.header_len(); + let payload_len = packet.payload_len(); + (header_len, payload_len) + } else { + (self.buffer.len(), 0) + } + } +} diff --git a/nex-packet/src/tcp.rs b/nex-packet/src/tcp.rs index bcb034a..0550c4d 100644 --- a/nex-packet/src/tcp.rs +++ b/nex-packet/src/tcp.rs @@ -1,7 +1,7 @@ //! A TCP packet abstraction. use crate::ip::IpNextProtocol; -use crate::packet::Packet; +use crate::packet::{MutablePacket, Packet}; use crate::util::{self, Octets}; use std::net::Ipv6Addr; @@ -665,6 +665,182 @@ impl TcpPacket { } } +/// Represents a mutable TCP packet. +pub struct MutableTcpPacket<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableTcpPacket<'a> { + type Packet = TcpPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < TCP_HEADER_LEN { + return None; + } + + let data_offset = buffer[12] >> 4; + if data_offset < TCP_MIN_DATA_OFFSET { + return None; + } + + let header_len = (data_offset as usize) * 4; + if header_len > buffer.len() { + return None; + } + + Some(Self { buffer }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + let len = self.header_len(); + &self.packet()[..len] + } + + fn header_mut(&mut self) -> &mut [u8] { + let len = self.header_len(); + let (header, _) = (&mut *self.buffer).split_at_mut(len); + header + } + + fn payload(&self) -> &[u8] { + let len = self.header_len(); + &self.packet()[len..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let len = self.header_len(); + let (_, payload) = (&mut *self.buffer).split_at_mut(len); + payload + } +} + +impl<'a> MutableTcpPacket<'a> { + /// Create a packet without validating the header fields. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + /// Returns the header length in bytes. + pub fn header_len(&self) -> usize { + let offset = (self.raw()[12] >> 4).max(TCP_MIN_DATA_OFFSET); + let len = (offset as usize) * 4; + len.min(self.raw().len()) + } + + /// Returns the payload length of the packet. + pub fn payload_len(&self) -> usize { + self.raw().len().saturating_sub(self.header_len()) + } + + pub fn get_source(&self) -> u16 { + u16::from_be_bytes([self.raw()[0], self.raw()[1]]) + } + + pub fn set_source(&mut self, value: u16) { + self.raw_mut()[0..2].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_destination(&self) -> u16 { + u16::from_be_bytes([self.raw()[2], self.raw()[3]]) + } + + pub fn set_destination(&mut self, value: u16) { + self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_sequence(&self) -> u32 { + u32::from_be_bytes([self.raw()[4], self.raw()[5], self.raw()[6], self.raw()[7]]) + } + + pub fn set_sequence(&mut self, value: u32) { + self.raw_mut()[4..8].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_acknowledgement(&self) -> u32 { + u32::from_be_bytes([self.raw()[8], self.raw()[9], self.raw()[10], self.raw()[11]]) + } + + pub fn set_acknowledgement(&mut self, value: u32) { + self.raw_mut()[8..12].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_data_offset(&self) -> u8 { + self.raw()[12] >> 4 + } + + pub fn set_data_offset(&mut self, offset: u8) { + let buf = self.raw_mut(); + buf[12] = (buf[12] & 0x0F) | ((offset & 0x0F) << 4); + } + + pub fn get_reserved(&self) -> u8 { + self.raw()[12] & 0x0F + } + + pub fn set_reserved(&mut self, value: u8) { + let buf = self.raw_mut(); + buf[12] = (buf[12] & 0xF0) | (value & 0x0F); + } + + pub fn get_flags(&self) -> u8 { + self.raw()[13] + } + + pub fn set_flags(&mut self, flags: u8) { + self.raw_mut()[13] = flags; + } + + pub fn get_window(&self) -> u16 { + u16::from_be_bytes([self.raw()[14], self.raw()[15]]) + } + + pub fn set_window(&mut self, value: u16) { + self.raw_mut()[14..16].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes([self.raw()[16], self.raw()[17]]) + } + + pub fn set_checksum(&mut self, value: u16) { + self.raw_mut()[16..18].copy_from_slice(&value.to_be_bytes()); + } + + pub fn get_urgent_ptr(&self) -> u16 { + u16::from_be_bytes([self.raw()[18], self.raw()[19]]) + } + + pub fn set_urgent_ptr(&mut self, value: u16) { + self.raw_mut()[18..20].copy_from_slice(&value.to_be_bytes()); + } + + pub fn options(&self) -> &[u8] { + let len = self.header_len(); + &self.raw()[TCP_HEADER_LEN..len] + } + + pub fn options_mut(&mut self) -> &mut [u8] { + let len = self.header_len(); + &mut self.raw_mut()[TCP_HEADER_LEN..len] + } +} + pub fn checksum(packet: &TcpPacket, source: &IpAddr, destination: &IpAddr) -> u16 { match (source, destination) { (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst), @@ -810,4 +986,35 @@ mod tests { (0x2c57cda5, 0x02a04192) ); } + + #[test] + fn test_mutable_tcp_packet_round_trip() { + let mut raw = [ + 0x00, 0x50, // source + 0x01, 0xbb, // destination + 0x00, 0x00, 0x00, 0x01, // seq + 0x00, 0x00, 0x00, 0x00, // ack + 0x50, // data offset/reserved + 0x18, // flags + 0x40, 0x00, // window + 0x12, 0x34, // checksum + 0x00, 0x00, // urgent pointer + b'h', b'e', b'l', b'l', b'o', + ]; + + let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp"); + assert_eq!(packet.get_source(), 80); + packet.set_source(1234); + packet.set_destination(4321); + packet.set_sequence(0xfeedbeef); + packet.set_flags(0x11); + packet.payload_mut()[0] = b'H'; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.source, 1234); + assert_eq!(frozen.header.destination, 4321); + assert_eq!(frozen.header.sequence, 0xfeedbeef); + assert_eq!(frozen.header.flags, 0x11); + assert_eq!(frozen.payload[0], b'H'); + } } diff --git a/nex-packet/src/udp.rs b/nex-packet/src/udp.rs index 8826772..da59aff 100644 --- a/nex-packet/src/udp.rs +++ b/nex-packet/src/udp.rs @@ -1,7 +1,7 @@ //! A UDP packet abstraction. use crate::ip::IpNextProtocol; -use crate::packet::Packet; +use crate::packet::{MutablePacket, Packet}; use crate::util; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -106,6 +106,124 @@ impl Packet for UdpPacket { } } +/// Represents a mutable UDP packet. +pub struct MutableUdpPacket<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> { + type Packet = UdpPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < UDP_HEADER_LEN { + return None; + } + + let length = u16::from_be_bytes([buffer[4], buffer[5]]); + if length != 0 { + if length < UDP_HEADER_LEN as u16 { + return None; + } + + if length as usize > buffer.len() { + return None; + } + } + + Some(Self { buffer }) + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..UDP_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + let length = self.total_len(); + &self.packet()[UDP_HEADER_LEN..length] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let total_len = self.total_len(); + let (_, payload) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN); + &mut payload[..total_len.saturating_sub(UDP_HEADER_LEN)] + } +} + +impl<'a> MutableUdpPacket<'a> { + /// Create a new packet without validating length fields. + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + /// Returns the total length derived from the UDP length field. + pub fn total_len(&self) -> usize { + let field = u16::from_be_bytes([self.raw()[4], self.raw()[5]]); + if field == 0 { + self.raw().len() + } else { + field as usize + } + } + + /// Returns the payload length. + pub fn payload_len(&self) -> usize { + self.total_len().saturating_sub(UDP_HEADER_LEN) + } + + pub fn get_source(&self) -> u16 { + u16::from_be_bytes([self.raw()[0], self.raw()[1]]) + } + + pub fn set_source(&mut self, port: u16) { + self.raw_mut()[0..2].copy_from_slice(&port.to_be_bytes()); + } + + pub fn get_destination(&self) -> u16 { + u16::from_be_bytes([self.raw()[2], self.raw()[3]]) + } + + pub fn set_destination(&mut self, port: u16) { + self.raw_mut()[2..4].copy_from_slice(&port.to_be_bytes()); + } + + pub fn get_length(&self) -> u16 { + u16::from_be_bytes([self.raw()[4], self.raw()[5]]) + } + + pub fn set_length(&mut self, length: u16) { + self.raw_mut()[4..6].copy_from_slice(&length.to_be_bytes()); + } + + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes([self.raw()[6], self.raw()[7]]) + } + + pub fn set_checksum(&mut self, checksum: u16) { + self.raw_mut()[6..8].copy_from_slice(&checksum.to_be_bytes()); + } +} + pub fn checksum(packet: &UdpPacket, source: &IpAddr, destination: &IpAddr) -> u16 { match (source, destination) { (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst), @@ -216,4 +334,28 @@ mod tests { assert_eq!(packet.payload(), payload); assert_eq!(packet.header_len(), UDP_HEADER_LEN); } + #[test] + fn test_mutable_udp_packet_updates_in_place() { + let mut raw = [ + 0x12, 0x34, // source + 0xab, 0xcd, // destination + 0x00, 0x0c, // length + 0x55, 0xaa, // checksum + b'd', b'a', b't', b'a', // payload + 0, 0, // trailing capacity + ]; + + let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp"); + assert_eq!(packet.get_source(), 0x1234); + packet.set_source(0x4321); + packet.set_destination(0x0102); + packet.payload_mut()[0] = b'x'; + packet.set_checksum(0xffff); + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.source, 0x4321); + assert_eq!(frozen.header.destination, 0x0102); + assert_eq!(frozen.header.checksum, 0xffff); + assert_eq!(&raw[UDP_HEADER_LEN], &b'x'); + } } diff --git a/nex-packet/src/vlan.rs b/nex-packet/src/vlan.rs index 840ef8d..d8a018e 100644 --- a/nex-packet/src/vlan.rs +++ b/nex-packet/src/vlan.rs @@ -1,6 +1,9 @@ //! A VLAN (802.1Q) packet abstraction. //! -use crate::{ethernet::EtherType, packet::Packet}; +use crate::{ + ethernet::EtherType, + packet::{MutablePacket, Packet}, +}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use nex_core::bitfield::{u1, u12be}; #[cfg(feature = "serde")] @@ -168,6 +171,102 @@ impl Packet for VlanPacket { } } +/// Represents a mutable VLAN packet. +pub struct MutableVlanPacket<'a> { + buffer: &'a mut [u8], +} + +impl<'a> MutablePacket<'a> for MutableVlanPacket<'a> { + type Packet = VlanPacket; + + fn new(buffer: &'a mut [u8]) -> Option { + if buffer.len() < VLAN_HEADER_LEN { + None + } else { + Some(Self { buffer }) + } + } + + fn packet(&self) -> &[u8] { + &*self.buffer + } + + fn packet_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + fn header(&self) -> &[u8] { + &self.packet()[..VLAN_HEADER_LEN] + } + + fn header_mut(&mut self) -> &mut [u8] { + let (header, _) = (&mut *self.buffer).split_at_mut(VLAN_HEADER_LEN); + header + } + + fn payload(&self) -> &[u8] { + &self.packet()[VLAN_HEADER_LEN..] + } + + fn payload_mut(&mut self) -> &mut [u8] { + let (_, payload) = (&mut *self.buffer).split_at_mut(VLAN_HEADER_LEN); + payload + } +} + +impl<'a> MutableVlanPacket<'a> { + pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + fn raw(&self) -> &[u8] { + &*self.buffer + } + + fn raw_mut(&mut self) -> &mut [u8] { + &mut *self.buffer + } + + pub fn get_priority_code_point(&self) -> ClassOfService { + let first = self.raw()[0]; + ClassOfService::new(first >> 5) + } + + pub fn set_priority_code_point(&mut self, class: ClassOfService) { + let buf = self.raw_mut(); + buf[0] = (buf[0] & 0x1F) | ((class.value() & 0x07) << 5); + } + + pub fn get_drop_eligible_id(&self) -> u1 { + ((self.raw()[0] >> 4) & 0x01) as u1 + } + + pub fn set_drop_eligible_id(&mut self, dei: u1) { + let buf = self.raw_mut(); + buf[0] = (buf[0] & !(1 << 4)) | (((dei & 0x1) as u8) << 4); + } + + pub fn get_vlan_id(&self) -> u16 { + let first = self.raw()[0] as u16 & 0x0F; + let second = self.raw()[1] as u16; + (first << 8) | second + } + + pub fn set_vlan_id(&mut self, id: u16) { + let buf = self.raw_mut(); + buf[0] = (buf[0] & 0xF0) | ((id >> 8) as u8 & 0x0F); + buf[1] = id as u8; + } + + pub fn get_ethertype(&self) -> EtherType { + EtherType::new(u16::from_be_bytes([self.raw()[2], self.raw()[3]])) + } + + pub fn set_ethertype(&mut self, ty: EtherType) { + self.raw_mut()[2..4].copy_from_slice(&ty.value().to_be_bytes()); + } +} + #[cfg(test)] mod tests { use super::*; @@ -189,6 +288,7 @@ mod tests { assert_eq!(packet.payload, Bytes::from_static(b"xyz")); assert_eq!(packet.to_bytes(), raw); } + #[test] fn test_vlan_parse_2() { let raw = Bytes::from_static(&[ @@ -206,4 +306,26 @@ mod tests { assert_eq!(packet.payload, Bytes::from_static(b"xyz")); assert_eq!(packet.to_bytes(), raw); } + + #[test] + fn test_mutable_vlan_packet_changes() { + let mut raw = [ + 0x00, 0x01, // TCI + 0x08, 0x00, // EtherType: IPv4 + b'a', b'b', + ]; + + let mut packet = MutableVlanPacket::new(&mut raw).expect("mutable vlan"); + assert_eq!(packet.get_vlan_id(), 1); + packet.set_priority_code_point(ClassOfService::VO); + packet.set_vlan_id(0x0abc); + packet.set_ethertype(EtherType::Ipv6); + packet.payload_mut()[0] = b'z'; + + let frozen = packet.freeze().expect("freeze"); + assert_eq!(frozen.header.priority_code_point, ClassOfService::VO); + assert_eq!(frozen.header.vlan_id, 0x0abc); + assert_eq!(frozen.header.ethertype, EtherType::Ipv6); + assert_eq!(frozen.payload[0], b'z'); + } } diff --git a/nex-packet/src/vxlan.rs b/nex-packet/src/vxlan.rs index a6a5fb9..56e41e8 100644 --- a/nex-packet/src/vxlan.rs +++ b/nex-packet/src/vxlan.rs @@ -2,7 +2,7 @@ use bytes::{Buf, Bytes}; use nex_core::bitfield::{self, u24be}; -use crate::packet::Packet; +use crate::packet::{GenericMutablePacket, Packet}; /// Virtual eXtensible Local Area Network (VXLAN) /// @@ -14,7 +14,7 @@ use crate::packet::Packet; /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ /// | VXLAN Network Identifier (VNI) | Reserved | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -pub struct Vxlan { +pub struct VxlanPacket { pub flags: u8, pub reserved1: u24be, pub vni: u24be, @@ -22,7 +22,7 @@ pub struct Vxlan { pub payload: Bytes, } -impl Packet for Vxlan { +impl Packet for VxlanPacket { type Header = (); fn from_buf(mut bytes: &[u8]) -> Option { @@ -108,14 +108,36 @@ impl Packet for Vxlan { } } -#[test] -fn vxlan_packet_test() { - let packet = Bytes::from_static(&[ - 0x08, // I flag - 0x00, 0x00, 0x00, // Reserved - 0x12, 0x34, 0x56, // VNI - 0x00, // Reserved - ]); - let vxlan_packet = Vxlan::from_bytes(packet.clone()).unwrap(); - assert_eq!(vxlan_packet.to_bytes(), packet); +/// Represents a mutable VXLAN packet. +pub type MutableVxlanPacket<'a> = GenericMutablePacket<'a, VxlanPacket>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn vxlan_packet_test() { + let packet = Bytes::from_static(&[ + 0x08, // I flag + 0x00, 0x00, 0x00, // Reserved + 0x12, 0x34, 0x56, // VNI + 0x00, // Reserved + ]); + let vxlan_packet = VxlanPacket::from_bytes(packet.clone()).unwrap(); + assert_eq!(vxlan_packet.to_bytes(), packet); + } + + #[test] + fn mutable_vxlan_packet_test() { + let mut raw = [0x08, 0x00, 0x00, 0x00, 0x12, 0x34, 0x56, 0x00, 0xaa]; + + use crate::packet::MutablePacket; + let mut packet = + ::new(&mut raw).expect("mutable vxlan"); + packet.header_mut()[0] = 0x0c; + packet.payload_mut()[0] = 0xff; + + let frozen = MutablePacket::freeze(&packet).expect("freeze"); + assert_eq!(frozen.payload[0], 0xff); + } } diff --git a/nex/Cargo.toml b/nex/Cargo.toml index 960be4e..10ab2c8 100644 --- a/nex/Cargo.toml +++ b/nex/Cargo.toml @@ -90,3 +90,7 @@ path = "../examples/async_datalink.rs" [[example]] name = "async_dump" path = "../examples/async_dump.rs" + +[[example]] +name = "mutable_chaining" +path = "../examples/mutable_chaining.rs"