diff --git a/include/net/sctp/structs.h b/include/net/sctp/structs.h index 6b2b8df8a1d2..28a7c8e44636 100644 --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -57,6 +57,7 @@ #include /* This gets us atomic counters. */ #include /* We need sk_buff_head. */ #include /* We need tq_struct. */ +#include /* We need flex_array. */ #include /* We need sctp* header structs. */ #include /* We need auth specific structs */ #include /* For inet_skb_parm */ @@ -1438,8 +1439,8 @@ struct sctp_stream_in { }; struct sctp_stream { - struct sctp_stream_out *out; - struct sctp_stream_in *in; + struct flex_array *out; + struct flex_array *in; __u16 outcnt; __u16 incnt; /* Current stream being sent, if any */ @@ -1465,14 +1466,14 @@ static inline struct sctp_stream_out *sctp_stream_out( const struct sctp_stream *stream, __u16 sid) { - return ((struct sctp_stream_out *)(stream->out)) + sid; + return flex_array_get(stream->out, sid); } static inline struct sctp_stream_in *sctp_stream_in( const struct sctp_stream *stream, __u16 sid) { - return ((struct sctp_stream_in *)(stream->in)) + sid; + return flex_array_get(stream->in, sid); } #define SCTP_SO(s, i) sctp_stream_out((s), (i)) diff --git a/net/sctp/stream.c b/net/sctp/stream.c index 7ca6fe4e7882..ffb940d3b57c 100644 --- a/net/sctp/stream.c +++ b/net/sctp/stream.c @@ -37,6 +37,53 @@ #include #include +static struct flex_array *fa_alloc(size_t elem_size, size_t elem_count, + gfp_t gfp) +{ + struct flex_array *result; + int err; + + result = flex_array_alloc(elem_size, elem_count, gfp); + if (result) { + err = flex_array_prealloc(result, 0, elem_count, gfp); + if (err) { + flex_array_free(result); + result = NULL; + } + } + + return result; +} + +static void fa_free(struct flex_array *fa) +{ + if (fa) + flex_array_free(fa); +} + +static void fa_copy(struct flex_array *fa, struct flex_array *from, + size_t index, size_t count) +{ + void *elem; + + while (count--) { + elem = flex_array_get(from, index); + flex_array_put(fa, index, elem, 0); + index++; + } +} + +static void fa_zero(struct flex_array *fa, size_t index, size_t count) +{ + void *elem; + + while (count--) { + elem = flex_array_get(fa, index); + memset(elem, 0, fa->element_size); + index++; + } +} + /* Migrates chunks from stream queues to new stream queues if needed, * but not across associations. Also, removes those chunks to streams * higher than the new max. @@ -78,34 +125,33 @@ static void sctp_stream_outq_migrate(struct sctp_stream *stream, * sctp_stream_update will swap ->out pointers. */ for (i = 0; i < outcnt; i++) { - kfree(new->out[i].ext); - new->out[i].ext = stream->out[i].ext; - stream->out[i].ext = NULL; + kfree(SCTP_SO(new, i)->ext); + SCTP_SO(new, i)->ext = SCTP_SO(stream, i)->ext; + SCTP_SO(stream, i)->ext = NULL; } } for (i = outcnt; i < stream->outcnt; i++) - kfree(stream->out[i].ext); + kfree(SCTP_SO(stream, i)->ext); } static int sctp_stream_alloc_out(struct sctp_stream *stream, __u16 outcnt, gfp_t gfp) { - struct sctp_stream_out *out; + struct flex_array *out; + size_t elem_size = sizeof(struct sctp_stream_out); - out = kmalloc_array(outcnt, sizeof(*out), gfp); + out = fa_alloc(elem_size, outcnt, gfp); if (!out) return -ENOMEM; if (stream->out) { - memcpy(out, stream->out, min(outcnt, stream->outcnt) * - sizeof(*out)); - kfree(stream->out); + fa_copy(out, stream->out, 0, min(outcnt, stream->outcnt)); + fa_free(stream->out); } if (outcnt > stream->outcnt) - memset(out + stream->outcnt, 0, - (outcnt - stream->outcnt) * sizeof(*out)); + fa_zero(out, stream->outcnt, (outcnt - stream->outcnt)); stream->out = out; @@ -115,22 +161,20 @@ static int sctp_stream_alloc_out(struct sctp_stream *stream, __u16 outcnt, static int sctp_stream_alloc_in(struct sctp_stream *stream, __u16 incnt, gfp_t gfp) { - struct sctp_stream_in *in; - - in = kmalloc_array(incnt, sizeof(*stream->in), gfp); + struct flex_array *in; + size_t elem_size = sizeof(struct sctp_stream_in); + in = fa_alloc(elem_size, incnt, gfp); if (!in) return -ENOMEM; if (stream->in) { - memcpy(in, stream->in, min(incnt, stream->incnt) * - sizeof(*in)); - kfree(stream->in); + fa_copy(in, stream->in, 0, min(incnt, stream->incnt)); + fa_free(stream->in); } if (incnt > stream->incnt) - memset(in + stream->incnt, 0, - (incnt - stream->incnt) * sizeof(*in)); + fa_zero(in, stream->incnt, (incnt - stream->incnt)); stream->in = in; @@ -174,7 +218,7 @@ in: ret = sctp_stream_alloc_in(stream, incnt, gfp); if (ret) { sched->free(stream); - kfree(stream->out); + fa_free(stream->out); stream->out = NULL; stream->outcnt = 0; goto out; @@ -206,8 +250,8 @@ void sctp_stream_free(struct sctp_stream *stream) sched->free(stream); for (i = 0; i < stream->outcnt; i++) kfree(SCTP_SO(stream, i)->ext); - kfree(stream->out); - kfree(stream->in); + fa_free(stream->out); + fa_free(stream->in); } void sctp_stream_clear(struct sctp_stream *stream)