diff options
-rw-r--r-- | mm/maccess.c | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/mm/maccess.c b/mm/maccess.c index 17b26bceb716..db7cf48d8fed 100644 --- a/mm/maccess.c +++ b/mm/maccess.c @@ -11,6 +11,81 @@ bool __weak probe_kernel_read_allowed(const void *unsafe_src, size_t size) return true; } +#ifdef HAVE_GET_KERNEL_NOFAULT + +#define probe_kernel_read_loop(dst, src, len, type, err_label) \ + while (len >= sizeof(type)) { \ + __get_kernel_nofault(dst, src, type, err_label); \ + dst += sizeof(type); \ + src += sizeof(type); \ + len -= sizeof(type); \ + } + +long probe_kernel_read(void *dst, const void *src, size_t size) +{ + if (!probe_kernel_read_allowed(src, size)) + return -EFAULT; + + pagefault_disable(); + probe_kernel_read_loop(dst, src, size, u64, Efault); + probe_kernel_read_loop(dst, src, size, u32, Efault); + probe_kernel_read_loop(dst, src, size, u16, Efault); + probe_kernel_read_loop(dst, src, size, u8, Efault); + pagefault_enable(); + return 0; +Efault: + pagefault_enable(); + return -EFAULT; +} +EXPORT_SYMBOL_GPL(probe_kernel_read); + +#define probe_kernel_write_loop(dst, src, len, type, err_label) \ + while (len >= sizeof(type)) { \ + __put_kernel_nofault(dst, src, type, err_label); \ + dst += sizeof(type); \ + src += sizeof(type); \ + len -= sizeof(type); \ + } + +long probe_kernel_write(void *dst, const void *src, size_t size) +{ + pagefault_disable(); + probe_kernel_write_loop(dst, src, size, u64, Efault); + probe_kernel_write_loop(dst, src, size, u32, Efault); + probe_kernel_write_loop(dst, src, size, u16, Efault); + probe_kernel_write_loop(dst, src, size, u8, Efault); + pagefault_enable(); + return 0; +Efault: + pagefault_enable(); + return -EFAULT; +} + +long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count) +{ + const void *src = unsafe_addr; + + if (unlikely(count <= 0)) + return 0; + if (!probe_kernel_read_allowed(unsafe_addr, count)) + return -EFAULT; + + pagefault_disable(); + do { + __get_kernel_nofault(dst, src, u8, Efault); + dst++; + src++; + } while (dst[-1] && src - unsafe_addr < count); + pagefault_enable(); + + dst[-1] = '\0'; + return src - unsafe_addr; +Efault: + pagefault_enable(); + dst[-1] = '\0'; + return -EFAULT; +} +#else /* HAVE_GET_KERNEL_NOFAULT */ /** * probe_kernel_read(): safely attempt to read from kernel-space * @dst: pointer to the buffer that shall take the data @@ -113,6 +188,7 @@ long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count) return ret ? -EFAULT : src - unsafe_addr; } +#endif /* HAVE_GET_KERNEL_NOFAULT */ /** * probe_user_read(): safely attempt to read from a user-space location |