Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/objc2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
* Allow cloning `Id<AnyObject>`.
* **BREAKING**: Restrict message sending to `&mut` references to things that
implement `IsAllowedMutable`.
* Disallow the ability to use non-`Self`-like types as the receiver in
`declare_class!`.

### Removed
* **BREAKING**: Removed `ProtocolType` implementation for `NSObject`.
Expand Down
249 changes: 247 additions & 2 deletions crates/objc2/src/__macro_helpers/declare_class.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
#[cfg(all(debug_assertions, feature = "verify"))]
use alloc::vec::Vec;
use core::marker::PhantomData;
#[cfg(all(debug_assertions, feature = "verify"))]
use std::collections::HashSet;

#[cfg(all(debug_assertions, feature = "verify"))]
use crate::runtime::{AnyProtocol, MethodDescription};

use objc2_encode::Encoding;

use crate::declare::{ClassBuilder, IvarType};
use crate::encode::Encode;
use crate::rc::{Allocated, Id};
use crate::runtime::{AnyClass, MethodImplementation, Sel};
use crate::runtime::{AnyObject, MessageReceiver};
use crate::{ClassType, Message};
use crate::{ClassType, Message, ProtocolType};

use super::{CopyOrMutCopy, Init, MaybeUnwrap, New, Other};
use crate::mutability;
Expand Down Expand Up @@ -52,7 +63,7 @@ where
// restrict it here to only be when the selector is `init`.
//
// Additionally, the receiver and return type must have the same generic
// generic parameter `T`.
// parameter `T`.
impl<Ret, T> MessageRecieveId<Allocated<T>, Ret> for Init
where
T: Message,
Expand Down Expand Up @@ -190,3 +201,237 @@ where
{
// Noop
}

#[derive(Debug)]
pub struct ClassBuilderHelper<T: ?Sized> {
builder: ClassBuilder,
p: PhantomData<T>,
}

#[track_caller]
fn failed_declaring_class(name: &str) -> ! {
panic!("could not create new class {name}. Perhaps a class with that name already exists?")
}

impl<T: ?Sized + ClassType> ClassBuilderHelper<T> {
#[inline]
#[track_caller]
#[allow(clippy::new_without_default)]
pub fn new() -> Self
where
T::Super: ClassType,
{
let builder = match ClassBuilder::new(T::NAME, <T::Super as ClassType>::class()) {
Some(builder) => builder,
None => failed_declaring_class(T::NAME),
};

Self {
builder,
p: PhantomData,
}
}

#[inline]
pub fn add_protocol_methods<P>(&mut self) -> ClassProtocolMethodsBuilder<'_, T>
where
P: ?Sized + ProtocolType,
{
let protocol = P::protocol();

if let Some(protocol) = protocol {
self.builder.add_protocol(protocol);
}

#[cfg(all(debug_assertions, feature = "verify"))]
{
ClassProtocolMethodsBuilder {
builder: self,
protocol,
required_instance_methods: protocol
.map(|p| p.method_descriptions(true))
.unwrap_or_default(),
optional_instance_methods: protocol
.map(|p| p.method_descriptions(false))
.unwrap_or_default(),
registered_instance_methods: HashSet::new(),
required_class_methods: protocol
.map(|p| p.class_method_descriptions(true))
.unwrap_or_default(),
optional_class_methods: protocol
.map(|p| p.class_method_descriptions(false))
.unwrap_or_default(),
registered_class_methods: HashSet::new(),
}
}

#[cfg(not(all(debug_assertions, feature = "verify")))]
{
ClassProtocolMethodsBuilder { builder: self }
}
}

// Addition: This restricts to callee `T`
#[inline]
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = T>,
{
// SAFETY: Checked by caller
unsafe { self.builder.add_method(sel, func) }
}

#[inline]
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = AnyClass>,
{
// SAFETY: Checked by caller
unsafe { self.builder.add_class_method(sel, func) }
}

#[inline]
pub fn add_static_ivar<I: IvarType>(&mut self) {
self.builder.add_static_ivar::<I>()
}

#[inline]
pub fn register(self) -> &'static AnyClass {
self.builder.register()
}
}

/// Helper for ensuring that:
/// - Only methods on the protocol are overriden.
/// - TODO: The methods have the correct signature.
/// - All required methods are overridden.
#[derive(Debug)]
pub struct ClassProtocolMethodsBuilder<'a, T: ?Sized> {
builder: &'a mut ClassBuilderHelper<T>,
#[cfg(all(debug_assertions, feature = "verify"))]
protocol: Option<&'static AnyProtocol>,
#[cfg(all(debug_assertions, feature = "verify"))]
required_instance_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
optional_instance_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
registered_instance_methods: HashSet<Sel>,
#[cfg(all(debug_assertions, feature = "verify"))]
required_class_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
optional_class_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
registered_class_methods: HashSet<Sel>,
}

impl<T: ?Sized + ClassType> ClassProtocolMethodsBuilder<'_, T> {
// Addition: This restricts to callee `T`
#[inline]
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = T>,
{
#[cfg(all(debug_assertions, feature = "verify"))]
if let Some(protocol) = self.protocol {
let _types = self
.required_instance_methods
.iter()
.chain(&self.optional_instance_methods)
.find(|desc| desc.sel == sel)
.map(|desc| desc.types)
.unwrap_or_else(|| {
panic!(
"failed overriding protocol method -[{protocol} {sel}]: method not found"
)
});
}

// SAFETY: Checked by caller
unsafe { self.builder.add_method(sel, func) };

#[cfg(all(debug_assertions, feature = "verify"))]
if !self.registered_instance_methods.insert(sel) {
unreachable!("already added")
}
}

#[inline]
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = AnyClass>,
{
#[cfg(all(debug_assertions, feature = "verify"))]
if let Some(protocol) = self.protocol {
let _types = self
.required_class_methods
.iter()
.chain(&self.optional_class_methods)
.find(|desc| desc.sel == sel)
.map(|desc| desc.types)
.unwrap_or_else(|| {
panic!(
"failed overriding protocol method +[{protocol} {sel}]: method not found"
)
});
}

// SAFETY: Checked by caller
unsafe { self.builder.add_class_method(sel, func) };

#[cfg(all(debug_assertions, feature = "verify"))]
if !self.registered_class_methods.insert(sel) {
unreachable!("already added")
}
}

#[cfg(all(debug_assertions, feature = "verify"))]
pub fn finish(self) {
let superclass = self.builder.builder.superclass();

if let Some(protocol) = self.protocol {
for desc in &self.required_instance_methods {
if self.registered_instance_methods.contains(&desc.sel) {
continue;
}

// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
if superclass
.and_then(|superclass| superclass.instance_method(desc.sel))
.is_some()
{
continue;
}

panic!(
"must implement required protocol method -[{protocol} {}]",
desc.sel
)
}
}

if let Some(protocol) = self.protocol {
for desc in &self.required_class_methods {
if self.registered_class_methods.contains(&desc.sel) {
continue;
}

// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
if superclass
.and_then(|superclass| superclass.class_method(desc.sel))
.is_some()
{
continue;
}

panic!(
"must implement required protocol method +[{protocol} {}]",
desc.sel
);
}
}
}

#[inline]
#[cfg(not(all(debug_assertions, feature = "verify")))]
pub fn finish(self) {}
}
Loading