diff --git a/drivers/net/ethernet/intel/i40e/i40e_virtchnl_pf.c b/drivers/net/ethernet/intel/i40e/i40e_virtchnl_pf.c index 9f361e810990..d7fcc4ffa393 100644 --- a/drivers/net/ethernet/intel/i40e/i40e_virtchnl_pf.c +++ b/drivers/net/ethernet/intel/i40e/i40e_virtchnl_pf.c @@ -2549,10 +2549,6 @@ static int i40e_vc_validate_vf_msg(struct i40e_vf *vf, u32 v_opcode, bool err_msg_format = false; int valid_len = 0; - /* Check if VF is disabled. */ - if (test_bit(I40E_VF_STATE_DISABLED, &vf->vf_states)) - return I40E_ERR_PARAM; - /* Validate message length. */ switch (v_opcode) { case VIRTCHNL_OP_VERSION: @@ -2657,10 +2653,6 @@ static int i40e_vc_validate_vf_msg(struct i40e_vf *vf, u32 v_opcode, if (msglen >= valid_len) { struct virtchnl_rss_key *vrk = (struct virtchnl_rss_key *)msg; - if (vrk->key_len != I40E_HKEY_ARRAY_SIZE) { - err_msg_format = true; - break; - } valid_len += vrk->key_len - 1; } break; @@ -2669,10 +2661,6 @@ static int i40e_vc_validate_vf_msg(struct i40e_vf *vf, u32 v_opcode, if (msglen >= valid_len) { struct virtchnl_rss_lut *vrl = (struct virtchnl_rss_lut *)msg; - if (vrl->lut_entries != I40E_VF_HLUT_ARRAY_SIZE) { - err_msg_format = true; - break; - } valid_len += vrl->lut_entries - 1; } break; @@ -2719,9 +2707,27 @@ int i40e_vc_process_vf_msg(struct i40e_pf *pf, s16 vf_id, u32 v_opcode, if (local_vf_id >= pf->num_alloc_vfs) return -EINVAL; vf = &(pf->vf[local_vf_id]); + + /* Check if VF is disabled. */ + if (test_bit(I40E_VF_STATE_DISABLED, &vf->vf_states)) + return I40E_ERR_PARAM; + /* perform basic checks on the msg */ ret = i40e_vc_validate_vf_msg(vf, v_opcode, v_retval, msg, msglen); + /* perform additional checks specific to this driver */ + if (v_opcode == VIRTCHNL_OP_CONFIG_RSS_KEY) { + struct virtchnl_rss_key *vrk = (struct virtchnl_rss_key *)msg; + + if (vrk->key_len != I40E_HKEY_ARRAY_SIZE) + ret = -EINVAL; + } else if (v_opcode == VIRTCHNL_OP_CONFIG_RSS_LUT) { + struct virtchnl_rss_lut *vrl = (struct virtchnl_rss_lut *)msg; + + if (vrl->lut_entries != I40E_VF_HLUT_ARRAY_SIZE) + ret = -EINVAL; + } + if (ret) { dev_err(&pf->pdev->dev, "Invalid message from VF %d, opcode %d, len %d\n", local_vf_id, v_opcode, msglen);