diff --git a/lib/mpi/mpicoder.c b/lib/mpi/mpicoder.c index 350abaf4bee7..9c6f6b986682 100644 --- a/lib/mpi/mpicoder.c +++ b/lib/mpi/mpicoder.c @@ -81,7 +81,7 @@ MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread) { const uint8_t *buffer = xbuffer; int i, j; - unsigned nbits, nbytes, nlimbs, nread = 0; + unsigned nbits, nbytes, nlimbs; mpi_limb_t a; MPI val = NULL; @@ -94,9 +94,14 @@ MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread) return ERR_PTR(-EINVAL); } buffer += 2; - nread = 2; nbytes = DIV_ROUND_UP(nbits, 8); + if (nbytes + 2 > *ret_nread) { + printk("MPI: mpi larger than buffer nread=%d ret_nread=%d\n", + *ret_nread + 1, *ret_nread); + return ERR_PTR(-EINVAL); + } + nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); val = mpi_alloc(nlimbs); if (!val) @@ -109,12 +114,6 @@ MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread) for (; j > 0; j--) { a = 0; for (; i < BYTES_PER_MPI_LIMB; i++) { - if (++nread > *ret_nread) { - printk - ("MPI: mpi larger than buffer nread=%d ret_nread=%d\n", - nread, *ret_nread); - goto leave; - } a <<= 8; a |= *buffer++; } @@ -122,8 +121,7 @@ MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread) val->d[j - 1] = a; } -leave: - *ret_nread = nread; + *ret_nread = nbytes + 2; return val; } EXPORT_SYMBOL_GPL(mpi_read_from_buffer);