diff --git a/drivers/scsi/ufs/ufshcd-pci.c b/drivers/scsi/ufs/ufshcd-pci.c index 8b9531204c2b..c007a7a69c28 100644 --- a/drivers/scsi/ufs/ufshcd-pci.c +++ b/drivers/scsi/ufs/ufshcd-pci.c @@ -134,26 +134,6 @@ static void ufshcd_pci_remove(struct pci_dev *pdev) ufshcd_remove(hba); } -/** - * ufshcd_set_dma_mask - Set dma mask based on the controller - * addressing capability - * @pdev: PCI device structure - * - * Returns 0 for success, non-zero for failure - */ -static int ufshcd_set_dma_mask(struct pci_dev *pdev) -{ - int err; - - if (!pci_set_dma_mask(pdev, DMA_BIT_MASK(64)) - && !pci_set_consistent_dma_mask(pdev, DMA_BIT_MASK(64))) - return 0; - err = pci_set_dma_mask(pdev, DMA_BIT_MASK(32)); - if (!err) - err = pci_set_consistent_dma_mask(pdev, DMA_BIT_MASK(32)); - return err; -} - /** * ufshcd_pci_probe - probe routine of the driver * @pdev: pointer to PCI device handle @@ -184,12 +164,6 @@ ufshcd_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id) mmio_base = pcim_iomap_table(pdev)[0]; - err = ufshcd_set_dma_mask(pdev); - if (err) { - dev_err(&pdev->dev, "set dma mask failed\n"); - return err; - } - err = ufshcd_init(&pdev->dev, &hba, mmio_base, pdev->irq); if (err) { dev_err(&pdev->dev, "Initialization failed\n"); diff --git a/drivers/scsi/ufs/ufshcd.c b/drivers/scsi/ufs/ufshcd.c index af1bffc1eac8..d41233914336 100644 --- a/drivers/scsi/ufs/ufshcd.c +++ b/drivers/scsi/ufs/ufshcd.c @@ -3258,6 +3258,22 @@ void ufshcd_remove(struct ufs_hba *hba) } EXPORT_SYMBOL_GPL(ufshcd_remove); +/** + * ufshcd_set_dma_mask - Set dma mask based on the controller + * addressing capability + * @hba: per adapter instance + * + * Returns 0 for success, non-zero for failure + */ +static int ufshcd_set_dma_mask(struct ufs_hba *hba) +{ + if (hba->capabilities & MASK_64_ADDRESSING_SUPPORT) { + if (!dma_set_mask_and_coherent(hba->dev, DMA_BIT_MASK(64))) + return 0; + } + return dma_set_mask_and_coherent(hba->dev, DMA_BIT_MASK(32)); +} + /** * ufshcd_init - Driver initialization routine * @dev: pointer to device handle @@ -3309,6 +3325,12 @@ int ufshcd_init(struct device *dev, struct ufs_hba **hba_handle, /* Get Interrupt bit mask per version */ hba->intr_mask = ufshcd_get_intr_mask(hba); + err = ufshcd_set_dma_mask(hba); + if (err) { + dev_err(hba->dev, "set dma mask failed\n"); + goto out_disable; + } + /* Allocate memory for host memory space */ err = ufshcd_memory_alloc(hba); if (err) {