summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rust/kernel/pci.rs74
-rw-r--r--samples/rust/rust_pci_driver/mod.rs4
2 files changed, 27 insertions, 51 deletions
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index c3135403e376..fd59d8cd76cc 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -4,11 +4,9 @@
//!
//! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h)
-use core::cell::UnsafeCell;
use core::marker::PhantomData;
use kernel::{
- alloc::flags::*,
- bindings,
+ bindings, driver,
error::{from_result, to_result},
prelude::*,
};
@@ -32,7 +30,8 @@ pub trait Driver {
fn remove(pdev: *mut bindings::pci_dev);
}
-struct Adapter<T: Driver>(PhantomData<T>);
+/// PCI abstraction for registering PCI drivers.
+pub struct Adapter<T: Driver>(PhantomData<T>);
impl<T> Adapter<T>
where
@@ -53,62 +52,39 @@ where
}
}
-/// Registration structure for a PCI driver.
-///
-/// The existance of an instance of this structure implies that the corresponding PCI driver is
-/// currently registered.
-pub struct Registration<T: Driver> {
- driver: Pin<KBox<UnsafeCell<bindings::pci_driver>>>,
- _p: PhantomData<T>,
-}
-
-impl<T> Registration<T>
+impl<T> driver::RegistrationOps for Adapter<T>
where
T: Driver,
{
- /// Register a new PCI driver from `T: Driver`.
- pub fn new(name: &'static CStr, module: &'static ThisModule) -> Result<Self> {
- let mut driver = KBox::pin(UnsafeCell::new(bindings::pci_driver::default()), GFP_KERNEL)?;
-
- // Abuse that `bindings::pci_driver` is `Unpin`.
- let inner = driver.get_mut();
- inner.name = name.as_char_ptr();
- inner.probe = Some(Adapter::<T>::probe);
- inner.remove = Some(Adapter::<T>::remove);
- inner.id_table = T::ID_TABLE;
-
- // SAFETY: `driver` is a valid `struct pci_driver`; `ThisModule` is equivalent to
+ type RegType = bindings::pci_driver;
+
+ unsafe fn register(
+ pdrv: *mut Self::RegType,
+ name: &'static CStr,
+ module: &'static ThisModule,
+ ) -> Result {
+ // SAFETY: By the safety requirements of this function `pdrv` is valid; we never move out
+ // of `pdrv`.
+ let pdrv = unsafe { &mut *pdrv };
+
+ pdrv.name = name.as_char_ptr();
+ pdrv.probe = Some(Self::probe);
+ pdrv.remove = Some(Self::remove);
+ pdrv.id_table = T::ID_TABLE;
+
+ // SAFETY: `pdrv` is a valid `struct pci_driver`; `ThisModule` is equivalent to
// C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL`
// terminated C string.
//
// Returns zero when the driver was registered successfully, a non-zero error code
// otherwise, which is handled by `to_result`.
to_result(unsafe {
- bindings::__pci_register_driver(driver.get(), module.as_ptr(), name.as_char_ptr())
- })?;
-
- Ok(Self {
- driver,
- _p: PhantomData::<T>,
+ bindings::__pci_register_driver(pdrv, module.as_ptr(), name.as_char_ptr())
})
}
-}
-impl<T> Drop for Registration<T>
-where
- T: Driver,
-{
- fn drop(&mut self) {
- // SAFETY: `Module::drop` is only ever called when `self.drv` was registered
- // successfully.
- unsafe { bindings::pci_unregister_driver(self.driver.get()) };
+ unsafe fn unregister(pdrv: *mut Self::RegType) {
+ // SAFETY: `pdrv` is guaranteed to be a valid `RegType`.
+ unsafe { bindings::pci_unregister_driver(pdrv) }
}
}
-
-// SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to
-// share references to it with multiple threads as nothing can be done.
-unsafe impl<T> Sync for Registration<T> where T: Driver {}
-
-// SAFETY: Both registration and unregistration are implemented in C and safe to be performed from
-// any thread, so `Registration` is `Send`.
-unsafe impl<T> Send for Registration<T> where T: Driver {}
diff --git a/samples/rust/rust_pci_driver/mod.rs b/samples/rust/rust_pci_driver/mod.rs
index cd40a167e91c..e4e848d7c59d 100644
--- a/samples/rust/rust_pci_driver/mod.rs
+++ b/samples/rust/rust_pci_driver/mod.rs
@@ -15,13 +15,13 @@ module! {
}
struct Module {
- _reg: pci::Registration<driver::Driver>,
+ _reg: kernel::driver::Registration<pci::Adapter<driver::Driver>>,
}
impl kernel::Module for Module {
fn init(name: &'static CStr, module: &'static ThisModule) -> Result<Self> {
Ok(Module {
- _reg: pci::Registration::new(name, module)?,
+ _reg: kernel::driver::Registration::new(name, module)?,
})
}
}