diff --git a/client.c b/client.c index e90312f9..c818c185 100644 --- a/client.c +++ b/client.c @@ -107,7 +107,26 @@ main(int argc, char *const *argv) } } - uint16_t term_len = strlen(term); + const uint16_t term_len = strlen(term) + 1; + uint32_t total_len = 0; + + /* Calculate total length */ + total_len += sizeof(term_len) + term_len; + total_len += sizeof(argc); + + for (int i = 0; i < argc; i++) { + uint16_t len = strlen(argv[i]) + 1; + total_len += sizeof(len) + len; + } + + LOG_DBG("term-len: %hu, argc: %d, total-len: %u", + term_len, argc, total_len); + + if (send(fd, &total_len, sizeof(total_len), 0) != sizeof(total_len)) { + LOG_ERRNO("failed to send total length to server"); + goto err; + } + if (send(fd, &term_len, sizeof(term_len), 0) != sizeof(term_len) || send(fd, term, term_len, 0) != term_len) { @@ -122,7 +141,7 @@ main(int argc, char *const *argv) } for (int i = 0; i < argc; i++) { - uint16_t len = strlen(argv[i]); + uint16_t len = strlen(argv[i]) + 1; LOG_DBG("argv[%d] = %s (%hu)", i, argv[i], len); diff --git a/server.c b/server.c index 498ea2f3..4de645b3 100644 --- a/server.c +++ b/server.c @@ -34,6 +34,12 @@ struct client { struct server *server; int fd; + struct { + uint8_t *data; + size_t left; + size_t idx; + } buffer; + struct terminal *term; }; @@ -60,6 +66,7 @@ client_destroy(struct client *client) } } + free(client->buffer.data); free(client); } @@ -93,7 +100,6 @@ fdm_client(struct fdm *fdm, int fd, int events, void *data) struct client *client = data; struct server *server = client->server; - char *term_env = NULL; char **argv = NULL; int argc = 0; @@ -104,47 +110,107 @@ fdm_client(struct fdm *fdm, int fd, int events, void *data) if (client->term != NULL) { uint8_t dummy[128]; - read(fd, dummy, sizeof(dummy)); + ssize_t count = read(fd, dummy, sizeof(dummy)); + LOG_WARN("client unexpectedly sent %zd bytes", count); + return true; /* TODO: shutdown instead? */ + } + + if (client->buffer.data == NULL) { + /* + * We haven't received any data yet - the first thing the + * client sends is the total size of the initialization + * data. + */ + uint32_t total_len; + if (recv(fd, &total_len, sizeof(total_len), 0) != sizeof(total_len)) { + LOG_ERRNO("failed to read total length"); + goto shutdown; + } + + LOG_DBG("total len: %u", total_len); + client->buffer.data = malloc(total_len + 1); + client->buffer.left = total_len; + client->buffer.idx = 0; + + /* Prevent our strlen() calls to run outside */ + client->buffer.data[total_len] = '\0'; + return true; /* Let FDM trigger again when we have more data */ + } + + /* Keep filling our buffer of initialization data */ + ssize_t count = recv( + fd, &client->buffer.data[client->buffer.idx], client->buffer.left, 0); + + if (count < 0) { + LOG_ERRNO("failed to read"); + goto shutdown; + } + + client->buffer.idx += count; + client->buffer.left -= count; + + if (client->buffer.left > 0) { + /* Not done yet */ return true; } - uint16_t term_env_len; - if (recv(fd, &term_env_len, sizeof(term_env_len), 0) != sizeof(term_env_len)) - goto shutdown; - - term_env = malloc(term_env_len + 1); - term_env[term_env_len] = '\0'; - if (term_env_len > 0) { - if (recv(fd, term_env, term_env_len, 0) != term_env_len) - goto shutdown; - } - - if (recv(fd, &argc, sizeof(argc), 0) != sizeof(argc)) - goto shutdown; - - LOG_DBG("argc = %d", argc); - - argv = calloc(argc + 1, sizeof(argv[0])); - for (int i = 0; i < argc; i++) { - uint16_t len; - if (recv(fd, &len, sizeof(len), 0) != sizeof(len)) - goto shutdown; - - argv[i] = malloc(len + 1); - argv[i][len] = '\0'; - if (len == 0) - continue; - - if (recv(fd, argv[i], len, 0) != len) - goto shutdown; - - LOG_DBG("argv[%d] = %s (%hu)", i, argv[i], len); - } + /* All initialization data received - time to instantiate a terminal! */ assert(client->term == NULL); + assert(client->buffer.data != NULL); + assert(client->buffer.left == 0); + + /* + * Parse the received buffer, verifying lengths etc + */ + +#define CHECK_BUF(sz) do { \ + if (p + (sz) > end) \ + goto shutdown; \ + } while (0) + + uint8_t *p = client->buffer.data; + const uint8_t *end = &client->buffer.data[client->buffer.idx]; + + CHECK_BUF(sizeof(uint16_t)); + uint16_t term_env_len = *(uint16_t *)p; p += sizeof(term_env_len); + + CHECK_BUF(term_env_len); + const char *term_env = (const char *)p; p += strlen(term_env) + 1; + LOG_DBG("TERM = %.*s", term_env_len, term_env); + + if (term_env_len != strlen(term_env) + 1) { + LOG_ERR("TERM length mismatch: indicated = %hu, actual = %zu", + term_env_len - 1, strlen(term_env)); + goto shutdown; + } + + CHECK_BUF(sizeof(argc)); + argc = *(int *)p; p += sizeof(argc); + argv = calloc(argc + 1, sizeof(argv[0])); + LOG_DBG("argc = %d", argc); + + for (int i = 0; i < argc; i++) { + CHECK_BUF(sizeof(uint16_t)); + uint16_t len = *(uint16_t *)p; p += sizeof(len); + + CHECK_BUF(len); + argv[i] = (char *)p; p += strlen(argv[i]) + 1; + LOG_DBG("argv[%d] = %s", i, argv[i]); + + if (len != strlen(argv[i]) + 1) { + LOG_ERR("argv[%d] length mismatch: indicated = %hu, actual = %zu", + i, len - 1, strlen(argv[i])); + goto shutdown; + } + } + argv[argc] = NULL; + +#undef CHECK_BUF + client->term = term_init( server->conf, server->fdm, server->wayl, - term_env_len > 0 ? term_env : server->conf->term, + strlen(term_env) > 0 ? term_env : server->conf->term, argc, argv, &term_shutdown_handler, client); if (client->term == NULL) { @@ -152,20 +218,13 @@ fdm_client(struct fdm *fdm, int fd, int events, void *data) goto shutdown; } - for (int i = 0; i < argc; i++) - free(argv[i]); free(argv); - free(term_env); return true; shutdown: LOG_DBG("client FD=%d: disconnected", client->fd); - for (int i = 0; i < argc; i++) - free(argv[i]); free(argv); - free(term_env); - fdm_del(fdm, fd); client->fd = -1;