diff --git a/include/jemalloc/internal/buf_writer.h b/include/jemalloc/internal/buf_writer.h index b64c966a..55b18ab2 100644 --- a/include/jemalloc/internal/buf_writer.h +++ b/include/jemalloc/internal/buf_writer.h @@ -27,4 +27,8 @@ void buf_writer_flush(buf_writer_t *buf_writer); write_cb_t buf_writer_cb; void buf_writer_terminate(tsdn_t *tsdn, buf_writer_t *buf_writer); +typedef ssize_t (read_cb_t)(void *read_cbopaque, void *buf, size_t limit); +void buf_writer_pipe(buf_writer_t *buf_writer, read_cb_t *read_cb, + void *read_cbopaque); + #endif /* JEMALLOC_INTERNAL_BUF_WRITER_H */ diff --git a/src/buf_writer.c b/src/buf_writer.c index fd0226a1..06a2735b 100644 --- a/src/buf_writer.c +++ b/src/buf_writer.c @@ -110,3 +110,36 @@ buf_writer_terminate(tsdn_t *tsdn, buf_writer_t *buf_writer) { buf_writer_free_internal_buf(tsdn, buf_writer->buf); } } + +void +buf_writer_pipe(buf_writer_t *buf_writer, read_cb_t *read_cb, + void *read_cbopaque) { + /* + * A tiny local buffer in case the buffered writer failed to allocate + * at init. + */ + static char backup_buf[16]; + static buf_writer_t backup_buf_writer; + + buf_writer_assert(buf_writer); + assert(read_cb != NULL); + if (buf_writer->buf == NULL) { + buf_writer_init(TSDN_NULL, &backup_buf_writer, + buf_writer->write_cb, buf_writer->cbopaque, backup_buf, + sizeof(backup_buf)); + buf_writer = &backup_buf_writer; + } + assert(buf_writer->buf != NULL); + ssize_t nread = 0; + do { + buf_writer->buf_end += nread; + buf_writer_assert(buf_writer); + if (buf_writer->buf_end == buf_writer->buf_size) { + buf_writer_flush(buf_writer); + } + nread = read_cb(read_cbopaque, + buf_writer->buf + buf_writer->buf_end, + buf_writer->buf_size - buf_writer->buf_end); + } while (nread > 0); + buf_writer_flush(buf_writer); +} diff --git a/test/unit/buf_writer.c b/test/unit/buf_writer.c index 821cf61f..d5e63a0e 100644 --- a/test/unit/buf_writer.c +++ b/test/unit/buf_writer.c @@ -119,10 +119,78 @@ TEST_BEGIN(test_buf_write_oom) { } TEST_END +static int test_read_count; +static size_t test_read_len; +static uint64_t arg_sum; + +ssize_t +test_read_cb(void *cbopaque, void *buf, size_t limit) { + static uint64_t rand = 4; + + arg_sum += *(uint64_t *)cbopaque; + assert_zu_gt(limit, 0, "Limit for read_cb must be positive"); + --test_read_count; + if (test_read_count == 0) { + return -1; + } else { + size_t read_len = limit; + if (limit > 1) { + rand = prng_range_u64(&rand, (uint64_t)limit); + read_len -= (size_t)rand; + } + assert(read_len > 0); + memset(buf, 'a', read_len); + size_t prev_test_read_len = test_read_len; + test_read_len += read_len; + assert_zu_le(prev_test_read_len, test_read_len, + "Test read overflowed"); + return read_len; + } +} + +static void +test_buf_writer_pipe_body(tsdn_t *tsdn, buf_writer_t *buf_writer) { + arg = 4; /* Starting value of random argument. */ + for (int count = 5; count > 0; --count) { + arg = prng_lg_range_u64(&arg, 64); + arg_sum = 0; + test_read_count = count; + test_read_len = 0; + test_write_len = 0; + buf_writer_pipe(buf_writer, test_read_cb, &arg); + assert(test_read_count == 0); + expect_u64_eq(arg_sum, arg * count, ""); + expect_zu_eq(test_write_len, test_read_len, + "Write length should be equal to read length"); + } + buf_writer_terminate(tsdn, buf_writer); +} + +TEST_BEGIN(test_buf_write_pipe) { + buf_writer_t buf_writer; + tsdn_t *tsdn = tsdn_fetch(); + assert_false(buf_writer_init(tsdn, &buf_writer, test_write_cb, &arg, + test_buf, TEST_BUF_SIZE), + "buf_writer_init() should not encounter error on static buffer"); + test_buf_writer_pipe_body(tsdn, &buf_writer); +} +TEST_END + +TEST_BEGIN(test_buf_write_pipe_oom) { + buf_writer_t buf_writer; + tsdn_t *tsdn = tsdn_fetch(); + assert_true(buf_writer_init(tsdn, &buf_writer, test_write_cb, &arg, + NULL, SC_LARGE_MAXCLASS + 1), "buf_writer_init() should OOM"); + test_buf_writer_pipe_body(tsdn, &buf_writer); +} +TEST_END + int main(void) { return test( test_buf_write_static, test_buf_write_dynamic, - test_buf_write_oom); + test_buf_write_oom, + test_buf_write_pipe, + test_buf_write_pipe_oom); }