vfio: Use down_reads to protect iommu disconnects

If a group or device is released or a container is unset from a group
it can race against file ops on the container.  Protect these with
down_reads to allow concurrent users.

Signed-off-by: Alex Williamson <alex.williamson@redhat.com>
Reported-by: Michael S. Tsirkin <mst@redhat.com>
This commit is contained in:
Alex Williamson 2013-04-29 08:41:36 -06:00
parent 9587f44aa6
commit 0b43c08233

View file

@ -704,9 +704,13 @@ EXPORT_SYMBOL_GPL(vfio_del_group_dev);
static long vfio_ioctl_check_extension(struct vfio_container *container, static long vfio_ioctl_check_extension(struct vfio_container *container,
unsigned long arg) unsigned long arg)
{ {
struct vfio_iommu_driver *driver = container->iommu_driver; struct vfio_iommu_driver *driver;
long ret = 0; long ret = 0;
down_read(&container->group_lock);
driver = container->iommu_driver;
switch (arg) { switch (arg) {
/* No base extensions yet */ /* No base extensions yet */
default: default:
@ -736,6 +740,8 @@ static long vfio_ioctl_check_extension(struct vfio_container *container,
VFIO_CHECK_EXTENSION, arg); VFIO_CHECK_EXTENSION, arg);
} }
up_read(&container->group_lock);
return ret; return ret;
} }
@ -844,9 +850,6 @@ static long vfio_fops_unl_ioctl(struct file *filep,
if (!container) if (!container)
return ret; return ret;
driver = container->iommu_driver;
data = container->iommu_data;
switch (cmd) { switch (cmd) {
case VFIO_GET_API_VERSION: case VFIO_GET_API_VERSION:
ret = VFIO_API_VERSION; ret = VFIO_API_VERSION;
@ -858,8 +861,15 @@ static long vfio_fops_unl_ioctl(struct file *filep,
ret = vfio_ioctl_set_iommu(container, arg); ret = vfio_ioctl_set_iommu(container, arg);
break; break;
default: default:
down_read(&container->group_lock);
driver = container->iommu_driver;
data = container->iommu_data;
if (driver) /* passthrough all unrecognized ioctls */ if (driver) /* passthrough all unrecognized ioctls */
ret = driver->ops->ioctl(data, cmd, arg); ret = driver->ops->ioctl(data, cmd, arg);
up_read(&container->group_lock);
} }
return ret; return ret;
@ -910,35 +920,55 @@ static ssize_t vfio_fops_read(struct file *filep, char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct vfio_container *container = filep->private_data; struct vfio_container *container = filep->private_data;
struct vfio_iommu_driver *driver = container->iommu_driver; struct vfio_iommu_driver *driver;
ssize_t ret = -EINVAL;
if (unlikely(!driver || !driver->ops->read)) down_read(&container->group_lock);
return -EINVAL;
return driver->ops->read(container->iommu_data, buf, count, ppos); driver = container->iommu_driver;
if (likely(driver && driver->ops->read))
ret = driver->ops->read(container->iommu_data,
buf, count, ppos);
up_read(&container->group_lock);
return ret;
} }
static ssize_t vfio_fops_write(struct file *filep, const char __user *buf, static ssize_t vfio_fops_write(struct file *filep, const char __user *buf,
size_t count, loff_t *ppos) size_t count, loff_t *ppos)
{ {
struct vfio_container *container = filep->private_data; struct vfio_container *container = filep->private_data;
struct vfio_iommu_driver *driver = container->iommu_driver; struct vfio_iommu_driver *driver;
ssize_t ret = -EINVAL;
if (unlikely(!driver || !driver->ops->write)) down_read(&container->group_lock);
return -EINVAL;
return driver->ops->write(container->iommu_data, buf, count, ppos); driver = container->iommu_driver;
if (likely(driver && driver->ops->write))
ret = driver->ops->write(container->iommu_data,
buf, count, ppos);
up_read(&container->group_lock);
return ret;
} }
static int vfio_fops_mmap(struct file *filep, struct vm_area_struct *vma) static int vfio_fops_mmap(struct file *filep, struct vm_area_struct *vma)
{ {
struct vfio_container *container = filep->private_data; struct vfio_container *container = filep->private_data;
struct vfio_iommu_driver *driver = container->iommu_driver; struct vfio_iommu_driver *driver;
int ret = -EINVAL;
if (unlikely(!driver || !driver->ops->mmap)) down_read(&container->group_lock);
return -EINVAL;
return driver->ops->mmap(container->iommu_data, vma); driver = container->iommu_driver;
if (likely(driver && driver->ops->mmap))
ret = driver->ops->mmap(container->iommu_data, vma);
up_read(&container->group_lock);
return ret;
} }
static const struct file_operations vfio_fops = { static const struct file_operations vfio_fops = {