diff --git a/src/modules/module-protocol-pulse/message.c b/src/modules/module-protocol-pulse/message.c index 6c897465d..48327e181 100644 --- a/src/modules/module-protocol-pulse/message.c +++ b/src/modules/module-protocol-pulse/message.c @@ -364,9 +364,22 @@ static int message_get(struct message *m, ...) return 0; } +static int ensure_size(struct message *m, uint32_t size) +{ + uint32_t alloc; + if (m->length + size <= m->allocated) + return size; + + alloc = SPA_ROUND_UP_N(SPA_MAX(m->allocated + size, 4096u), 4096u); + if ((m->data = realloc(m->data, alloc)) == NULL) + return -errno; + m->allocated = alloc; + return size; +} + static void write_8(struct message *m, uint8_t val) { - if (m->length < m->allocated) + if (ensure_size(m, 1) > 0) m->data[m->length] = val; m->length++; } @@ -374,7 +387,7 @@ static void write_8(struct message *m, uint8_t val) static void write_32(struct message *m, uint32_t val) { val = htonl(val); - if (m->length + 4 <= m->allocated) + if (ensure_size(m, 4) > 0) memcpy(m->data + m->length, &val, 4); m->length += 4; } @@ -384,7 +397,7 @@ static void write_string(struct message *m, const char *s) write_8(m, s ? TAG_STRING : TAG_STRING_NULL); if (s != NULL) { int len = strlen(s) + 1; - if (m->length + len <= m->allocated) + if (ensure_size(m, len) > 0) strcpy(&m->data[m->length], s); m->length += len; } @@ -420,7 +433,7 @@ static void write_arbitrary(struct message *m, const void *p, size_t length) { write_8(m, TAG_ARBITRARY); write_32(m, length); - if (length > 0 && m->length + length <= m->allocated) + if (ensure_size(m, length) > 0) memcpy(m->data + m->length, p, length); m->length += length; } diff --git a/src/modules/module-protocol-pulse/pulse-server.c b/src/modules/module-protocol-pulse/pulse-server.c index f4ed9c3f7..fe4e2830c 100644 --- a/src/modules/module-protocol-pulse/pulse-server.c +++ b/src/modules/module-protocol-pulse/pulse-server.c @@ -169,13 +169,14 @@ static const struct command commands[COMMAND_MAX]; static void message_free(struct client *client, struct message *msg, bool destroy) { spa_list_remove(&msg->link); - if (destroy) + if (destroy) { + free(msg->data); free(msg); - else + } else spa_list_append(&client->free_messages, &msg->link); } -static struct message *message_alloc(struct client *client, uint32_t size, uint32_t channel) +static struct message *message_alloc(struct client *client, uint32_t channel, uint32_t size) { struct message *msg = NULL; @@ -183,12 +184,11 @@ static struct message *message_alloc(struct client *client, uint32_t size, uint3 msg = spa_list_first(&client->free_messages, struct message, link); spa_list_remove(&msg->link); } - if (msg == NULL || msg->allocated < size) { - uint32_t alloc = SPA_ROUND_UP_N(SPA_MAX(size, 4096u), 4096u); - msg = realloc(msg, sizeof(struct message) + alloc); - msg->allocated = alloc; - msg->data = SPA_MEMBER(msg, sizeof(struct message), void); - } + if (msg == NULL) + msg = calloc(1, sizeof(struct message)); + if (msg == NULL) + return NULL; + ensure_size(msg, size); msg->channel = channel; msg->offset = 0; msg->length = size; @@ -265,7 +265,7 @@ static int send_message(struct client *client, struct message *m) static struct message *reply_new(struct client *client, uint32_t tag) { struct message *reply; - reply = message_alloc(client, 0, -1); + reply = message_alloc(client, -1, 0); pw_log_debug(NAME" %p: REPLY tag:%u", client, tag); message_put(reply, TAG_U32, COMMAND_REPLY, @@ -286,7 +286,7 @@ static int reply_error(struct client *client, uint32_t tag, uint32_t error) pw_log_debug(NAME" %p: ERROR tag:%u error:%u", client, tag, error); - reply = message_alloc(client, 0, -1); + reply = message_alloc(client, -1, 0); message_put(reply, TAG_U32, COMMAND_ERROR, TAG_U32, tag, @@ -462,7 +462,7 @@ static int send_command_request(struct stream *stream) pw_log_trace(NAME" %p: REQUEST channel:%d %u", stream, stream->channel, size); - msg = message_alloc(client, 0, -1); + msg = message_alloc(client, -1, 0); message_put(msg, TAG_U32, COMMAND_REQUEST, TAG_U32, -1, @@ -722,7 +722,7 @@ static void stream_process_record(struct stream *stream) size = buf->datas[0].chunk->size; - msg = message_alloc(client, size, stream->channel); + msg = message_alloc(client, stream->channel, size); if (msg != NULL) { memcpy(msg->data, SPA_MEMBER(p, buf->datas[0].chunk->offset, void), @@ -2367,7 +2367,7 @@ static int do_read(struct client *client) } if (client->message) message_free(client, client->message, false); - client->message = message_alloc(client, length, channel); + client->message = message_alloc(client, channel, length); } else if (client->message && client->in_index >= client->message->length + sizeof(client->desc)) { struct message *msg = client->message;