diff --git a/src/connection.c b/src/connection.c index c49ca3d4..06cc66fa 100644 --- a/src/connection.c +++ b/src/connection.c @@ -307,7 +307,7 @@ wl_connection_data(struct wl_connection *connection, uint32_t mask) msg.msg_flags = 0; do { - len = recvmsg(connection->fd, &msg, MSG_CMSG_CLOEXEC); + len = wl_os_recvmsg_cloexec(connection->fd, &msg, 0); } while (len < 0 && errno == EINTR); if (len < 0) { diff --git a/src/wayland-os.c b/src/wayland-os.c index 4a19da6f..eb53eec4 100644 --- a/src/wayland-os.c +++ b/src/wayland-os.c @@ -79,3 +79,48 @@ wl_os_dupfd_cloexec(int fd, long minfd) newfd = fcntl(fd, F_DUPFD, minfd); return set_cloexec_or_close(newfd); } + +static ssize_t +recvmsg_cloexec_fallback(int sockfd, struct msghdr *msg, int flags) +{ + ssize_t len; + struct cmsghdr *cmsg; + unsigned char *data; + int *fd; + int *end; + + len = recvmsg(sockfd, msg, flags); + if (len == -1) + return -1; + + if (!msg->msg_control || msg->msg_controllen == 0) + return len; + + cmsg = CMSG_FIRSTHDR(msg); + for (; cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET || + cmsg->cmsg_type != SCM_RIGHTS) + continue; + + data = CMSG_DATA(cmsg); + end = (int *)(data + cmsg->cmsg_len - CMSG_LEN(0)); + for (fd = (int *)data; fd < end; ++fd) + *fd = set_cloexec_or_close(*fd); + } + + return len; +} + +ssize_t +wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags) +{ + ssize_t len; + + len = recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC); + if (len >= 0) + return len; + if (errno != EINVAL) + return -1; + + return recvmsg_cloexec_fallback(sockfd, msg, flags); +} diff --git a/src/wayland-os.h b/src/wayland-os.h index 456d8b03..43c317b9 100644 --- a/src/wayland-os.h +++ b/src/wayland-os.h @@ -29,6 +29,10 @@ wl_os_socket_cloexec(int domain, int type, int protocol); int wl_os_dupfd_cloexec(int fd, long minfd); +ssize_t +wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags); + + /* * The following are for wayland-os.c and the unit tests. * Do not use them elsewhere. @@ -44,6 +48,10 @@ wl_os_dupfd_cloexec(int fd, long minfd); #define F_DUPFD_CLOEXEC 1030 #endif +#ifndef MSG_CMSG_CLOEXEC +#define MSG_CMSG_CLOEXEC 0x40000000 +#endif + #endif /* __linux__ */ #endif diff --git a/tests/os-wrappers-test.c b/tests/os-wrappers-test.c index 2272b730..657f1feb 100644 --- a/tests/os-wrappers-test.c +++ b/tests/os-wrappers-test.c @@ -1,5 +1,6 @@ /* * Copyright © 2012 Collabora, Ltd. + * Copyright © 2012 Intel Corporation * * Permission to use, copy, modify, distribute, and sell this software and its * documentation for any purpose is hereby granted without fee, provided that @@ -26,12 +27,15 @@ #include #include #include +#include #include #include #include #include #include +#include +#include "../src/wayland-private.h" #include "test-runner.h" #include "../src/wayland-os.h" @@ -43,12 +47,16 @@ static int wrapped_calls_socket; static int (*real_fcntl)(int, int, ...); static int wrapped_calls_fcntl; +static ssize_t (*real_recvmsg)(int, struct msghdr *, int); +static int wrapped_calls_recvmsg; + static void init_fallbacks(int do_fallbacks) { fall_back = do_fallbacks; real_socket = dlsym(RTLD_NEXT, "socket"); real_fcntl = dlsym(RTLD_NEXT, "fcntl"); + real_recvmsg = dlsym(RTLD_NEXT, "recvmsg"); } __attribute__ ((visibility("default"))) int @@ -84,6 +92,19 @@ fcntl(int fd, int cmd, ...) return real_fcntl(fd, cmd, arg); } +__attribute__ ((visibility("default"))) ssize_t +recvmsg(int sockfd, struct msghdr *msg, int flags) +{ + wrapped_calls_recvmsg++; + + if (fall_back && (flags & MSG_CMSG_CLOEXEC)) { + errno = EINVAL; + return -1; + } + + return real_recvmsg(sockfd, msg, flags); +} + static void do_os_wrappers_socket_cloexec(int n) { @@ -156,3 +177,157 @@ TEST(os_wrappers_dupfd_cloexec_fallback) init_fallbacks(1); do_os_wrappers_dupfd_cloexec(3); } + +struct marshal_data { + struct wl_connection *read_connection; + struct wl_connection *write_connection; + int s[2]; + uint32_t read_mask; + uint32_t write_mask; + union { + int h[3]; + } value; + int nr_fds_begin; + int nr_fds_conn; + int wrapped_calls; +}; + +static int +update_func(struct wl_connection *connection, uint32_t mask, void *data) +{ + uint32_t *m = data; + + *m = mask; + + return 0; +} + +static void +setup_marshal_data(struct marshal_data *data) +{ + assert(socketpair(AF_UNIX, + SOCK_STREAM | SOCK_CLOEXEC, 0, data->s) == 0); + + data->read_connection = + wl_connection_create(data->s[0], + update_func, &data->read_mask); + assert(data->read_connection); + assert(data->read_mask == WL_CONNECTION_READABLE); + + data->write_connection = + wl_connection_create(data->s[1], + update_func, &data->write_mask); + assert(data->write_connection); + assert(data->write_mask == WL_CONNECTION_READABLE); +} + +static void +marshal_demarshal(struct marshal_data *data, + void (*func)(void), int size, const char *format, ...) +{ + struct wl_closure closure; + static const int opcode = 4444; + static struct wl_object sender = { NULL, NULL, 1234 }; + struct wl_message message = { "test", format, NULL }; + struct wl_map objects; + struct wl_object object; + va_list ap; + uint32_t msg[1] = { 1234 }; + int ret; + + va_start(ap, format); + ret = wl_closure_vmarshal(&closure, &sender, opcode, ap, &message); + va_end(ap); + + assert(ret == 0); + assert(wl_closure_send(&closure, data->write_connection) == 0); + wl_closure_destroy(&closure); + assert(data->write_mask == + (WL_CONNECTION_WRITABLE | WL_CONNECTION_READABLE)); + assert(wl_connection_data(data->write_connection, + WL_CONNECTION_WRITABLE) == 0); + assert(data->write_mask == WL_CONNECTION_READABLE); + + assert(wl_connection_data(data->read_connection, + WL_CONNECTION_READABLE) == size); + + wl_map_init(&objects); + object.id = msg[0]; + ret = wl_connection_demarshal(data->read_connection, + &closure, size, &objects, &message); + wl_closure_invoke(&closure, &object, func, data); + wl_closure_destroy(&closure); +} + +static void +validate_recvmsg_h(struct marshal_data *data, + struct wl_object *object, int fd1, int fd2, int fd3) +{ + struct stat buf1, buf2; + + assert(fd1 >= 0); + assert(fd2 >= 0); + assert(fd3 >= 0); + + assert(fd1 != data->value.h[0]); + assert(fd2 != data->value.h[1]); + assert(fd3 != data->value.h[2]); + + assert(fstat(fd3, &buf1) == 0); + assert(fstat(data->value.h[2], &buf2) == 0); + assert(buf1.st_dev == buf2.st_dev); + assert(buf1.st_ino == buf2.st_ino); + + /* close the original file descriptors */ + close(data->value.h[0]); + close(data->value.h[1]); + close(data->value.h[2]); + + /* the dup'd (received) fds should still be open */ + assert(count_open_fds() == data->nr_fds_conn + 3); + + /* + * Must have 2 calls if falling back, but must also allow + * falling back without a forced fallback. + */ + assert(wrapped_calls_recvmsg > data->wrapped_calls); + + if (data->wrapped_calls == 0 && wrapped_calls_recvmsg > 1) + printf("recvmsg fell back unforced.\n"); + + /* all fds opened during the test in any way should be gone on exec */ + exec_fd_leak_check(data->nr_fds_begin); +} + +static void +do_os_wrappers_recvmsg_cloexec(int n) +{ + struct marshal_data data; + + data.nr_fds_begin = count_open_fds(); + data.wrapped_calls = n; + + setup_marshal_data(&data); + data.nr_fds_conn = count_open_fds(); + + assert(pipe(data.value.h) >= 0); + + data.value.h[2] = open("/dev/zero", O_RDONLY); + assert(data.value.h[2] >= 0); + + marshal_demarshal(&data, (void *) validate_recvmsg_h, + 8, "hhh", data.value.h[0], data.value.h[1], + data.value.h[2]); +} + +TEST(os_wrappers_recvmsg_cloexec) +{ + init_fallbacks(0); + do_os_wrappers_recvmsg_cloexec(0); +} + +TEST(os_wrappers_recvmsg_cloexec_fallback) +{ + init_fallbacks(1); + do_os_wrappers_recvmsg_cloexec(1); +}