diff --git a/crypto/dh.c b/crypto/dh.c index 5e960fe28681..9d19360e7189 100644 --- a/crypto/dh.c +++ b/crypto/dh.c @@ -129,7 +129,7 @@ static int dh_compute_value(struct kpp_request *req) if (ret) goto err_free_base; - ret = mpi_write_to_sgl(val, req->dst, &req->dst_len, &sign); + ret = mpi_write_to_sgl(val, req->dst, req->dst_len, &sign); if (ret) goto err_free_base; diff --git a/crypto/rsa.c b/crypto/rsa.c index dc692d43b666..4c280b6a3ea9 100644 --- a/crypto/rsa.c +++ b/crypto/rsa.c @@ -108,7 +108,7 @@ static int rsa_enc(struct akcipher_request *req) if (ret) goto err_free_m; - ret = mpi_write_to_sgl(c, req->dst, &req->dst_len, &sign); + ret = mpi_write_to_sgl(c, req->dst, req->dst_len, &sign); if (ret) goto err_free_m; @@ -147,7 +147,7 @@ static int rsa_dec(struct akcipher_request *req) if (ret) goto err_free_c; - ret = mpi_write_to_sgl(m, req->dst, &req->dst_len, &sign); + ret = mpi_write_to_sgl(m, req->dst, req->dst_len, &sign); if (ret) goto err_free_c; @@ -185,7 +185,7 @@ static int rsa_sign(struct akcipher_request *req) if (ret) goto err_free_m; - ret = mpi_write_to_sgl(s, req->dst, &req->dst_len, &sign); + ret = mpi_write_to_sgl(s, req->dst, req->dst_len, &sign); if (ret) goto err_free_m; @@ -226,7 +226,7 @@ static int rsa_verify(struct akcipher_request *req) if (ret) goto err_free_s; - ret = mpi_write_to_sgl(m, req->dst, &req->dst_len, &sign); + ret = mpi_write_to_sgl(m, req->dst, req->dst_len, &sign); if (ret) goto err_free_s; diff --git a/include/linux/mpi.h b/include/linux/mpi.h index f219559e5e80..1cc5ffb769af 100644 --- a/include/linux/mpi.h +++ b/include/linux/mpi.h @@ -80,7 +80,7 @@ void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign); int mpi_read_buffer(MPI a, uint8_t *buf, unsigned buf_len, unsigned *nbytes, int *sign); void *mpi_get_secure_buffer(MPI a, unsigned *nbytes, int *sign); -int mpi_write_to_sgl(MPI a, struct scatterlist *sg, unsigned *nbytes, +int mpi_write_to_sgl(MPI a, struct scatterlist *sg, unsigned nbytes, int *sign); #define log_mpidump g10_log_mpidump diff --git a/lib/mpi/mpicoder.c b/lib/mpi/mpicoder.c index 823cf5f5196b..7150e5c23604 100644 --- a/lib/mpi/mpicoder.c +++ b/lib/mpi/mpicoder.c @@ -237,16 +237,13 @@ EXPORT_SYMBOL_GPL(mpi_get_buffer); * @a: a multi precision integer * @sgl: scatterlist to write to. Needs to be at least * mpi_get_size(a) long. - * @nbytes: in/out param - it has the be set to the maximum number of - * bytes that can be written to sgl. This has to be at least - * the size of the integer a. On return it receives the actual - * length of the data written on success or the data that would - * be written if buffer was too small. + * @nbytes: the number of bytes to write. Leading bytes will be + * filled with zero. * @sign: if not NULL, it will be set to the sign of a. * * Return: 0 on success or error code in case of error */ -int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, +int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes, int *sign) { u8 *p, *p2; @@ -258,43 +255,44 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, #error please implement for this limb size. #endif unsigned int n = mpi_get_size(a); - int i, x, y = 0, lzeros, buf_len; - - if (!nbytes) - return -EINVAL; + int i, x, buf_len; if (sign) *sign = a->sign; - lzeros = count_lzeros(a); - - if (*nbytes < n - lzeros) { - *nbytes = n - lzeros; + if (nbytes < n) return -EOVERFLOW; - } - *nbytes = n - lzeros; buf_len = sgl->length; p2 = sg_virt(sgl); - for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB, - lzeros %= BYTES_PER_MPI_LIMB; - i >= 0; i--) { + while (nbytes > n) { + if (!buf_len) { + sgl = sg_next(sgl); + if (!sgl) + return -EINVAL; + buf_len = sgl->length; + p2 = sg_virt(sgl); + } + + i = min_t(unsigned, nbytes - n, buf_len); + memset(p2, 0, i); + p2 += i; + buf_len -= i; + nbytes -= i; + } + + for (i = a->nlimbs - 1; i >= 0; i--) { #if BYTES_PER_MPI_LIMB == 4 - alimb = cpu_to_be32(a->d[i]); + alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0; #elif BYTES_PER_MPI_LIMB == 8 - alimb = cpu_to_be64(a->d[i]); + alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0; #else #error please implement for this limb size. #endif - if (lzeros) { - y = lzeros; - lzeros = 0; - } + p = (u8 *)&alimb; - p = (u8 *)&alimb + y; - - for (x = 0; x < sizeof(alimb) - y; x++) { + for (x = 0; x < sizeof(alimb); x++) { if (!buf_len) { sgl = sg_next(sgl); if (!sgl) @@ -305,7 +303,6 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, *p2++ = *p++; buf_len--; } - y = 0; } return 0; }