diff --git a/spa/include/spa/utils/string.h b/spa/include/spa/utils/string.h index 728ba9f2e..393f272cb 100644 --- a/spa/include/spa/utils/string.h +++ b/spa/include/spa/utils/string.h @@ -29,6 +29,7 @@ extern "C" { #endif +#include #include #include @@ -178,6 +179,50 @@ static inline bool spa_atob(const char *str) return spa_streq(str, "true") || spa_streq(str, "1"); } +/** + * "Safe" version of vsnprintf. Exactly the same as vsnprintf but the + * returned value is clipped to `size - 1` and a negative or zero size + * will abort() the program. + * + * \return The number of bytes printed, capped to `size-1`, or a negative + * number on error. + */ +SPA_PRINTF_FUNC(3, 0) +static inline int spa_vscnprintf(char *buffer, size_t size, const char *format, va_list args) +{ + int r; + + spa_assert((ssize_t)size > 0); + + r = vsnprintf(buffer, size, format, args); + if (SPA_UNLIKELY(r < 0)) + buffer[0] = '\0'; + if (SPA_LIKELY(r < (ssize_t)size)) + return r; + return size - 1; +} + +/** + * "Safe" version of snprintf. Exactly the same as snprintf but the + * returned value is clipped to `size - 1` and a negative or zero size + * will abort() the program. + * + * \return The number of bytes printed, capped to `size-1`, or a negative + * number on error. + */ +SPA_PRINTF_FUNC(3, 4) +static inline int spa_scnprintf(char *buffer, size_t size, const char *format, ...) +{ + int r; + va_list args; + + va_start(args, format); + r = spa_vscnprintf(buffer, size, format, args); + va_end(args); + + return r; +} + /** * Convert \a str to a float and store the result in \a val. * diff --git a/spa/tests/test-utils.c b/spa/tests/test-utils.c index ef3245031..64a386c75 100644 --- a/spa/tests/test-utils.c +++ b/spa/tests/test-utils.c @@ -23,6 +23,8 @@ */ #include +#include +#include #include #include @@ -752,6 +754,65 @@ static void test_ansi(void) SPA_ANSI_ITALIC, SPA_ANSI_BOLD_YELLOW, SPA_ANSI_RESET); } +static void test_snprintf(void) +{ + char dest[8]; + pid_t pid; + int len; + + /* Basic printf */ + spa_assert(spa_scnprintf(dest, sizeof(dest), "foo%d%s", 10, "2") == 6); + spa_assert(spa_streq(dest, "foo102")); + /* Print a few strings, make sure dest is truncated and return value + * is the length of the returned string */ + spa_assert(spa_scnprintf(dest, sizeof(dest), "1234567") == 7); + spa_assert(spa_streq(dest, "1234567")); + spa_assert(spa_scnprintf(dest, sizeof(dest), "12345678") == 7); + spa_assert(spa_streq(dest, "1234567")); + spa_assert(spa_scnprintf(dest, sizeof(dest), "123456789") == 7); + spa_assert(spa_streq(dest, "1234567")); + /* Same as above, but with printf %s expansion */ + spa_assert(spa_scnprintf(dest, sizeof(dest), "%s", "1234567") == 7); + spa_assert(spa_streq(dest, "1234567")); + spa_assert(spa_scnprintf(dest, sizeof(dest), "%s", "12345678") == 7); + spa_assert(spa_streq(dest, "1234567")); + spa_assert(spa_scnprintf(dest, sizeof(dest), "%s", "123456789") == 7); + spa_assert(spa_streq(dest, "1234567")); + + spa_assert(spa_scnprintf(dest, 2, "1234567") == 1); + spa_assert(spa_streq(dest, "1")); + spa_assert(spa_scnprintf(dest, 1, "1234567") == 0); + spa_assert(spa_streq(dest, "")); + + /* Check for abort on negative/zero size */ + for (int i = -2; i <= 0; i++) { + pid = fork(); + if (pid == 0) { + close(STDOUT_FILENO); + close(STDERR_FILENO); + spa_assert(spa_scnprintf(dest, (size_t)i, "1234")); + exit(0); + } else { + int r; + int status; + + r = waitpid(pid, &status, 0); + spa_assert(r == pid); + spa_assert(WIFSIGNALED(status)); + spa_assert(WTERMSIG(status) == SIGABRT); + } + } + + /* The "append until buffer is full" use-case */ + len = 0; + while ((size_t)len < sizeof(dest) - 1) + len += spa_scnprintf(dest + len, sizeof(dest) - len, "123"); + /* and once more for good measure, this should print 0 characters */ + len = spa_scnprintf(dest + len, sizeof(dest) - len, "abc"); + spa_assert(len == 0); + spa_assert(spa_streq(dest, "1231231")); +} + int main(int argc, char *argv[]) { setlocale(LC_NUMERIC, "C"); /* For decimal number parsing */ @@ -768,6 +829,7 @@ int main(int argc, char *argv[]) test_strtof(); test_strtod(); test_streq(); + test_snprintf(); test_atob(); test_ansi(); return 0;