diff --git a/include/linux/hmm.h b/include/linux/hmm.h index dee2f8953b2e..e5834082de60 100644 --- a/include/linux/hmm.h +++ b/include/linux/hmm.h @@ -181,10 +181,31 @@ struct hmm_range { const uint64_t *values; uint64_t default_flags; uint64_t pfn_flags_mask; + uint8_t page_shift; uint8_t pfn_shift; bool valid; }; +/* + * hmm_range_page_shift() - return the page shift for the range + * @range: range being queried + * Returns: page shift (page size = 1 << page shift) for the range + */ +static inline unsigned hmm_range_page_shift(const struct hmm_range *range) +{ + return range->page_shift; +} + +/* + * hmm_range_page_size() - return the page size for the range + * @range: range being queried + * Returns: page size for the range in bytes + */ +static inline unsigned long hmm_range_page_size(const struct hmm_range *range) +{ + return 1UL << hmm_range_page_shift(range); +} + /* * hmm_range_wait_until_valid() - wait for range to be valid * @range: range affected by invalidation to wait on @@ -424,7 +445,8 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror); int hmm_range_register(struct hmm_range *range, struct mm_struct *mm, unsigned long start, - unsigned long end); + unsigned long end, + unsigned page_shift); void hmm_range_unregister(struct hmm_range *range); long hmm_range_snapshot(struct hmm_range *range); long hmm_range_fault(struct hmm_range *range, bool block); @@ -462,7 +484,8 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block) range->pfn_flags_mask = -1UL; ret = hmm_range_register(range, range->vma->vm_mm, - range->start, range->end); + range->start, range->end, + PAGE_SHIFT); if (ret) return (int)ret; diff --git a/mm/hmm.c b/mm/hmm.c index 0e21d3594ab6..52e40be56dc7 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -391,11 +391,13 @@ static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end, struct hmm_vma_walk *hmm_vma_walk = walk->private; struct hmm_range *range = hmm_vma_walk->range; uint64_t *pfns = range->pfns; - unsigned long i; + unsigned long i, page_size; hmm_vma_walk->last = addr; - i = (addr - range->start) >> PAGE_SHIFT; - for (; addr < end; addr += PAGE_SIZE, i++) { + page_size = hmm_range_page_size(range); + i = (addr - range->start) >> range->page_shift; + + for (; addr < end; addr += page_size, i++) { pfns[i] = range->values[HMM_PFN_NONE]; if (fault || write_fault) { int ret; @@ -707,6 +709,69 @@ again: return 0; } +static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask, + unsigned long start, unsigned long end, + struct mm_walk *walk) +{ +#ifdef CONFIG_HUGETLB_PAGE + unsigned long addr = start, i, pfn, mask, size, pfn_inc; + struct hmm_vma_walk *hmm_vma_walk = walk->private; + struct hmm_range *range = hmm_vma_walk->range; + struct vm_area_struct *vma = walk->vma; + struct hstate *h = hstate_vma(vma); + uint64_t orig_pfn, cpu_flags; + bool fault, write_fault; + spinlock_t *ptl; + pte_t entry; + int ret = 0; + + size = 1UL << huge_page_shift(h); + mask = size - 1; + if (range->page_shift != PAGE_SHIFT) { + /* Make sure we are looking at full page. */ + if (start & mask) + return -EINVAL; + if (end < (start + size)) + return -EINVAL; + pfn_inc = size >> PAGE_SHIFT; + } else { + pfn_inc = 1; + size = PAGE_SIZE; + } + + + ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte); + entry = huge_ptep_get(pte); + + i = (start - range->start) >> range->page_shift; + orig_pfn = range->pfns[i]; + range->pfns[i] = range->values[HMM_PFN_NONE]; + cpu_flags = pte_to_hmm_pfn_flags(range, entry); + fault = write_fault = false; + hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags, + &fault, &write_fault); + if (fault || write_fault) { + ret = -ENOENT; + goto unlock; + } + + pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift); + for (; addr < end; addr += size, i++, pfn += pfn_inc) + range->pfns[i] = hmm_pfn_from_pfn(range, pfn) | cpu_flags; + hmm_vma_walk->last = end; + +unlock: + spin_unlock(ptl); + + if (ret == -ENOENT) + return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); + + return ret; +#else /* CONFIG_HUGETLB_PAGE */ + return -EINVAL; +#endif +} + static void hmm_pfns_clear(struct hmm_range *range, uint64_t *pfns, unsigned long addr, @@ -730,6 +795,7 @@ static void hmm_pfns_special(struct hmm_range *range) * @mm: the mm struct for the range of virtual address * @start: start virtual address (inclusive) * @end: end virtual address (exclusive) + * @page_shift: expect page shift for the range * Returns 0 on success, -EFAULT if the address space is no longer valid * * Track updates to the CPU page table see include/linux/hmm.h @@ -737,16 +803,20 @@ static void hmm_pfns_special(struct hmm_range *range) int hmm_range_register(struct hmm_range *range, struct mm_struct *mm, unsigned long start, - unsigned long end) + unsigned long end, + unsigned page_shift) { - range->start = start & PAGE_MASK; - range->end = end & PAGE_MASK; + unsigned long mask = ((1UL << page_shift) - 1UL); + range->valid = false; range->hmm = NULL; - if (range->start >= range->end) + if ((start & mask) || (end & mask)) + return -EINVAL; + if (start >= end) return -EINVAL; + range->page_shift = page_shift; range->start = start; range->end = end; @@ -816,6 +886,7 @@ EXPORT_SYMBOL(hmm_range_unregister); */ long hmm_range_snapshot(struct hmm_range *range) { + const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP; unsigned long start = range->start, end; struct hmm_vma_walk hmm_vma_walk; struct hmm *hmm = range->hmm; @@ -832,15 +903,26 @@ long hmm_range_snapshot(struct hmm_range *range) return -EAGAIN; vma = find_vma(hmm->mm, start); - if (vma == NULL || (vma->vm_flags & VM_SPECIAL)) + if (vma == NULL || (vma->vm_flags & device_vma)) return -EFAULT; - /* FIXME support hugetlb fs/dax */ - if (is_vm_hugetlb_page(vma) || vma_is_dax(vma)) { + /* FIXME support dax */ + if (vma_is_dax(vma)) { hmm_pfns_special(range); return -EINVAL; } + if (is_vm_hugetlb_page(vma)) { + struct hstate *h = hstate_vma(vma); + + if (huge_page_shift(h) != range->page_shift && + range->page_shift != PAGE_SHIFT) + return -EINVAL; + } else { + if (range->page_shift != PAGE_SHIFT) + return -EINVAL; + } + if (!(vma->vm_flags & VM_READ)) { /* * If vma do not allow read access, then assume that it @@ -866,6 +948,7 @@ long hmm_range_snapshot(struct hmm_range *range) mm_walk.hugetlb_entry = NULL; mm_walk.pmd_entry = hmm_vma_walk_pmd; mm_walk.pte_hole = hmm_vma_walk_hole; + mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry; walk_page_range(start, end, &mm_walk); start = end; @@ -884,7 +967,7 @@ EXPORT_SYMBOL(hmm_range_snapshot); * then one of the following values may be returned: * * -EINVAL invalid arguments or mm or virtual address are in an - * invalid vma (ie either hugetlbfs or device file vma). + * invalid vma (for instance device file vma). * -ENOMEM: Out of memory. * -EPERM: Invalid permission (for instance asking for write and * range is read only). @@ -905,6 +988,7 @@ EXPORT_SYMBOL(hmm_range_snapshot); */ long hmm_range_fault(struct hmm_range *range, bool block) { + const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP; unsigned long start = range->start, end; struct hmm_vma_walk hmm_vma_walk; struct hmm *hmm = range->hmm; @@ -924,15 +1008,25 @@ long hmm_range_fault(struct hmm_range *range, bool block) } vma = find_vma(hmm->mm, start); - if (vma == NULL || (vma->vm_flags & VM_SPECIAL)) + if (vma == NULL || (vma->vm_flags & device_vma)) return -EFAULT; - /* FIXME support hugetlb fs/dax */ - if (is_vm_hugetlb_page(vma) || vma_is_dax(vma)) { + /* FIXME support dax */ + if (vma_is_dax(vma)) { hmm_pfns_special(range); return -EINVAL; } + if (is_vm_hugetlb_page(vma)) { + if (huge_page_shift(hstate_vma(vma)) != + range->page_shift && + range->page_shift != PAGE_SHIFT) + return -EINVAL; + } else { + if (range->page_shift != PAGE_SHIFT) + return -EINVAL; + } + if (!(vma->vm_flags & VM_READ)) { /* * If vma do not allow read access, then assume that it @@ -959,6 +1053,7 @@ long hmm_range_fault(struct hmm_range *range, bool block) mm_walk.hugetlb_entry = NULL; mm_walk.pmd_entry = hmm_vma_walk_pmd; mm_walk.pte_hole = hmm_vma_walk_hole; + mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry; do { ret = walk_page_range(start, end, &mm_walk);