diff --git a/drivers/iommu/arm-smmu-v3.c b/drivers/iommu/arm-smmu-v3.c index 591bb96047c9..b47a88757c18 100644 --- a/drivers/iommu/arm-smmu-v3.c +++ b/drivers/iommu/arm-smmu-v3.c @@ -1837,6 +1837,9 @@ static int arm_smmu_domain_get_attr(struct iommu_domain *domain, { struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain); + if (domain->type != IOMMU_DOMAIN_UNMANAGED) + return -EINVAL; + switch (attr) { case DOMAIN_ATTR_NESTING: *(int *)data = (smmu_domain->stage == ARM_SMMU_DOMAIN_NESTED); @@ -1852,6 +1855,9 @@ static int arm_smmu_domain_set_attr(struct iommu_domain *domain, int ret = 0; struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain); + if (domain->type != IOMMU_DOMAIN_UNMANAGED) + return -EINVAL; + mutex_lock(&smmu_domain->init_mutex); switch (attr) { diff --git a/drivers/iommu/arm-smmu.c b/drivers/iommu/arm-smmu.c index 6560c4a34b96..099215bbff89 100644 --- a/drivers/iommu/arm-smmu.c +++ b/drivers/iommu/arm-smmu.c @@ -1603,6 +1603,9 @@ static int arm_smmu_domain_get_attr(struct iommu_domain *domain, { struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain); + if (domain->type != IOMMU_DOMAIN_UNMANAGED) + return -EINVAL; + switch (attr) { case DOMAIN_ATTR_NESTING: *(int *)data = (smmu_domain->stage == ARM_SMMU_DOMAIN_NESTED); @@ -1618,6 +1621,9 @@ static int arm_smmu_domain_set_attr(struct iommu_domain *domain, int ret = 0; struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain); + if (domain->type != IOMMU_DOMAIN_UNMANAGED) + return -EINVAL; + mutex_lock(&smmu_domain->init_mutex); switch (attr) {