diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c index 249472f05509..ce5dd219f2c8 100644 --- a/drivers/vfio/mdev/mdev_sysfs.c +++ b/drivers/vfio/mdev/mdev_sysfs.c @@ -92,8 +92,8 @@ static struct kobj_type mdev_type_ktype = { .release = mdev_type_release, }; -struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent, - struct attribute_group *group) +static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent, + struct attribute_group *group) { struct mdev_type *type; int ret; diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c index a89fa5d4e877..ff60bd1ea587 100644 --- a/drivers/vfio/pci/vfio_pci.c +++ b/drivers/vfio/pci/vfio_pci.c @@ -56,8 +56,6 @@ module_param(disable_idle_d3, bool, S_IRUGO | S_IWUSR); MODULE_PARM_DESC(disable_idle_d3, "Disable using the PCI D3 low power state for idle, unused devices"); -static DEFINE_MUTEX(driver_lock); - static inline bool vfio_vga_disabled(void) { #ifdef CONFIG_VFIO_PCI_VGA @@ -416,14 +414,14 @@ static void vfio_pci_release(void *device_data) { struct vfio_pci_device *vdev = device_data; - mutex_lock(&driver_lock); + mutex_lock(&vdev->reflck->lock); if (!(--vdev->refcnt)) { vfio_spapr_pci_eeh_release(vdev->pdev); vfio_pci_disable(vdev); } - mutex_unlock(&driver_lock); + mutex_unlock(&vdev->reflck->lock); module_put(THIS_MODULE); } @@ -436,7 +434,7 @@ static int vfio_pci_open(void *device_data) if (!try_module_get(THIS_MODULE)) return -ENODEV; - mutex_lock(&driver_lock); + mutex_lock(&vdev->reflck->lock); if (!vdev->refcnt) { ret = vfio_pci_enable(vdev); @@ -447,7 +445,7 @@ static int vfio_pci_open(void *device_data) } vdev->refcnt++; error: - mutex_unlock(&driver_lock); + mutex_unlock(&vdev->reflck->lock); if (ret) module_put(THIS_MODULE); return ret; @@ -1225,6 +1223,9 @@ static const struct vfio_device_ops vfio_pci_ops = { .request = vfio_pci_request, }; +static int vfio_pci_reflck_attach(struct vfio_pci_device *vdev); +static void vfio_pci_reflck_put(struct vfio_pci_reflck *reflck); + static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id) { struct vfio_pci_device *vdev; @@ -1271,6 +1272,14 @@ static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id) return ret; } + ret = vfio_pci_reflck_attach(vdev); + if (ret) { + vfio_del_group_dev(&pdev->dev); + vfio_iommu_group_put(group, &pdev->dev); + kfree(vdev); + return ret; + } + if (vfio_pci_is_vga(pdev)) { vga_client_register(pdev, vdev, NULL, vfio_pci_set_vga_decode); vga_set_legacy_decoding(pdev, @@ -1302,6 +1311,8 @@ static void vfio_pci_remove(struct pci_dev *pdev) if (!vdev) return; + vfio_pci_reflck_put(vdev->reflck); + vfio_iommu_group_put(pdev->dev.iommu_group, &pdev->dev); kfree(vdev->region); mutex_destroy(&vdev->ioeventfds_lock); @@ -1358,16 +1369,97 @@ static struct pci_driver vfio_pci_driver = { .err_handler = &vfio_err_handlers, }; +static DEFINE_MUTEX(reflck_lock); + +static struct vfio_pci_reflck *vfio_pci_reflck_alloc(void) +{ + struct vfio_pci_reflck *reflck; + + reflck = kzalloc(sizeof(*reflck), GFP_KERNEL); + if (!reflck) + return ERR_PTR(-ENOMEM); + + kref_init(&reflck->kref); + mutex_init(&reflck->lock); + + return reflck; +} + +static void vfio_pci_reflck_get(struct vfio_pci_reflck *reflck) +{ + kref_get(&reflck->kref); +} + +static int vfio_pci_reflck_find(struct pci_dev *pdev, void *data) +{ + struct vfio_pci_reflck **preflck = data; + struct vfio_device *device; + struct vfio_pci_device *vdev; + + device = vfio_device_get_from_dev(&pdev->dev); + if (!device) + return 0; + + if (pci_dev_driver(pdev) != &vfio_pci_driver) { + vfio_device_put(device); + return 0; + } + + vdev = vfio_device_data(device); + + if (vdev->reflck) { + vfio_pci_reflck_get(vdev->reflck); + *preflck = vdev->reflck; + vfio_device_put(device); + return 1; + } + + vfio_device_put(device); + return 0; +} + +static int vfio_pci_reflck_attach(struct vfio_pci_device *vdev) +{ + bool slot = !pci_probe_reset_slot(vdev->pdev->slot); + + mutex_lock(&reflck_lock); + + if (pci_is_root_bus(vdev->pdev->bus) || + vfio_pci_for_each_slot_or_bus(vdev->pdev, vfio_pci_reflck_find, + &vdev->reflck, slot) <= 0) + vdev->reflck = vfio_pci_reflck_alloc(); + + mutex_unlock(&reflck_lock); + + return PTR_ERR_OR_ZERO(vdev->reflck); +} + +static void vfio_pci_reflck_release(struct kref *kref) +{ + struct vfio_pci_reflck *reflck = container_of(kref, + struct vfio_pci_reflck, + kref); + + kfree(reflck); + mutex_unlock(&reflck_lock); +} + +static void vfio_pci_reflck_put(struct vfio_pci_reflck *reflck) +{ + kref_put_mutex(&reflck->kref, vfio_pci_reflck_release, &reflck_lock); +} + struct vfio_devices { struct vfio_device **devices; int cur_index; int max_index; }; -static int vfio_pci_get_devs(struct pci_dev *pdev, void *data) +static int vfio_pci_get_unused_devs(struct pci_dev *pdev, void *data) { struct vfio_devices *devs = data; struct vfio_device *device; + struct vfio_pci_device *vdev; if (devs->cur_index == devs->max_index) return -ENOSPC; @@ -1381,16 +1473,28 @@ static int vfio_pci_get_devs(struct pci_dev *pdev, void *data) return -EBUSY; } + vdev = vfio_device_data(device); + + /* Fault if the device is not unused */ + if (vdev->refcnt) { + vfio_device_put(device); + return -EBUSY; + } + devs->devices[devs->cur_index++] = device; return 0; } /* - * Attempt to do a bus/slot reset if there are devices affected by a reset for - * this device that are needs_reset and all of the affected devices are unused - * (!refcnt). Callers are required to hold driver_lock when calling this to - * prevent device opens and concurrent bus reset attempts. We prevent device - * unbinds by acquiring and holding a reference to the vfio_device. + * If a bus or slot reset is available for the provided device and: + * - All of the devices affected by that bus or slot reset are unused + * (!refcnt) + * - At least one of the affected devices is marked dirty via + * needs_reset (such as by lack of FLR support) + * Then attempt to perform that bus or slot reset. Callers are required + * to hold vdev->reflck->lock, protecting the bus/slot reset group from + * concurrent opens. A vfio_device reference is acquired for each device + * to prevent unbinds during the reset operation. * * NB: vfio-core considers a group to be viable even if some devices are * bound to drivers like pci-stub or pcieport. Here we require all devices @@ -1401,7 +1505,7 @@ static void vfio_pci_try_bus_reset(struct vfio_pci_device *vdev) { struct vfio_devices devs = { .cur_index = 0 }; int i = 0, ret = -EINVAL; - bool needs_reset = false, slot = false; + bool slot = false; struct vfio_pci_device *tmp; if (!pci_probe_reset_slot(vdev->pdev->slot)) @@ -1419,28 +1523,36 @@ static void vfio_pci_try_bus_reset(struct vfio_pci_device *vdev) return; if (vfio_pci_for_each_slot_or_bus(vdev->pdev, - vfio_pci_get_devs, &devs, slot)) + vfio_pci_get_unused_devs, + &devs, slot)) goto put_devs; + /* Does at least one need a reset? */ for (i = 0; i < devs.cur_index; i++) { tmp = vfio_device_data(devs.devices[i]); - if (tmp->needs_reset) - needs_reset = true; - if (tmp->refcnt) - goto put_devs; + if (tmp->needs_reset) { + ret = pci_reset_bus(vdev->pdev); + break; + } } - if (needs_reset) - ret = pci_reset_bus(vdev->pdev); - put_devs: for (i = 0; i < devs.cur_index; i++) { tmp = vfio_device_data(devs.devices[i]); - if (!ret) + + /* + * If reset was successful, affected devices no longer need + * a reset and we should return all the collateral devices + * to low power. If not successful, we either didn't reset + * the bus or timed out waiting for it, so let's not touch + * the power state. + */ + if (!ret) { tmp->needs_reset = false; - if (!tmp->refcnt && !disable_idle_d3) - pci_set_power_state(tmp->pdev, PCI_D3hot); + if (tmp != vdev && !disable_idle_d3) + pci_set_power_state(tmp->pdev, PCI_D3hot); + } vfio_device_put(devs.devices[i]); } diff --git a/drivers/vfio/pci/vfio_pci_private.h b/drivers/vfio/pci/vfio_pci_private.h index 127071b84dd7..8c0009f00818 100644 --- a/drivers/vfio/pci/vfio_pci_private.h +++ b/drivers/vfio/pci/vfio_pci_private.h @@ -82,6 +82,11 @@ struct vfio_pci_dummy_resource { struct list_head res_next; }; +struct vfio_pci_reflck { + struct kref kref; + struct mutex lock; +}; + struct vfio_pci_device { struct pci_dev *pdev; void __iomem *barmap[PCI_STD_RESOURCE_END + 1]; @@ -110,6 +115,7 @@ struct vfio_pci_device { bool needs_reset; bool nointx; struct pci_saved_state *pci_saved_state; + struct vfio_pci_reflck *reflck; int refcnt; int ioeventfds_nr; struct eventfd_ctx *err_trigger; diff --git a/samples/vfio-mdev/mtty.c b/samples/vfio-mdev/mtty.c index 7abb79d8313d..f6732aa16bb1 100644 --- a/samples/vfio-mdev/mtty.c +++ b/samples/vfio-mdev/mtty.c @@ -171,7 +171,7 @@ static struct mdev_state *find_mdev_state_by_uuid(uuid_le uuid) return NULL; } -void dump_buffer(char *buf, uint32_t count) +void dump_buffer(u8 *buf, uint32_t count) { #if defined(DEBUG) int i; @@ -250,7 +250,7 @@ static void mtty_create_config_space(struct mdev_state *mdev_state) } static void handle_pci_cfg_write(struct mdev_state *mdev_state, u16 offset, - char *buf, u32 count) + u8 *buf, u32 count) { u32 cfg_addr, bar_mask, bar_index = 0; @@ -304,7 +304,7 @@ static void handle_pci_cfg_write(struct mdev_state *mdev_state, u16 offset, } static void handle_bar_write(unsigned int index, struct mdev_state *mdev_state, - u16 offset, char *buf, u32 count) + u16 offset, u8 *buf, u32 count) { u8 data = *buf; @@ -475,7 +475,7 @@ static void handle_bar_write(unsigned int index, struct mdev_state *mdev_state, } static void handle_bar_read(unsigned int index, struct mdev_state *mdev_state, - u16 offset, char *buf, u32 count) + u16 offset, u8 *buf, u32 count) { /* Handle read requests by guest */ switch (offset) { @@ -650,7 +650,7 @@ static void mdev_read_base(struct mdev_state *mdev_state) } } -static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, +static ssize_t mdev_access(struct mdev_device *mdev, u8 *buf, size_t count, loff_t pos, bool is_write) { struct mdev_state *mdev_state; @@ -698,7 +698,7 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, #if defined(DEBUG_REGS) pr_info("%s: BAR%d WR @0x%llx %s val:0x%02x dlab:%d\n", __func__, index, offset, wr_reg[offset], - (u8)*buf, mdev_state->s[index].dlab); + *buf, mdev_state->s[index].dlab); #endif handle_bar_write(index, mdev_state, offset, buf, count); } else { @@ -708,7 +708,7 @@ static ssize_t mdev_access(struct mdev_device *mdev, char *buf, size_t count, #if defined(DEBUG_REGS) pr_info("%s: BAR%d RD @0x%llx %s val:0x%02x dlab:%d\n", __func__, index, offset, rd_reg[offset], - (u8)*buf, mdev_state->s[index].dlab); + *buf, mdev_state->s[index].dlab); #endif } break; @@ -827,7 +827,7 @@ ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, size_t count, if (count >= 4 && !(*ppos % 4)) { u32 val; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, false); if (ret <= 0) goto read_err; @@ -839,7 +839,7 @@ ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, size_t count, } else if (count >= 2 && !(*ppos % 2)) { u16 val; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, false); if (ret <= 0) goto read_err; @@ -851,7 +851,7 @@ ssize_t mtty_read(struct mdev_device *mdev, char __user *buf, size_t count, } else { u8 val; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, false); if (ret <= 0) goto read_err; @@ -889,7 +889,7 @@ ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, if (copy_from_user(&val, buf, sizeof(val))) goto write_err; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, true); if (ret <= 0) goto write_err; @@ -901,7 +901,7 @@ ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, if (copy_from_user(&val, buf, sizeof(val))) goto write_err; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, true); if (ret <= 0) goto write_err; @@ -913,7 +913,7 @@ ssize_t mtty_write(struct mdev_device *mdev, const char __user *buf, if (copy_from_user(&val, buf, sizeof(val))) goto write_err; - ret = mdev_access(mdev, (char *)&val, sizeof(val), + ret = mdev_access(mdev, (u8 *)&val, sizeof(val), *ppos, true); if (ret <= 0) goto write_err;