From fe1a3de4c8c86f2dd7deec5cb3c76bd03ea17c50 Mon Sep 17 00:00:00 2001 From: Tyler Neely Date: Wed, 6 May 2020 10:15:51 +0200 Subject: [PATCH] Use unsafe pointers for values that the kernel may update behind our backs --- Cargo.toml | 2 +- src/completion.rs | 11 ++++++-- src/io_uring/cq.rs | 46 +++++++++++++++++++------------- src/io_uring/sq.rs | 62 ++++++++++++++++++++++++------------------- src/io_uring/uring.rs | 37 ++++++++++++++++---------- src/lib.rs | 4 +-- 6 files changed, 96 insertions(+), 66 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f08989f..f29035d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rio" -version = "0.9.2" +version = "0.9.3" authors = ["Tyler Neely "] edition = "2018" description = "GPL-3.0 nice bindings for io_uring. MIT/Apache-2.0 license is available for spacejam's github sponsors." diff --git a/src/completion.rs b/src/completion.rs index 2049f29..51b8a0f 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -29,6 +29,13 @@ impl Default for CompletionState { } /// A Future value which may or may not be filled +/// +/// # Safety +/// +/// Never call `std::mem::forget` on this value. +/// It can lead to a use-after-free bug. The fact +/// that `std::mem::forget` is not marked unsafe +/// is a bug in the Rust standard library. #[derive(Debug)] pub struct Completion<'a, C: FromCqe> { lifetime: PhantomData<&'a C>, @@ -97,9 +104,9 @@ impl<'a, C: FromCqe> Completion<'a, C> { inner = self.cv.wait(inner).unwrap(); } - return inner.item.take().map(|io_result| { + inner.item.take().map(|io_result| { io_result.map(FromCqe::from_cqe) - }); + }) } } diff --git a/src/io_uring/cq.rs b/src/io_uring/cq.rs index 0a61a25..606798b 100644 --- a/src/io_uring/cq.rs +++ b/src/io_uring/cq.rs @@ -1,3 +1,5 @@ +#![allow(unsafe_code)] + use std::slice::from_raw_parts_mut; use super::*; @@ -5,11 +7,11 @@ use super::*; /// Consumes uring completions. #[derive(Debug)] pub struct Cq { - khead: &'static AtomicU32, - ktail: &'static AtomicU32, - kring_mask: &'static u32, - koverflow: &'static AtomicU32, - cqes: &'static mut [io_uring_cqe], + khead: *mut AtomicU32, + ktail: *mut AtomicU32, + kring_mask: *mut u32, + koverflow: *mut AtomicU32, + cqes: *mut [io_uring_cqe], ticket_queue: Arc, in_flight: Arc, ring_ptr: *const libc::c_void, @@ -54,18 +56,18 @@ impl Cq { Cq { ring_ptr: cq_ring_ptr, ring_mmap_sz: cq_ring_mmap_sz, - khead: &*(cq_ring_ptr + khead: cq_ring_ptr .add(params.cq_off.head as usize) - as *const AtomicU32), - ktail: &*(cq_ring_ptr + as *mut AtomicU32, + ktail: cq_ring_ptr .add(params.cq_off.tail as usize) - as *const AtomicU32), - kring_mask: &*(cq_ring_ptr + as *mut AtomicU32, + kring_mask: cq_ring_ptr .add(params.cq_off.ring_mask as usize) - as *const u32), - koverflow: &*(cq_ring_ptr + as *mut u32, + koverflow: cq_ring_ptr .add(params.cq_off.overflow as usize) - as *const AtomicU32), + as *mut AtomicU32, cqes: from_raw_parts_mut( cq_ring_ptr .add(params.cq_off.cqes as usize) @@ -95,7 +97,12 @@ impl Cq { if let Err(e) = block_for_cqe(ring_fd) { panic!("error in cqe reaper: {:?}", e); } else { - assert_eq!(self.koverflow.load(Relaxed), 0); + assert_eq!( + unsafe { + (*self.koverflow).load(Relaxed) + }, + 0 + ); if self.reap_ready_cqes().is_none() { // poison pill detected, time to shut down return; @@ -106,8 +113,9 @@ impl Cq { fn reap_ready_cqes(&mut self) -> Option { let _ = Measure::new(&M.reap_ready); - let mut head = self.khead.load(Acquire); - let tail = self.ktail.load(Acquire); + let mut head = + unsafe { &*self.khead }.load(Acquire); + let tail = unsafe { &*self.ktail }.load(Acquire); let count = tail - head; // hack to get around mutable usage in loop @@ -119,8 +127,8 @@ impl Cq { while head != tail { let cq = cq_opt.take().unwrap(); - let index = head & cq.kring_mask; - let cqe = &cq.cqes[index as usize]; + let index = head & unsafe { *cq.kring_mask }; + let cqe = &unsafe { &*cq.cqes }[index as usize]; // we detect a poison pill by seeing if // the user_data is really big, which it @@ -148,7 +156,7 @@ impl Cq { completion_filler.fill(result); - cq.khead.fetch_add(1, Release); + unsafe { &*cq.khead }.fetch_add(1, Release); cq_opt = Some(cq); head += 1; diff --git a/src/io_uring/sq.rs b/src/io_uring/sq.rs index 6a494e1..141d921 100644 --- a/src/io_uring/sq.rs +++ b/src/io_uring/sq.rs @@ -1,3 +1,5 @@ +#![allow(unsafe_code)] + use std::slice::from_raw_parts_mut; use super::*; @@ -5,11 +7,11 @@ use super::*; /// Sprays uring submissions. #[derive(Debug)] pub(crate) struct Sq { - khead: &'static AtomicU32, - ktail: &'static AtomicU32, - kring_mask: &'static u32, - kflags: &'static AtomicU32, - kdropped: &'static AtomicU32, + khead: *mut AtomicU32, + ktail: *mut AtomicU32, + kring_mask: *mut u32, + kflags: *mut AtomicU32, + kdropped: *mut AtomicU32, array: &'static mut [AtomicU32], sqes: &'static mut [io_uring_sqe], sqe_head: u32, @@ -64,7 +66,6 @@ impl Sq { IORING_OFF_SQES, )? as _; - #[allow(unsafe_code)] Ok(unsafe { Sq { sqe_head: 0, @@ -72,31 +73,31 @@ impl Sq { ring_ptr: sq_ring_ptr, ring_mmap_sz: sq_ring_mmap_sz, sqes_mmap_sz, - sqes: from_raw_parts_mut( - sqes_ptr, - params.sq_entries as usize, - ), - khead: &*(sq_ring_ptr + khead: sq_ring_ptr .add(params.sq_off.head as usize) - as *const AtomicU32), - ktail: &*(sq_ring_ptr + as *mut AtomicU32, + ktail: sq_ring_ptr .add(params.sq_off.tail as usize) - as *const AtomicU32), - kring_mask: &*(sq_ring_ptr + as *mut AtomicU32, + kring_mask: sq_ring_ptr .add(params.sq_off.ring_mask as usize) - as *const u32), - kflags: &*(sq_ring_ptr + as *mut u32, + kflags: sq_ring_ptr .add(params.sq_off.flags as usize) - as *const AtomicU32), - kdropped: &*(sq_ring_ptr + as *mut AtomicU32, + kdropped: sq_ring_ptr .add(params.sq_off.dropped as usize) - as *const AtomicU32), + as *mut AtomicU32, array: from_raw_parts_mut( sq_ring_ptr .add(params.sq_off.array as usize) as _, params.sq_entries as usize, ), + sqes: from_raw_parts_mut( + sqes_ptr, + params.sq_entries as usize, + ), } }) } @@ -113,11 +114,12 @@ impl Sq { self.sqe_head } else { // polling mode - self.khead.load(Acquire) + unsafe { &*self.khead }.load(Acquire) }; if next - head <= self.sqes.len() as u32 { - let idx = self.sqe_tail & self.kring_mask; + let idx = + self.sqe_tail & unsafe { *self.kring_mask }; let sqe = &mut self.sqes[idx as usize]; self.sqe_tail = next; @@ -129,10 +131,11 @@ impl Sq { // sets sq.array to point to current sq.sqe_head fn flush(&mut self) -> u32 { - let mask: u32 = *self.kring_mask; + let mask: u32 = unsafe { *self.kring_mask }; let to_submit = self.sqe_tail - self.sqe_head; - let mut ktail = self.ktail.load(Acquire); + let mut ktail = + unsafe { &*self.ktail }.load(Acquire); for _ in 0..to_submit { let index = ktail & mask; @@ -142,7 +145,9 @@ impl Sq { self.sqe_head += 1; } - let swapped = self.ktail.swap(ktail, Release); + let swapped = + unsafe { &*self.ktail }.swap(ktail, Release); + assert_eq!(swapped, ktail - to_submit); to_submit @@ -180,7 +185,7 @@ impl Sq { to_submit -= u32::try_from(ret).unwrap(); } flushed - } else if self.kflags.load(Acquire) + } else if unsafe { &*self.kflags }.load(Acquire) & IORING_SQ_NEED_WAKEUP != 0 { @@ -205,7 +210,10 @@ impl Sq { } else { 0 }; - assert_eq!(self.kdropped.load(Relaxed), 0); + assert_eq!( + unsafe { &*self.kdropped }.load(Relaxed), + 0 + ); u64::from(submitted) } } diff --git a/src/io_uring/uring.rs b/src/io_uring/uring.rs index a480c5d..3e53a74 100644 --- a/src/io_uring/uring.rs +++ b/src/io_uring/uring.rs @@ -1,5 +1,4 @@ use super::*; -use std::os::unix::io::{AsRawFd, IntoRawFd}; /// Nice bindings for the shiny new linux IO system #[derive(Debug, Clone)] @@ -71,8 +70,8 @@ impl Uring { ring_fd, sq: Mutex::new(sq), config, - in_flight: in_flight, - ticket_queue: ticket_queue, + in_flight, + ticket_queue, loaded: 0.into(), submitted: 0.into(), } @@ -148,7 +147,9 @@ impl Uring { address: &std::net::SocketAddr, order: Ordering, ) -> Completion<'a, ()> - where F: AsRawFd { + where + F: AsRawFd, + { let (addr, len) = addr2raw(address); self.with_sqe(None, false, |sqe| { sqe.prep_rw( @@ -729,15 +730,23 @@ impl Uring { } } -fn addr2raw(addr: &std::net::SocketAddr) -> (*const libc::sockaddr, libc::socklen_t) { - match *addr { - std::net::SocketAddr::V4(ref a) => { - let b: *const std::net::SocketAddrV4 = a; - (b as *const _, std::mem::size_of_val(a) as libc::socklen_t) - } - std::net::SocketAddr::V6(ref a) => { - let b: *const std::net::SocketAddrV6 = a; - (b as *const _, std::mem::size_of_val(a) as libc::socklen_t) +fn addr2raw( + addr: &std::net::SocketAddr, +) -> (*const libc::sockaddr, libc::socklen_t) { + match *addr { + std::net::SocketAddr::V4(ref a) => { + let b: *const std::net::SocketAddrV4 = a; + ( + b as *const _, + std::mem::size_of_val(a) as libc::socklen_t, + ) + } + std::net::SocketAddr::V6(ref a) => { + let b: *const std::net::SocketAddrV6 = a; + ( + b as *const _, + std::mem::size_of_val(a) as libc::socklen_t, + ) + } } - } } diff --git a/src/lib.rs b/src/lib.rs index 5c71fd3..de2e570 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -307,7 +307,5 @@ impl FromCqe for usize { } impl FromCqe for () { - fn from_cqe(_: io_uring::io_uring_cqe) -> () { - () - } + fn from_cqe(_: io_uring::io_uring_cqe) {} }