diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 8a2be4e40f22..98231d10890c 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -370,6 +370,9 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, struct vfio_dma *split; int ret; + if (!*size) + return 0; + /* * Existing dma region is completely covered, unmap all. This is * the likely case since userspace tends to map and unmap buffers @@ -411,7 +414,9 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, dma->vaddr += overlap; dma->size -= overlap; vfio_insert_dma(iommu, dma); - } + } else + kfree(dma); + *size = overlap; return 0; } @@ -425,48 +430,41 @@ static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, if (ret) return ret; - /* - * We may have unmapped the entire vfio_dma if the user is - * trying to unmap a sub-region of what was originally - * mapped. If anything left, we can resize in place since - * iova is unchanged. - */ - if (overlap < dma->size) - dma->size -= overlap; - else - vfio_remove_dma(iommu, dma); - + dma->size -= overlap; *size = overlap; return 0; } /* Split existing */ + split = kzalloc(sizeof(*split), GFP_KERNEL); + if (!split) + return -ENOMEM; + offset = start - dma->iova; ret = vfio_unmap_unpin(iommu, dma, start, size); if (ret) return ret; - WARN_ON(!*size); + if (!*size) { + kfree(split); + return -EINVAL; + } + tmp = dma->size; - /* - * Resize the lower vfio_dma in place, insert new for remaining - * upper segment. - */ + /* Resize the lower vfio_dma in place, before the below insert */ dma->size = offset; - if (offset + *size < tmp) { - split = kzalloc(sizeof(*split), GFP_KERNEL); - if (!split) - return -ENOMEM; - + /* Insert new for remainder, assuming it didn't all get unmapped */ + if (likely(offset + *size < tmp)) { split->size = tmp - offset - *size; split->iova = dma->iova + offset + *size; split->vaddr = dma->vaddr + offset + *size; split->prot = dma->prot; vfio_insert_dma(iommu, split); - } + } else + kfree(split); return 0; } @@ -483,7 +481,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, if (unmap->iova & mask) return -EINVAL; - if (unmap->size & mask) + if (!unmap->size || unmap->size & mask) return -EINVAL; WARN_ON(mask & PAGE_MASK); @@ -493,7 +491,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) { size = unmap->size; ret = vfio_remove_dma_overlap(iommu, unmap->iova, &size, dma); - if (ret) + if (ret || !size) break; unmapped += size; } @@ -635,7 +633,6 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, if (tmp && tmp->prot == prot && tmp->vaddr + tmp->size == vaddr) { tmp->size += size; - iova = tmp->iova; size = tmp->size; vaddr = tmp->vaddr; @@ -643,19 +640,28 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, } } - /* Check if we abut a region above - nothing above ~0 + 1 */ + /* + * Check if we abut a region above - nothing above ~0 + 1. + * If we abut above and below, remove and free. If only + * abut above, remove, modify, reinsert. + */ if (likely(iova + size)) { struct vfio_dma *tmp; - tmp = vfio_find_dma(iommu, iova + size, 1); if (tmp && tmp->prot == prot && tmp->vaddr == vaddr + size) { vfio_remove_dma(iommu, tmp); - if (dma) + if (dma) { dma->size += tmp->size; - else + kfree(tmp); + } else { size += tmp->size; - kfree(tmp); + tmp->size = size; + tmp->iova = iova; + tmp->vaddr = vaddr; + vfio_insert_dma(iommu, tmp); + dma = tmp; + } } } @@ -681,11 +687,10 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, iova = map->iova; size = map->size; while ((tmp = vfio_find_dma(iommu, iova, size))) { - if (vfio_remove_dma_overlap(iommu, iova, &size, tmp)) { - pr_warn("%s: Error rolling back failed map\n", - __func__); + int r = vfio_remove_dma_overlap(iommu, iova, + &size, tmp); + if (WARN_ON(r || !size)) break; - } } } @@ -813,6 +818,8 @@ static void vfio_iommu_type1_release(void *iommu_data) struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node); size_t size = dma->size; vfio_remove_dma_overlap(iommu, dma->iova, &size, dma); + if (WARN_ON(!size)) + break; } iommu_domain_free(iommu->domain);