diff options
Diffstat (limited to 'include/linux/hmm.h')
-rw-r--r-- | include/linux/hmm.h | 310 |
1 files changed, 250 insertions, 60 deletions
diff --git a/include/linux/hmm.h b/include/linux/hmm.h index ad50b7b4f141..51ec27a84668 100644 --- a/include/linux/hmm.h +++ b/include/linux/hmm.h @@ -77,8 +77,34 @@ #include <linux/migrate.h> #include <linux/memremap.h> #include <linux/completion.h> +#include <linux/mmu_notifier.h> -struct hmm; + +/* + * struct hmm - HMM per mm struct + * + * @mm: mm struct this HMM struct is bound to + * @lock: lock protecting ranges list + * @ranges: list of range being snapshotted + * @mirrors: list of mirrors for this mm + * @mmu_notifier: mmu notifier to track updates to CPU page table + * @mirrors_sem: read/write semaphore protecting the mirrors list + * @wq: wait queue for user waiting on a range invalidation + * @notifiers: count of active mmu notifiers + * @dead: is the mm dead ? + */ +struct hmm { + struct mm_struct *mm; + struct kref kref; + struct mutex lock; + struct list_head ranges; + struct list_head mirrors; + struct mmu_notifier mmu_notifier; + struct rw_semaphore mirrors_sem; + wait_queue_head_t wq; + long notifiers; + bool dead; +}; /* * hmm_pfn_flag_e - HMM flag enums @@ -131,6 +157,7 @@ enum hmm_pfn_value_e { /* * struct hmm_range - track invalidation lock on virtual address range * + * @hmm: the core HMM structure this range is active against * @vma: the vm area struct for the range * @list: all range lock are on a list * @start: range virtual start address (inclusive) @@ -138,10 +165,13 @@ enum hmm_pfn_value_e { * @pfns: array of pfns (big enough for the range) * @flags: pfn flags to match device driver page table * @values: pfn value for some special case (none, special, error, ...) + * @default_flags: default flags for the range (write, read, ... see hmm doc) + * @pfn_flags_mask: allows to mask pfn flags so that only default_flags matter * @pfn_shifts: pfn shift value (should be <= PAGE_SHIFT) * @valid: pfns array did not change since it has been fill by an HMM function */ struct hmm_range { + struct hmm *hmm; struct vm_area_struct *vma; struct list_head list; unsigned long start; @@ -149,41 +179,96 @@ struct hmm_range { uint64_t *pfns; const uint64_t *flags; const uint64_t *values; + uint64_t default_flags; + uint64_t pfn_flags_mask; + uint8_t page_shift; uint8_t pfn_shift; bool valid; }; /* - * hmm_pfn_to_page() - return struct page pointed to by a valid HMM pfn - * @range: range use to decode HMM pfn value - * @pfn: HMM pfn value to get corresponding struct page from - * Returns: struct page pointer if pfn is a valid HMM pfn, NULL otherwise + * hmm_range_page_shift() - return the page shift for the range + * @range: range being queried + * Returns: page shift (page size = 1 << page shift) for the range + */ +static inline unsigned hmm_range_page_shift(const struct hmm_range *range) +{ + return range->page_shift; +} + +/* + * hmm_range_page_size() - return the page size for the range + * @range: range being queried + * Returns: page size for the range in bytes + */ +static inline unsigned long hmm_range_page_size(const struct hmm_range *range) +{ + return 1UL << hmm_range_page_shift(range); +} + +/* + * hmm_range_wait_until_valid() - wait for range to be valid + * @range: range affected by invalidation to wait on + * @timeout: time out for wait in ms (ie abort wait after that period of time) + * Returns: true if the range is valid, false otherwise. + */ +static inline bool hmm_range_wait_until_valid(struct hmm_range *range, + unsigned long timeout) +{ + /* Check if mm is dead ? */ + if (range->hmm == NULL || range->hmm->dead || range->hmm->mm == NULL) { + range->valid = false; + return false; + } + if (range->valid) + return true; + wait_event_timeout(range->hmm->wq, range->valid || range->hmm->dead, + msecs_to_jiffies(timeout)); + /* Return current valid status just in case we get lucky */ + return range->valid; +} + +/* + * hmm_range_valid() - test if a range is valid or not + * @range: range + * Returns: true if the range is valid, false otherwise. + */ +static inline bool hmm_range_valid(struct hmm_range *range) +{ + return range->valid; +} + +/* + * hmm_device_entry_to_page() - return struct page pointed to by a device entry + * @range: range use to decode device entry value + * @entry: device entry value to get corresponding struct page from + * Returns: struct page pointer if entry is a valid, NULL otherwise * - * If the HMM pfn is valid (ie valid flag set) then return the struct page - * matching the pfn value stored in the HMM pfn. Otherwise return NULL. + * If the device entry is valid (ie valid flag set) then return the struct page + * matching the entry value. Otherwise return NULL. */ -static inline struct page *hmm_pfn_to_page(const struct hmm_range *range, - uint64_t pfn) +static inline struct page *hmm_device_entry_to_page(const struct hmm_range *range, + uint64_t entry) { - if (pfn == range->values[HMM_PFN_NONE]) + if (entry == range->values[HMM_PFN_NONE]) return NULL; - if (pfn == range->values[HMM_PFN_ERROR]) + if (entry == range->values[HMM_PFN_ERROR]) return NULL; - if (pfn == range->values[HMM_PFN_SPECIAL]) + if (entry == range->values[HMM_PFN_SPECIAL]) return NULL; - if (!(pfn & range->flags[HMM_PFN_VALID])) + if (!(entry & range->flags[HMM_PFN_VALID])) return NULL; - return pfn_to_page(pfn >> range->pfn_shift); + return pfn_to_page(entry >> range->pfn_shift); } /* - * hmm_pfn_to_pfn() - return pfn value store in a HMM pfn - * @range: range use to decode HMM pfn value - * @pfn: HMM pfn value to extract pfn from - * Returns: pfn value if HMM pfn is valid, -1UL otherwise + * hmm_device_entry_to_pfn() - return pfn value store in a device entry + * @range: range use to decode device entry value + * @entry: device entry to extract pfn from + * Returns: pfn value if device entry is valid, -1UL otherwise */ -static inline unsigned long hmm_pfn_to_pfn(const struct hmm_range *range, - uint64_t pfn) +static inline unsigned long +hmm_device_entry_to_pfn(const struct hmm_range *range, uint64_t pfn) { if (pfn == range->values[HMM_PFN_NONE]) return -1UL; @@ -197,31 +282,66 @@ static inline unsigned long hmm_pfn_to_pfn(const struct hmm_range *range, } /* - * hmm_pfn_from_page() - create a valid HMM pfn value from struct page + * hmm_device_entry_from_page() - create a valid device entry for a page * @range: range use to encode HMM pfn value - * @page: struct page pointer for which to create the HMM pfn - * Returns: valid HMM pfn for the page + * @page: page for which to create the device entry + * Returns: valid device entry for the page */ -static inline uint64_t hmm_pfn_from_page(const struct hmm_range *range, - struct page *page) +static inline uint64_t hmm_device_entry_from_page(const struct hmm_range *range, + struct page *page) { return (page_to_pfn(page) << range->pfn_shift) | range->flags[HMM_PFN_VALID]; } /* - * hmm_pfn_from_pfn() - create a valid HMM pfn value from pfn + * hmm_device_entry_from_pfn() - create a valid device entry value from pfn * @range: range use to encode HMM pfn value - * @pfn: pfn value for which to create the HMM pfn - * Returns: valid HMM pfn for the pfn + * @pfn: pfn value for which to create the device entry + * Returns: valid device entry for the pfn */ -static inline uint64_t hmm_pfn_from_pfn(const struct hmm_range *range, - unsigned long pfn) +static inline uint64_t hmm_device_entry_from_pfn(const struct hmm_range *range, + unsigned long pfn) { return (pfn << range->pfn_shift) | range->flags[HMM_PFN_VALID]; } +/* + * Old API: + * hmm_pfn_to_page() + * hmm_pfn_to_pfn() + * hmm_pfn_from_page() + * hmm_pfn_from_pfn() + * + * This are the OLD API please use new API, it is here to avoid cross-tree + * merge painfullness ie we convert things to new API in stages. + */ +static inline struct page *hmm_pfn_to_page(const struct hmm_range *range, + uint64_t pfn) +{ + return hmm_device_entry_to_page(range, pfn); +} + +static inline unsigned long hmm_pfn_to_pfn(const struct hmm_range *range, + uint64_t pfn) +{ + return hmm_device_entry_to_pfn(range, pfn); +} + +static inline uint64_t hmm_pfn_from_page(const struct hmm_range *range, + struct page *page) +{ + return hmm_device_entry_from_page(range, page); +} + +static inline uint64_t hmm_pfn_from_pfn(const struct hmm_range *range, + unsigned long pfn) +{ + return hmm_device_entry_from_pfn(range, pfn); +} + + #if IS_ENABLED(CONFIG_HMM_MIRROR) /* @@ -353,43 +473,113 @@ struct hmm_mirror { int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm); void hmm_mirror_unregister(struct hmm_mirror *mirror); - /* - * To snapshot the CPU page table, call hmm_vma_get_pfns(), then take a device - * driver lock that serializes device page table updates, then call - * hmm_vma_range_done(), to check if the snapshot is still valid. The same - * device driver page table update lock must also be used in the - * hmm_mirror_ops.sync_cpu_device_pagetables() callback, so that CPU page - * table invalidation serializes on it. - * - * YOU MUST CALL hmm_vma_range_done() ONCE AND ONLY ONCE EACH TIME YOU CALL - * hmm_vma_get_pfns() WITHOUT ERROR ! - * - * IF YOU DO NOT FOLLOW THE ABOVE RULE THE SNAPSHOT CONTENT MIGHT BE INVALID ! + * hmm_mirror_mm_is_alive() - test if mm is still alive + * @mirror: the HMM mm mirror for which we want to lock the mmap_sem + * Returns: false if the mm is dead, true otherwise + * + * This is an optimization it will not accurately always return -EINVAL if the + * mm is dead ie there can be false negative (process is being kill but HMM is + * not yet inform of that). It is only intented to be use to optimize out case + * where driver is about to do something time consuming and it would be better + * to skip it if the mm is dead. */ -int hmm_vma_get_pfns(struct hmm_range *range); -bool hmm_vma_range_done(struct hmm_range *range); +static inline bool hmm_mirror_mm_is_alive(struct hmm_mirror *mirror) +{ + struct mm_struct *mm; + + if (!mirror || !mirror->hmm) + return false; + mm = READ_ONCE(mirror->hmm->mm); + if (mirror->hmm->dead || !mm) + return false; + + return true; +} /* - * Fault memory on behalf of device driver. Unlike handle_mm_fault(), this will - * not migrate any device memory back to system memory. The HMM pfn array will - * be updated with the fault result and current snapshot of the CPU page table - * for the range. - * - * The mmap_sem must be taken in read mode before entering and it might be - * dropped by the function if the block argument is false. In that case, the - * function returns -EAGAIN. - * - * Return value does not reflect if the fault was successful for every single - * address or not. Therefore, the caller must to inspect the HMM pfn array to - * determine fault status for each address. - * - * Trying to fault inside an invalid vma will result in -EINVAL. + * Please see Documentation/vm/hmm.rst for how to use the range API. + */ +int hmm_range_register(struct hmm_range *range, + struct mm_struct *mm, + unsigned long start, + unsigned long end, + unsigned page_shift); +void hmm_range_unregister(struct hmm_range *range); +long hmm_range_snapshot(struct hmm_range *range); +long hmm_range_fault(struct hmm_range *range, bool block); +long hmm_range_dma_map(struct hmm_range *range, + struct device *device, + dma_addr_t *daddrs, + bool block); +long hmm_range_dma_unmap(struct hmm_range *range, + struct vm_area_struct *vma, + struct device *device, + dma_addr_t *daddrs, + bool dirty); + +/* + * HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range * - * See the function description in mm/hmm.c for further documentation. + * When waiting for mmu notifiers we need some kind of time out otherwise we + * could potentialy wait for ever, 1000ms ie 1s sounds like a long time to + * wait already. */ -int hmm_vma_fault(struct hmm_range *range, bool block); +#define HMM_RANGE_DEFAULT_TIMEOUT 1000 + +/* This is a temporary helper to avoid merge conflict between trees. */ +static inline bool hmm_vma_range_done(struct hmm_range *range) +{ + bool ret = hmm_range_valid(range); + + hmm_range_unregister(range); + return ret; +} + +/* This is a temporary helper to avoid merge conflict between trees. */ +static inline int hmm_vma_fault(struct hmm_range *range, bool block) +{ + long ret; + + /* + * With the old API the driver must set each individual entries with + * the requested flags (valid, write, ...). So here we set the mask to + * keep intact the entries provided by the driver and zero out the + * default_flags. + */ + range->default_flags = 0; + range->pfn_flags_mask = -1UL; + + ret = hmm_range_register(range, range->vma->vm_mm, + range->start, range->end, + PAGE_SHIFT); + if (ret) + return (int)ret; + + if (!hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT)) { + /* + * The mmap_sem was taken by driver we release it here and + * returns -EAGAIN which correspond to mmap_sem have been + * drop in the old API. + */ + up_read(&range->vma->vm_mm->mmap_sem); + return -EAGAIN; + } + + ret = hmm_range_fault(range, block); + if (ret <= 0) { + if (ret == -EBUSY || !ret) { + /* Same as above drop mmap_sem to match old API. */ + up_read(&range->vma->vm_mm->mmap_sem); + ret = -EBUSY; + } else if (ret == -EAGAIN) + ret = -EBUSY; + hmm_range_unregister(range); + return ret; + } + return 0; +} /* Below are for HMM internal use only! Not to be used by device driver! */ void hmm_mm_destroy(struct mm_struct *mm); |