diff options
-rw-r--r-- | rust/kernel/pci.rs | 74 | ||||
-rw-r--r-- | samples/rust/rust_pci_driver/mod.rs | 4 |
2 files changed, 27 insertions, 51 deletions
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index c3135403e376..fd59d8cd76cc 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -4,11 +4,9 @@ //! //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) -use core::cell::UnsafeCell; use core::marker::PhantomData; use kernel::{ - alloc::flags::*, - bindings, + bindings, driver, error::{from_result, to_result}, prelude::*, }; @@ -32,7 +30,8 @@ pub trait Driver { fn remove(pdev: *mut bindings::pci_dev); } -struct Adapter<T: Driver>(PhantomData<T>); +/// PCI abstraction for registering PCI drivers. +pub struct Adapter<T: Driver>(PhantomData<T>); impl<T> Adapter<T> where @@ -53,62 +52,39 @@ where } } -/// Registration structure for a PCI driver. -/// -/// The existance of an instance of this structure implies that the corresponding PCI driver is -/// currently registered. -pub struct Registration<T: Driver> { - driver: Pin<KBox<UnsafeCell<bindings::pci_driver>>>, - _p: PhantomData<T>, -} - -impl<T> Registration<T> +impl<T> driver::RegistrationOps for Adapter<T> where T: Driver, { - /// Register a new PCI driver from `T: Driver`. - pub fn new(name: &'static CStr, module: &'static ThisModule) -> Result<Self> { - let mut driver = KBox::pin(UnsafeCell::new(bindings::pci_driver::default()), GFP_KERNEL)?; - - // Abuse that `bindings::pci_driver` is `Unpin`. - let inner = driver.get_mut(); - inner.name = name.as_char_ptr(); - inner.probe = Some(Adapter::<T>::probe); - inner.remove = Some(Adapter::<T>::remove); - inner.id_table = T::ID_TABLE; - - // SAFETY: `driver` is a valid `struct pci_driver`; `ThisModule` is equivalent to + type RegType = bindings::pci_driver; + + unsafe fn register( + pdrv: *mut Self::RegType, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + // SAFETY: By the safety requirements of this function `pdrv` is valid; we never move out + // of `pdrv`. + let pdrv = unsafe { &mut *pdrv }; + + pdrv.name = name.as_char_ptr(); + pdrv.probe = Some(Self::probe); + pdrv.remove = Some(Self::remove); + pdrv.id_table = T::ID_TABLE; + + // SAFETY: `pdrv` is a valid `struct pci_driver`; `ThisModule` is equivalent to // C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL` // terminated C string. // // Returns zero when the driver was registered successfully, a non-zero error code // otherwise, which is handled by `to_result`. to_result(unsafe { - bindings::__pci_register_driver(driver.get(), module.as_ptr(), name.as_char_ptr()) - })?; - - Ok(Self { - driver, - _p: PhantomData::<T>, + bindings::__pci_register_driver(pdrv, module.as_ptr(), name.as_char_ptr()) }) } -} -impl<T> Drop for Registration<T> -where - T: Driver, -{ - fn drop(&mut self) { - // SAFETY: `Module::drop` is only ever called when `self.drv` was registered - // successfully. - unsafe { bindings::pci_unregister_driver(self.driver.get()) }; + unsafe fn unregister(pdrv: *mut Self::RegType) { + // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::pci_unregister_driver(pdrv) } } } - -// SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to -// share references to it with multiple threads as nothing can be done. -unsafe impl<T> Sync for Registration<T> where T: Driver {} - -// SAFETY: Both registration and unregistration are implemented in C and safe to be performed from -// any thread, so `Registration` is `Send`. -unsafe impl<T> Send for Registration<T> where T: Driver {} diff --git a/samples/rust/rust_pci_driver/mod.rs b/samples/rust/rust_pci_driver/mod.rs index cd40a167e91c..e4e848d7c59d 100644 --- a/samples/rust/rust_pci_driver/mod.rs +++ b/samples/rust/rust_pci_driver/mod.rs @@ -15,13 +15,13 @@ module! { } struct Module { - _reg: pci::Registration<driver::Driver>, + _reg: kernel::driver::Registration<pci::Adapter<driver::Driver>>, } impl kernel::Module for Module { fn init(name: &'static CStr, module: &'static ThisModule) -> Result<Self> { Ok(Module { - _reg: pci::Registration::new(name, module)?, + _reg: kernel::driver::Registration::new(name, module)?, }) } } |