diff --git a/mm/gup.c b/mm/gup.c index 76af4cfeaf68..84dd2063ca3d 100644 --- a/mm/gup.c +++ b/mm/gup.c @@ -1456,32 +1456,48 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr, return 1; } -static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr, +static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr, unsigned long end, struct page **pages, int *nr) { unsigned long fault_pfn; + int nr_start = *nr; - fault_pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); - return __gup_device_huge(fault_pfn, addr, end, pages, nr); + fault_pfn = pmd_pfn(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); + if (!__gup_device_huge(fault_pfn, addr, end, pages, nr)) + return 0; + + if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) { + undo_dev_pagemap(nr, nr_start, pages); + return 0; + } + return 1; } -static int __gup_device_huge_pud(pud_t pud, unsigned long addr, +static int __gup_device_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr, unsigned long end, struct page **pages, int *nr) { unsigned long fault_pfn; + int nr_start = *nr; - fault_pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT); - return __gup_device_huge(fault_pfn, addr, end, pages, nr); + fault_pfn = pud_pfn(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT); + if (!__gup_device_huge(fault_pfn, addr, end, pages, nr)) + return 0; + + if (unlikely(pud_val(orig) != pud_val(*pudp))) { + undo_dev_pagemap(nr, nr_start, pages); + return 0; + } + return 1; } #else -static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr, +static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr, unsigned long end, struct page **pages, int *nr) { BUILD_BUG(); return 0; } -static int __gup_device_huge_pud(pud_t pud, unsigned long addr, +static int __gup_device_huge_pud(pud_t pud, pud_t *pudp, unsigned long addr, unsigned long end, struct page **pages, int *nr) { BUILD_BUG(); @@ -1499,7 +1515,7 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr, return 0; if (pmd_devmap(orig)) - return __gup_device_huge_pmd(orig, addr, end, pages, nr); + return __gup_device_huge_pmd(orig, pmdp, addr, end, pages, nr); refs = 0; page = pmd_page(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); @@ -1537,7 +1553,7 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr, return 0; if (pud_devmap(orig)) - return __gup_device_huge_pud(orig, addr, end, pages, nr); + return __gup_device_huge_pud(orig, pudp, addr, end, pages, nr); refs = 0; page = pud_page(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);