diff --git a/MODULE.bazel b/MODULE.bazel index 24d82bf..03ef393 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -12,13 +12,16 @@ bazel_dep(name = "googletest", version = "1.15.2", repo_name = "com_google_googl # Coroutines http_archive( name = "coroutines", - integrity = "sha256-PhOYq1eE8Q8UOhzDHq2+rafTU4VTt9fZ0ZZyNE+hWb4=", - strip_prefix = "co-2.1.0", - urls = ["https://github.com/dallison/co/archive/refs/tags/2.1.0.tar.gz"], + # integrity = "sha256-tRZ5YigZhxVJDBENpQwqL1Bch1VPuV+tgKSK+/72aTU=", + sha256 = "d38c2a2480f016ed44c8460b9b9c6436445211bf14b6b0947e7577ccb29ca9d4", + strip_prefix = "co-562a5a335570c8ab70d8d36f3e6d7e85ccccfd10", + # urls = ["https://github.com/dallison/co/archive/refs/tags/2.1.8.tar.gz"], + urls = ["https://github.com/dallison/co/archive/562a5a335570c8ab70d8d36f3e6d7e85ccccfd10.tar.gz"], ) # For local debugging of co coroutine library. # bazel_dep(name = "coroutines") # local_path_override( -# module_name = "coroutines", -# path = "../co", +# module_name = "coroutines", +# path = "../co", # ) + diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 7723de7..fc62159 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -13,13 +13,15 @@ "https://bcr.bazel.build/modules/abseil-cpp/20240722.0.bcr.1/MODULE.bazel": "c0aa5eaefff1121b40208397f229604c717bd2fdf214ff67586d627118e17720", "https://bcr.bazel.build/modules/abseil-cpp/20240722.0.bcr.1/source.json": "e067fdd217bacbe74c88a975434be5df0b44315a247be180f0e20f891715210c", "https://bcr.bazel.build/modules/apple_support/1.15.1/MODULE.bazel": "a0556fefca0b1bb2de8567b8827518f94db6a6e7e7d632b4c48dc5f865bc7c85", - "https://bcr.bazel.build/modules/apple_support/1.15.1/source.json": "517f2b77430084c541bc9be2db63fdcbb7102938c5f64c17ee60ffda2e5cf07b", + "https://bcr.bazel.build/modules/apple_support/1.23.1/MODULE.bazel": "53763fed456a968cf919b3240427cf3a9d5481ec5466abc9d5dc51bc70087442", + "https://bcr.bazel.build/modules/apple_support/1.23.1/source.json": "d888b44312eb0ad2c21a91d026753f330caa48a25c9b2102fae75eb2b0dcfdd2", "https://bcr.bazel.build/modules/bazel_features/1.1.1/MODULE.bazel": "27b8c79ef57efe08efccbd9dd6ef70d61b4798320b8d3c134fd571f78963dbcd", "https://bcr.bazel.build/modules/bazel_features/1.11.0/MODULE.bazel": "f9382337dd5a474c3b7d334c2f83e50b6eaedc284253334cf823044a26de03e8", "https://bcr.bazel.build/modules/bazel_features/1.15.0/MODULE.bazel": "d38ff6e517149dc509406aca0db3ad1efdd890a85e049585b7234d04238e2a4d", "https://bcr.bazel.build/modules/bazel_features/1.17.0/MODULE.bazel": "039de32d21b816b47bd42c778e0454217e9c9caac4a3cf8e15c7231ee3ddee4d", "https://bcr.bazel.build/modules/bazel_features/1.18.0/MODULE.bazel": "1be0ae2557ab3a72a57aeb31b29be347bcdc5d2b1eb1e70f39e3851a7e97041a", "https://bcr.bazel.build/modules/bazel_features/1.19.0/MODULE.bazel": "59adcdf28230d220f0067b1f435b8537dd033bfff8db21335ef9217919c7fb58", + "https://bcr.bazel.build/modules/bazel_features/1.27.0/MODULE.bazel": "621eeee06c4458a9121d1f104efb80f39d34deff4984e778359c60eaf1a8cb65", "https://bcr.bazel.build/modules/bazel_features/1.30.0/MODULE.bazel": "a14b62d05969a293b80257e72e597c2da7f717e1e69fa8b339703ed6731bec87", "https://bcr.bazel.build/modules/bazel_features/1.30.0/source.json": "b07e17f067fe4f69f90b03b36ef1e08fe0d1f3cac254c1241a1818773e3423bc", "https://bcr.bazel.build/modules/bazel_features/1.4.1/MODULE.bazel": "e45b6bb2350aff3e442ae1111c555e27eac1d915e77775f6fdc4b351b758b5d7", @@ -146,37 +148,6 @@ }, "selectedYankedVersions": {}, "moduleExtensions": { - "@@apple_support+//crosstool:setup.bzl%apple_cc_configure_extension": { - "general": { - "bzlTransitiveDigest": "E970FlMbwpgJPdPUQzatKh6BMfeE0ZpWABvwshh7Tmg=", - "usagesDigest": "aYRVMk+1OupIp+5hdBlpzT36qgd6ntgSxYTzMLW5K4U=", - "recordedFileInputs": {}, - "recordedDirentsInputs": {}, - "envVariables": {}, - "generatedRepoSpecs": { - "local_config_apple_cc_toolchains": { - "repoRuleId": "@@apple_support+//crosstool:setup.bzl%_apple_cc_autoconf_toolchains", - "attributes": {} - }, - "local_config_apple_cc": { - "repoRuleId": "@@apple_support+//crosstool:setup.bzl%_apple_cc_autoconf", - "attributes": {} - } - }, - "recordedRepoMappingEntries": [ - [ - "apple_support+", - "bazel_tools", - "bazel_tools" - ], - [ - "bazel_tools", - "rules_cc", - "rules_cc+" - ] - ] - } - }, "@@rules_kotlin+//src/main/starlark/core/repositories:bzlmod_setup.bzl%rules_kotlin_extensions": { "general": { "bzlTransitiveDigest": "OlvsB0HsvxbR8ZN+J9Vf00X/+WVz/Y/5Xrq2LgcVfdo=", diff --git a/toolbelt/BUILD.bazel b/toolbelt/BUILD.bazel index 081ede1..0448880 100644 --- a/toolbelt/BUILD.bazel +++ b/toolbelt/BUILD.bazel @@ -12,6 +12,7 @@ cc_library( "sockets.cc", "table.cc", "triggerfd.cc", + "stacktrace.cc", ], hdrs = [ "bitset.h", @@ -26,6 +27,7 @@ cc_library( "sockets.h", "table.h", "triggerfd.h", + "stacktrace.h", ], deps = [ "@com_google_absl//absl/status", @@ -33,7 +35,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@coroutines//:co", + "@coroutines//co", ], ) @@ -63,7 +65,19 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) - +cc_test( + name = "stacktrace_test", + size = "small", + srcs = ["stacktrace_test.cc"], + deps = [ + ":toolbelt", + "@com_google_absl//absl/hash:hash_testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) cc_test( name = "pipe_test", size = "small", @@ -86,6 +100,7 @@ cc_test( ":toolbelt", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", ], diff --git a/toolbelt/fd.cc b/toolbelt/fd.cc index e4a8440..a160646 100644 --- a/toolbelt/fd.cc +++ b/toolbelt/fd.cc @@ -16,7 +16,7 @@ void CloseAllFds(std::function predicate) { } absl::StatusOr FileDescriptor::Read(void *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { char *buf = reinterpret_cast(buffer); size_t total = 0; while (total < length) { @@ -58,7 +58,7 @@ absl::StatusOr FileDescriptor::Read(void *buffer, size_t length, } absl::StatusOr FileDescriptor::Write(const void *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { const char *buf = reinterpret_cast(buffer); size_t total = 0; diff --git a/toolbelt/fd.h b/toolbelt/fd.h index d89dc68..19619d8 100644 --- a/toolbelt/fd.h +++ b/toolbelt/fd.h @@ -18,7 +18,7 @@ #include #include #include -#include "coroutine.h" +#include "co/coroutine.h" namespace toolbelt { @@ -56,7 +56,8 @@ class FileDescriptor { FileDescriptor() = default; // FileDescriptor initialize with an OS fd. Takes ownership // of the fd and will close it when all references go away. - explicit FileDescriptor(int fd) : data_(std::make_shared(fd)) {} + // If owned is false, the fd will not be closed when all references go away. + explicit FileDescriptor(int fd, bool owned = true) : data_(std::make_shared(fd, owned)) {} // Copy constructor, increments reference on shared data. Very cheap. FileDescriptor(const FileDescriptor &f) : data_(f.data_) {} @@ -119,13 +120,13 @@ class FileDescriptor { // Sets the OS fd. If it's the same as the underlying OS fd, there is // no effect (that's not another reference to it). Allocates new // shared data for the fd. - void SetFd(int fd) { + void SetFd(int fd, bool owned = true) { if (Fd() == fd) { // SetFd with same fd. This isn't another reference to the // fd. return; } - data_ = std::make_shared(fd); + data_ = std::make_shared(fd, owned); } void Reset() { Close(); } @@ -187,23 +188,27 @@ class FileDescriptor { return absl::OkStatus(); } - absl::StatusOr Read(void* buffer, size_t length, co::Coroutine* c = nullptr); + absl::StatusOr Read(void* buffer, size_t length, const co::Coroutine* c = nullptr); absl::StatusOr Write(const void* buffer, size_t length, - co::Coroutine* c = nullptr); + const co::Coroutine* c = nullptr); private: // Reference counted OS fd, shared among all FileDescriptors with the // same OS fd, provided you don't create two FileDescriptors with the // same OS fd (that would be a mistake but there's no way to stop it). struct SharedData { SharedData() = default; - SharedData(int f) : fd(f) {} + SharedData(int f, bool o) : fd(f), owned(o) {} ~SharedData() { if (fd != -1) { + if (!owned) { + return; + } ::close(fd); } } int fd = -1; // OS file descriptor. bool nonblocking = false; + bool owned = true; }; // The actual shared data. If nullptr the FileDescriptor is invalid. diff --git a/toolbelt/fd_test.cc b/toolbelt/fd_test.cc index 79a9e01..be1cd26 100644 --- a/toolbelt/fd_test.cc +++ b/toolbelt/fd_test.cc @@ -173,3 +173,26 @@ TEST(FdTest, Reset) { int e = fstat(f, &st); ASSERT_EQ(-1, e); } + +TEST(FdTest, CreateUnowned) { + int f = dup(1); + { + FileDescriptor fd(f, false); + ASSERT_TRUE(fd.Valid()); + ASSERT_EQ(f, fd.Fd()); + ASSERT_EQ(1, fd.RefCount()); + } + // Fd should still be open. + ASSERT_EQ(0, fcntl(f, F_GETFD)); + + // Now take ownership of f. + { + FileDescriptor fd(f, true); + ASSERT_TRUE(fd.Valid()); + ASSERT_EQ(f, fd.Fd()); + ASSERT_EQ(1, fd.RefCount()); + } + // Will be closed. + ASSERT_EQ(-1, fcntl(f, F_GETFD)); + ASSERT_EQ(EBADF, errno); +} diff --git a/toolbelt/logging.cc b/toolbelt/logging.cc index 933d76e..00fcc3e 100644 --- a/toolbelt/logging.cc +++ b/toolbelt/logging.cc @@ -7,6 +7,7 @@ #include "clock.h" #include #include +#include namespace toolbelt { @@ -148,10 +149,19 @@ void Logger::VLog(LogLevel level, const char *fmt, va_list ap) { if (level < min_level_) { return; } +#if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wformat-nonliteral" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wformat-nonliteral" +#endif size_t n = vsnprintf(buffer_, sizeof(buffer_), fmt, ap); +#if defined(__clang__) #pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif // Strip final \n if present. Refactoring from printf can leave // this in place. diff --git a/toolbelt/logging.h b/toolbelt/logging.h index ca8f177..50b95a2 100644 --- a/toolbelt/logging.h +++ b/toolbelt/logging.h @@ -10,7 +10,6 @@ #include #include #include -#include #include #include diff --git a/toolbelt/payload_buffer.cc b/toolbelt/payload_buffer.cc index a45b116..21fe3f7 100644 --- a/toolbelt/payload_buffer.cc +++ b/toolbelt/payload_buffer.cc @@ -5,7 +5,7 @@ namespace toolbelt { static constexpr struct BitmapRunInfo { int num; - int size; + uint32_t size; } bitmp_run_infos[kNumBitmapRuns] = { {kRunSize1, kBitmapRunSize1}, {kRunSize2, kBitmapRunSize2}, @@ -162,7 +162,7 @@ void PayloadBuffer::Dump(std::ostream &os) { os << " free_list: " << free_list << " " << ToAddress(free_list) << std::endl; os << " message: " << message << " " << ToAddress(message) << std::endl; - for (int i = 0; i < kNumBitmapRuns; i++) { + for (size_t i = 0; i < kNumBitmapRuns; i++) { os << " bitmaps[" << i << "]: " << bitmaps[i] << " " << ToAddress(bitmaps[i]) << std::endl; } @@ -576,9 +576,9 @@ void *PayloadBuffer::Realloc(PayloadBuffer **buffer, void *p, uint32_t n, (n & kBitmapRunSizeMask); *len_ptr = encoded_size; - if (clear && n > decoded_length) { + if (clear && n > static_cast(decoded_length)) { memset(reinterpret_cast(p) + decoded_length, 0, - n - decoded_length); + n - static_cast(decoded_length)); } return p; } @@ -589,7 +589,7 @@ void *PayloadBuffer::Realloc(PayloadBuffer **buffer, void *p, uint32_t n, return NULL; } memcpy(newp, p, decoded_length); - if (clear && n > decoded_length) { + if (clear && n > static_cast(decoded_length)) { memset(reinterpret_cast(newp) + decoded_length, 0, n - decoded_length); } @@ -629,8 +629,8 @@ void *PayloadBuffer::Realloc(PayloadBuffer **buffer, void *p, uint32_t n, int diff = n - orig_length; if (alloc_addr + orig_length == free_addr) { // There is a free block above. See if has enough space. - if (free_block->length > diff) { - ssize_t freelen = free_block->length - diff; + if (free_block->length > static_cast(diff)) { + uint32_t freelen = free_block->length - static_cast(diff); if (freelen > sizeof(FreeBlockHeader)) { (*buffer)->ExpandIntoFreeBlockAbove(free_block, n, diff, freelen, len_ptr, next_ptr, clear); @@ -642,7 +642,7 @@ void *PayloadBuffer::Realloc(PayloadBuffer **buffer, void *p, uint32_t n, if (prev != NULL) { uintptr_t prev_addr = (uintptr_t)prev; if (prev_addr + prev->length == (uintptr_t)alloc_block && - prev->length >= diff) { + prev->length >= static_cast(diff)) { // Previous free block is adjacent and has enough space in it. // Use start of new block as new address and place FreeBlockHeader // at newly free part. diff --git a/toolbelt/payload_buffer.h b/toolbelt/payload_buffer.h index 5c56cfa..ba8c9f0 100644 --- a/toolbelt/payload_buffer.h +++ b/toolbelt/payload_buffer.h @@ -144,7 +144,7 @@ struct PayloadBuffer { PayloadBuffer(uint32_t size, bool bitmap_allocator = true) : magic(kFixedBufferMagic | (bitmap_allocator ? kBitMapFlag : 0)), message(0), hwm(0), full_size(size), metadata(0) { - for (int i = 0; i < kNumBitmapRuns; i++) { + for (size_t i = 0; i < kNumBitmapRuns; i++) { bitmaps[i] = 0; } InitFreeList(); @@ -163,7 +163,7 @@ struct PayloadBuffer { PayloadBuffer(uint32_t initial_size, Resizer r, bool bitmap_allocator = true) : magic(kMovableBufferMagic | (bitmap_allocator ? kBitMapFlag : 0)), message(0), hwm(0), full_size(initial_size), metadata(0) { - for (int i = 0; i < kNumBitmapRuns; i++) { + for (size_t i = 0; i < kNumBitmapRuns; i++) { bitmaps[i] = 0; } InitFreeList(); diff --git a/toolbelt/payload_buffer_test.cc b/toolbelt/payload_buffer_test.cc index e33e52e..b699966 100644 --- a/toolbelt/payload_buffer_test.cc +++ b/toolbelt/payload_buffer_test.cc @@ -1,6 +1,7 @@ #include "toolbelt/clock.h" #include "toolbelt/hexdump.h" #include "toolbelt/payload_buffer.h" +#include #include #include @@ -152,13 +153,13 @@ TEST(BufferTest, SmallBlockAllocFree) { blocks.push_back(addr); } // Free every 5th block. - for (int i = 0; i < blocks.size(); i++) { + for (size_t i = 0; i < blocks.size(); i++) { if (i % 5 == 0) { pb->Free(blocks[i]); } } // Now allocate every 5th block again. - for (int i = 0; i < blocks.size(); i++) { + for (size_t i = 0; i < blocks.size(); i++) { if (i % 5 == 0) { size_t size = sizes[i % sizes.size()]; void *addr = PayloadBuffer::Allocate(&pb, size); @@ -314,7 +315,7 @@ TEST(BufferTest, TypicalPerformance) { small_blocks.push_back(addr); } // Free some of the blocks. - for (int i = prev_size; i < small_blocks.size(); i++) { + for (size_t i = prev_size; i < small_blocks.size(); i++) { if (i % 8 == 0) { continue; } @@ -341,7 +342,7 @@ TEST(BufferTest, TypicalPerformance) { large_blocks.push_back(addr); } // Free some of the blocks. - for (int i = prev_size; i < large_blocks.size(); i++) { + for (size_t i = prev_size; i < large_blocks.size(); i++) { if (i % 8 == 0) { continue; } diff --git a/toolbelt/pipe.cc b/toolbelt/pipe.cc index 82c89cb..0cb8312 100644 --- a/toolbelt/pipe.cc +++ b/toolbelt/pipe.cc @@ -83,7 +83,7 @@ absl::Status Pipe::SetPipeSize(size_t size) { } absl::StatusOr Pipe::Read(char *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { size_t total = 0; ScopedRead sc(*this, c); @@ -123,7 +123,7 @@ absl::StatusOr Pipe::Read(char *buffer, size_t length, } absl::StatusOr Pipe::Write(const char *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { size_t total = 0; ScopedWrite sc(*this, c); diff --git a/toolbelt/pipe.h b/toolbelt/pipe.h index 25996db..f05fb20 100644 --- a/toolbelt/pipe.h +++ b/toolbelt/pipe.h @@ -3,7 +3,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "coroutine.h" +#include "co/coroutine.h" #include "toolbelt/fd.h" #include @@ -66,9 +66,9 @@ class Pipe { absl::Status SetPipeSize(size_t size); virtual absl::StatusOr Read(char *buffer, size_t length, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); virtual absl::StatusOr Write(const char *buffer, size_t length, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); protected: // RAII classes for keeping coroutines from interleaving reads or writes on a @@ -78,7 +78,7 @@ class Pipe { // // Same applies to non-coroutine use except we block with a sleep. struct ScopedRead { - ScopedRead(Pipe &p, co::Coroutine *c) : pipe(p) { + ScopedRead(Pipe &p, const co::Coroutine *c) : pipe(p) { while (pipe.read_in_progress_) { if (c) { c->Yield(); @@ -97,7 +97,7 @@ class Pipe { }; struct ScopedWrite { - ScopedWrite(Pipe &p, co::Coroutine *c) : pipe(p) { + ScopedWrite(Pipe &p, const co::Coroutine *c) : pipe(p) { while (pipe.write_in_progress_) { if (c) { c->Yield(); @@ -151,15 +151,15 @@ template class SharedPtrPipe : public Pipe { // You can't use raw buffers with shared ptr pipes. absl::StatusOr Read(char *buffer, size_t length, - co::Coroutine *c = nullptr) override { + const co::Coroutine *c = nullptr) override { return absl::InternalError("Not supported on SharedPtrPipe"); } absl::StatusOr Write(const char *buffer, size_t length, - co::Coroutine *c = nullptr) override { + const co::Coroutine *c = nullptr) override { return absl::InternalError("Not supported on SharedPtrPipe"); } - absl::StatusOr> Read(co::Coroutine *c = nullptr) { + absl::StatusOr> Read(const co::Coroutine *c = nullptr) { char buffer[sizeof(std::shared_ptr)]; size_t length = sizeof(buffer); size_t total = 0; @@ -205,7 +205,7 @@ template class SharedPtrPipe : public Pipe { } // This makes the pipe an owner of the pointer. - absl::Status Write(std::shared_ptr p, co::Coroutine *c = nullptr) { + absl::Status Write(std::shared_ptr p, const co::Coroutine *c = nullptr) { // On entry, ref count for p = N char buffer[sizeof(std::shared_ptr)]; @@ -272,4 +272,4 @@ template class SharedPtrPipe : public Pipe { } }; -} // namespace toolbelt \ No newline at end of file +} // namespace toolbelt diff --git a/toolbelt/pipe_test.cc b/toolbelt/pipe_test.cc index 097160c..d0d5811 100644 --- a/toolbelt/pipe_test.cc +++ b/toolbelt/pipe_test.cc @@ -3,7 +3,7 @@ // See LICENSE file for licensing information. #include "absl/status/status_matchers.h" -#include "coroutine.h" +#include "co/coroutine.h" #include "pipe.h" #include @@ -387,4 +387,4 @@ TEST(PipeTest, CoroutineOverFullPipeReadAndWriteMultiwriterNonblocking) { } }); scheduler.Run(); -} \ No newline at end of file +} diff --git a/toolbelt/sockets.cc b/toolbelt/sockets.cc index 27905c8..2010e4a 100644 --- a/toolbelt/sockets.cc +++ b/toolbelt/sockets.cc @@ -33,19 +33,19 @@ InetAddress InetAddress::AnyAddress(int port) { return InetAddress(port); } InetAddress::InetAddress(const in_addr &ip, int port) { valid_ = true; addr_ = { -#if defined(_APPLE__) - .sin_len = sizeof(int), +#if defined(__APPLE__) + .sin_len = sizeof(struct sockaddr_in), #endif .sin_family = AF_INET, .sin_port = htons(port), - .sin_addr = {.s_addr = htonl(ip.s_addr)}}; + .sin_addr = {.s_addr = ip.s_addr}}; } InetAddress::InetAddress(int port) { valid_ = true; addr_ = { -#if defined(_APPLE__) - .sin_len = sizeof(int), +#if defined(__APPLE__) + .sin_len = sizeof(struct sockaddr_in), #endif .sin_family = AF_INET, .sin_port = htons(port), @@ -68,8 +68,8 @@ InetAddress::InetAddress(const std::string &hostname, int port) { } valid_ = true; addr_ = { -#if defined(_APPLE__) - .sin_len = sizeof(int), +#if defined(__APPLE__) + .sin_len = sizeof(struct sockaddr_in), #endif .sin_family = AF_INET, .sin_port = htons(port), @@ -88,7 +88,7 @@ VirtualAddress::VirtualAddress(uint32_t cid, uint32_t port) { valid_ = true; memset(&addr_, 0, sizeof(addr_)); addr_ = { -#if defined(_APPLE__) +#if defined(__APPLE__) .svm_len = sizeof(struct sockaddr_vm), #endif .svm_family = AF_VSOCK, @@ -100,7 +100,7 @@ VirtualAddress::VirtualAddress(uint32_t port) { valid_ = true; memset(&addr_, 0, sizeof(addr_)); addr_ = { -#if defined(_APPLE__) +#if defined(__APPLE__) .svm_len = sizeof(struct sockaddr_vm), #endif .svm_family = AF_VSOCK, @@ -131,14 +131,17 @@ std::string VirtualAddress::ToString() const { return absl::StrFormat("%d:%d", addr_.svm_cid, addr_.svm_port); } -static ssize_t ReceiveFully(co::Coroutine *c, int fd, size_t length, +static ssize_t ReceiveFully(const co::Coroutine *c, int fd, size_t length, char *buffer, size_t buflen) { int offset = 0; size_t remaining = length; while (remaining > 0) { size_t readlen = std::min(remaining, buflen); if (c != nullptr) { - c->Wait(fd, POLLIN); + int f = c->Wait(fd, POLLIN); + if (f != fd) { + return -1; + } } ssize_t n = ::recv(fd, buffer + offset, readlen, 0); if (n == -1) { @@ -164,7 +167,7 @@ static ssize_t ReceiveFully(co::Coroutine *c, int fd, size_t length, return length; } -static ssize_t SendFully(co::Coroutine *c, int fd, const char *buffer, +static ssize_t SendFully(const co::Coroutine *c, int fd, const char *buffer, size_t length, bool blocking) { size_t remaining = length; size_t offset = 0; @@ -176,7 +179,10 @@ static ssize_t SendFully(co::Coroutine *c, int fd, const char *buffer, // Yielding before sending to a nonblocking socket will // cause a context switch between coroutines and we want // the write to the network to be as fast as possible. - c->Wait(fd, POLLOUT); + int f = c->Wait(fd, POLLOUT); + if (f != fd) { + return -1; + } } ssize_t n = ::send(fd, buffer + offset, remaining, 0); if (n == -1) { @@ -190,7 +196,10 @@ static ssize_t SendFully(co::Coroutine *c, int fd, const char *buffer, // If we are nonblocking yield the coroutine now. When we // are resumed we can write to the socket again. if (!blocking) { - c->Wait(fd, POLLOUT); + int f = c->Wait(fd, POLLOUT); + if (f != fd) { + return -1; + } } continue; } @@ -207,7 +216,7 @@ static ssize_t SendFully(co::Coroutine *c, int fd, const char *buffer, } absl::StatusOr Socket::Receive(char *buffer, size_t buflen, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -221,7 +230,7 @@ absl::StatusOr Socket::Receive(char *buffer, size_t buflen, } absl::StatusOr Socket::Send(const char *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -235,7 +244,7 @@ absl::StatusOr Socket::Send(const char *buffer, size_t length, } absl::StatusOr Socket::ReceiveMessage(char *buffer, size_t buflen, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -268,7 +277,7 @@ absl::StatusOr Socket::ReceiveMessage(char *buffer, size_t buflen, } absl::StatusOr> -Socket::ReceiveVariableLengthMessage(co::Coroutine *c) { +Socket::ReceiveVariableLengthMessage(const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -303,7 +312,7 @@ Socket::ReceiveVariableLengthMessage(co::Coroutine *c) { } absl::StatusOr Socket::SendMessage(char *buffer, size_t length, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -371,12 +380,15 @@ absl::Status UnixSocket::Bind(const std::string &pathname, bool listen) { return absl::OkStatus(); } -absl::StatusOr UnixSocket::Accept(co::Coroutine *c) const { +absl::StatusOr UnixSocket::Accept(const co::Coroutine *c) const { if (!fd_.Valid()) { return absl::InternalError("UnixSocket is not valid"); } if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } struct sockaddr_un sender; socklen_t sock_len = sizeof(sender); @@ -419,7 +431,7 @@ absl::Status UnixSocket::Connect(const std::string &pathname) { } absl::Status UnixSocket::SendFds(const std::vector &fds, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -459,7 +471,10 @@ absl::Status UnixSocket::SendFds(const std::vector &fds, } if (c != nullptr) { - c->Wait(fd_.Fd(), POLLOUT); + int fd = c->Wait(fd_.Fd(), POLLOUT); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } int e = ::sendmsg(fd_.Fd(), &msg, 0); if (e == -1) { @@ -473,7 +488,7 @@ absl::Status UnixSocket::SendFds(const std::vector &fds, } absl::Status UnixSocket::ReceiveFds(std::vector &fds, - co::Coroutine *c) { + const co::Coroutine *c) { if (!Connected()) { return absl::InternalError("Socket is not connected"); } @@ -498,7 +513,10 @@ absl::Status UnixSocket::ReceiveFds(std::vector &fds, .msg_controllen = sizeof(u.buf)}; if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0); if (n == -1) { @@ -634,12 +652,15 @@ absl::Status TCPSocket::Bind(const InetAddress &addr, bool listen) { return absl::OkStatus(); } -absl::StatusOr TCPSocket::Accept(co::Coroutine *c) const { +absl::StatusOr TCPSocket::Accept(const co::Coroutine *c) const { if (!fd_.Valid()) { return absl::InternalError("Socket is not valid"); } if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } struct sockaddr_in sender; socklen_t sock_len = sizeof(sender); @@ -776,9 +797,12 @@ absl::Status UDPSocket::SetMulticastLoop() { } absl::Status UDPSocket::SendTo(const InetAddress &addr, const void *buffer, - size_t length, co::Coroutine *c) { + size_t length, const co::Coroutine *c) { if (c != nullptr) { - c->Wait(fd_.Fd(), POLLOUT); + int fd = c->Wait(fd_.Fd(), POLLOUT); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } ssize_t n = ::sendto(fd_.Fd(), buffer, length, 0, reinterpret_cast(&addr.GetAddress()), @@ -792,9 +816,12 @@ absl::Status UDPSocket::SendTo(const InetAddress &addr, const void *buffer, } absl::StatusOr UDPSocket::Receive(void *buffer, size_t buflen, - co::Coroutine *c) { + const co::Coroutine *c) { if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } ssize_t n = recv(fd_.Fd(), buffer, buflen, 0); if (n == -1) { @@ -805,9 +832,12 @@ absl::StatusOr UDPSocket::Receive(void *buffer, size_t buflen, } absl::StatusOr UDPSocket::ReceiveFrom(InetAddress &sender, void *buffer, size_t buflen, - co::Coroutine *c) { + const co::Coroutine *c) { if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } struct sockaddr_in sender_addr; socklen_t sender_addr_length = sizeof(sender_addr); @@ -819,6 +849,9 @@ absl::StatusOr UDPSocket::ReceiveFrom(InetAddress &sender, return absl::InternalError( absl::StrFormat("Unable to receive UDP datagram: %s", strerror(errno))); } +#if defined(__APPLE__) + sender_addr.sin_len = sender_addr_length; +#endif sender = {sender_addr}; return n; } @@ -859,12 +892,15 @@ absl::Status VirtualStreamSocket::Bind(const VirtualAddress &addr, } absl::StatusOr -VirtualStreamSocket::Accept(co::Coroutine *c) const { +VirtualStreamSocket::Accept(const co::Coroutine *c) const { if (!fd_.Valid()) { return absl::InternalError("Socket is not valid"); } if (c != nullptr) { - c->Wait(fd_.Fd(), POLLIN); + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } struct sockaddr_vm sender; socklen_t sock_len = sizeof(sender); diff --git a/toolbelt/sockets.h b/toolbelt/sockets.h index 795c1e3..1127e42 100644 --- a/toolbelt/sockets.h +++ b/toolbelt/sockets.h @@ -6,7 +6,7 @@ #define __TOOLBELT_SOCKETS_H #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "coroutine.h" +#include "co/coroutine.h" #include "fd.h" #include #include @@ -53,6 +53,10 @@ struct sockaddr_vm { #define AF_VSOCK 40 #endif +#if !defined(VMADDR_CID_LOCAL) +#define VMADDR_CID_LOCAL 3 +#endif + #include #include #include @@ -110,7 +114,7 @@ class InetAddress { static InetAddress AnyAddress(int port); private: - struct sockaddr_in addr_; // In network byte order. + struct sockaddr_in addr_ = {}; // In network byte order. bool valid_ = false; }; @@ -348,17 +352,17 @@ class Socket { // Send and receive raw buffers. absl::StatusOr Receive(char *buffer, size_t buflen, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::StatusOr Send(const char *buffer, size_t length, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); // Send and receive length-delimited message. The length is a 4-byte // network byte order (big endian) int as the first 4 bytes and // contains the length of the message. absl::StatusOr ReceiveMessage(char *buffer, size_t buflen, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::StatusOr> - ReceiveVariableLengthMessage(co::Coroutine *c = nullptr); + ReceiveVariableLengthMessage(const co::Coroutine *c = nullptr); // For SendMessage, the buffer pointer must be 4 bytes beyond // the actual buffer start, which must be length+4 bytes @@ -366,7 +370,7 @@ class Socket { // at buffer-4. This is to allow us to do a single send // to the socket rather than splitting it into 2. absl::StatusOr SendMessage(char *buffer, size_t length, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::Status SetNonBlocking() { if (absl::Status s = fd_.SetNonBlocking(); !s.ok()) { @@ -399,12 +403,12 @@ class UnixSocket : public Socket { absl::Status Bind(const std::string &pathname, bool listen); absl::Status Connect(const std::string &pathname); - absl::StatusOr Accept(co::Coroutine *c = nullptr) const; + absl::StatusOr Accept(const co::Coroutine *c = nullptr) const; absl::Status SendFds(const std::vector &fds, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::Status ReceiveFds(std::vector &fds, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); std::string BoundAddress() const { return bound_address_; } absl::StatusOr GetPeerName() const; @@ -449,12 +453,12 @@ class UDPSocket : public NetworkSocket { // NOTE: Read and Write may or may not work on UDP sockets. Use SendTo and // Receive for datagrams. absl::Status SendTo(const InetAddress &addr, const void *buffer, - size_t length, co::Coroutine *c = nullptr); + size_t length, const co::Coroutine *c = nullptr); absl::StatusOr Receive(void *buffer, size_t buflen, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::StatusOr ReceiveFrom(InetAddress &sender, void *buffer, size_t buflen, - co::Coroutine *c = nullptr); + const co::Coroutine *c = nullptr); absl::Status SetBroadcast(); absl::Status SetMulticastLoop(); }; @@ -468,7 +472,7 @@ class TCPSocket : public NetworkSocket { absl::Status Bind(const InetAddress &addr, bool listen); - absl::StatusOr Accept(co::Coroutine *c = nullptr) const; + absl::StatusOr Accept(const co::Coroutine *c = nullptr) const; absl::StatusOr LocalAddress(int port) const; @@ -485,7 +489,7 @@ class VirtualStreamSocket : public Socket { absl::Status Bind(const VirtualAddress &addr, bool listen); - absl::StatusOr Accept(co::Coroutine *c = nullptr) const; + absl::StatusOr Accept(const co::Coroutine *c = nullptr) const; absl::StatusOr LocalAddress(uint32_t port) const; const VirtualAddress &BoundAddress() const { return bound_address_; } @@ -546,7 +550,7 @@ class StreamSocket { return absl::Status(absl::StatusCode::kInternal, "Invalid socket address"); } - absl::StatusOr Accept(co::Coroutine *c = nullptr) const { + absl::StatusOr Accept(const co::Coroutine *c = nullptr) const { return std::visit( EyeOfNewt{ [&](const TCPSocket &s) mutable -> absl::StatusOr { @@ -612,7 +616,7 @@ class StreamSocket { // Send and receive raw buffers. absl::StatusOr Receive(char *buffer, size_t buflen, - co::Coroutine *c = nullptr) { + const co::Coroutine *c = nullptr) { return std::visit( EyeOfNewt{[&](TCPSocket &s) { return s.Receive(buffer, buflen, c); }, [&](VirtualStreamSocket &s) { @@ -623,7 +627,7 @@ class StreamSocket { } absl::StatusOr Send(const char *buffer, size_t length, - co::Coroutine *c = nullptr) { + const co::Coroutine *c = nullptr) { return std::visit( EyeOfNewt{ [&](TCPSocket &s) { return s.Send(buffer, length, c); }, @@ -636,7 +640,7 @@ class StreamSocket { // network byte order (big endian) int as the first 4 bytes and // contains the length of the message. absl::StatusOr ReceiveMessage(char *buffer, size_t buflen, - co::Coroutine *c = nullptr) { + const co::Coroutine *c = nullptr) { return std::visit( EyeOfNewt{ [&](TCPSocket &s) { return s.ReceiveMessage(buffer, buflen, c); }, @@ -648,7 +652,7 @@ class StreamSocket { } absl::StatusOr> - ReceiveVariableLengthMessage(co::Coroutine *c = nullptr) { + ReceiveVariableLengthMessage(const co::Coroutine *c = nullptr) { return std::visit( EyeOfNewt{ [&](TCPSocket &s) { return s.ReceiveVariableLengthMessage(c); }, @@ -665,7 +669,7 @@ class StreamSocket { // at buffer-4. This is to allow us to do a single send // to the socket rather than splitting it into 2. absl::StatusOr SendMessage(char *buffer, size_t length, - co::Coroutine *c = nullptr) { + const co::Coroutine *c = nullptr) { return std::visit( EyeOfNewt{ [&](TCPSocket &s) { return s.SendMessage(buffer, length, c); }, diff --git a/toolbelt/sockets_test.cc b/toolbelt/sockets_test.cc index 476c89a..6618700 100644 --- a/toolbelt/sockets_test.cc +++ b/toolbelt/sockets_test.cc @@ -4,13 +4,31 @@ #include #include #include +#include +#include "absl/status/status_matchers.h" +#include "toolbelt/hexdump.h" + +#define VAR(a) a##__COUNTER__ +#define EVAL_AND_ASSERT_OK(expr) EVAL_AND_ASSERT_OK2(VAR(r_), expr) + +#define EVAL_AND_ASSERT_OK2(result, expr) \ + ({ \ + auto result = (expr); \ + if (!result.ok()) { \ + std::cerr << result.status() << std::endl; \ + } \ + ASSERT_OK(result); \ + std::move(*result); \ + }) + +#define ASSERT_OK(e) ASSERT_THAT(e, ::absl_testing::IsOk()) namespace { constexpr std::string_view TEST_DATA = "The quick brown fox jumped over the lazy dog."; const static absl::Duration LOOPBACK_TIMEOUT = absl::Milliseconds(10); // Test class to hold on to a randomly assigned unused port until destruction -// Any tests binding to this unused port will probably need to call NetworkSocket::SetReusePort +// Any tests Binding to this unused port will probably need to call NetworkSocket::SetReusePort class UnusedPort { public: UnusedPort() { @@ -31,42 +49,491 @@ class UnusedPort { }; } +TEST(SocketsTest, InetAddresses) { + toolbelt::InetAddress addr1; + ASSERT_FALSE(addr1.Valid()); + + toolbelt::InetAddress addr2 = toolbelt::InetAddress::BroadcastAddress(1234); + ASSERT_TRUE(addr2.Valid()); + ASSERT_EQ(1234, addr2.Port()); + + toolbelt::InetAddress addr3 = toolbelt::InetAddress::AnyAddress(4321); + ASSERT_EQ(4321, addr3.Port()); + + toolbelt::InetAddress local_ip = toolbelt::InetAddress("127.0.0.1", 1111); + ASSERT_EQ(1111, local_ip.Port()); + ASSERT_EQ("127.0.0.1:1111", local_ip.ToString()); + + toolbelt::InetAddress local_host = toolbelt::InetAddress("localhost", 2222); + ASSERT_EQ(2222, local_host.Port()); + ASSERT_EQ("127.0.0.1:2222", local_host.ToString()); + + toolbelt::InetAddress bad = toolbelt::InetAddress("foobardoesntexist", 2222); + ASSERT_FALSE(bad.Valid()); + + in_addr ipaddr; + ASSERT_EQ(1, inet_pton(AF_INET, "127.0.0.1", &ipaddr.s_addr)); + toolbelt::InetAddress local_in = toolbelt::InetAddress(ipaddr, 3333); + ASSERT_EQ(3333, local_in.Port()); + ASSERT_EQ("127.0.0.1:3333", local_in.ToString()); +} + +TEST(SocketsTest, UnixSocket) { + char tmp[] = "/tmp/socketsXXXXXX"; + int fd = mkstemp(tmp); + ASSERT_NE(-1, fd); + std::string socket_name = tmp; + close(fd); + + unlink(socket_name.c_str()); + co::CoroutineScheduler scheduler; + + toolbelt::UnixSocket listener; + absl::Status status = listener.Bind(socket_name, true); + std::cerr << status << std::endl; + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + char buffer[256]; + // ReceiveMessage uses the 4 bytes below the buffer for the length. + absl::StatusOr nbytes = socket.ReceiveMessage(buffer + 4, sizeof(buffer) - 4, c); + ASSERT_TRUE(nbytes.ok()); + auto n = nbytes.value(); + ASSERT_EQ(12, n); // "hello world\0" + ASSERT_EQ("hello world", std::string(buffer + 4, n - 1)); + std::vector fds; + + absl::Status s2 = socket.ReceiveFds(fds, c); + ASSERT_TRUE(s2.ok()); + ASSERT_EQ(3, fds.size()); + }); + + co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) { + toolbelt::UnixSocket socket; + absl::Status s = socket.Connect(socket_name); + ASSERT_TRUE(s.ok()); + char buffer[256]; + // SendMessage uses the 4 bytes below the buffer for the length of the message. + ssize_t n = snprintf(buffer + 4, sizeof(buffer) - 4, "hello world"); + n += 1; // Include NUL at end. + absl::StatusOr nsent = socket.SendMessage(buffer + 4, n, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(n + 4, nsent.value()); + + std::vector fds; + for (int i = 0; i < 3; i++) { + // We dup the file descriptors to avoid closing stdout. + fds.push_back(toolbelt::FileDescriptor(dup(i))); + } + absl::Status s2 = socket.SendFds(fds, c); + ASSERT_TRUE(s2.ok()); + }); + + scheduler.Run(); + remove(socket_name.c_str()); +} + +TEST(SocketsTest, UnixSocketErrors) { + toolbelt::UnixSocket socket; + // Socket is inValid, all will fail. + ASSERT_FALSE(socket.Accept().ok()); + ASSERT_FALSE(socket.Connect("foobar").ok()); + std::vector fds; + ASSERT_FALSE(socket.SendFds(fds).ok()); + ASSERT_FALSE(socket.ReceiveFds(fds).ok()); +} + +TEST(SocketsTest, TCPSocket) { + toolbelt::InetAddress addr("127.0.0.1", 6502); + + co::CoroutineScheduler scheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + absl::StatusOr> b = socket.ReceiveVariableLengthMessage(c); + ASSERT_TRUE(b.ok()); + auto buf = b.value(); + ASSERT_EQ(12, buf.size()); // "hello world\0" + ASSERT_EQ("hello world", std::string(buf.data(), 11)); + }); + + co::Coroutine outgoing(scheduler, [&addr](co::Coroutine* c) { + toolbelt::TCPSocket socket; + absl::Status s = socket.Connect(addr); + ASSERT_TRUE(s.ok()); + char buffer[256]; + // SendMessage uses the 4 bytes below the buffer for the length of the message. + ssize_t n = snprintf(buffer + 4, sizeof(buffer) - 4, "hello world"); + n += 1; // Include NUL at end. + absl::StatusOr nsent = socket.SendMessage(buffer + 4, n, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(n + 4, nsent.value()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, BigTCPSocketNonblocking) { + toolbelt::InetAddress addr("127.0.0.1", 6502); + + co::CoroutineScheduler scheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + + constexpr size_t kBufferSize = 10 * 1024 * 1024; + co::Coroutine incoming(scheduler, [&listener, kBufferSize](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + ASSERT_OK(socket.SetNonBlocking()); + + absl::StatusOr> b = socket.ReceiveVariableLengthMessage(c); + ASSERT_TRUE(b.ok()); + auto buf = b.value(); + ASSERT_EQ(kBufferSize, buf.size()); + for (size_t i = 0; i < kBufferSize; i++) { + if (buf[i] != 'a' + ((i + 4) % 26)) { + std::cerr << "Mismatch at " << i << ": " << buf[i] << " != " << 'a' + (i % 26) + << "\n"; + } + ASSERT_EQ('a' + ((i + 4) % 26), buf[i]); + } + }); + + co::Coroutine outgoing(scheduler, [&addr](co::Coroutine* c) { + toolbelt::TCPSocket socket; + absl::Status s = socket.Connect(addr); + ASSERT_TRUE(s.ok()); + ASSERT_OK(socket.SetNonBlocking()); + std::vector buffer(kBufferSize + 4); + for (size_t i = 4; i < buffer.size(); i++) { + buffer[i] = 'a' + (i % 26); + } + absl::StatusOr nsent = + socket.SendMessage(buffer.data() + 4, buffer.size() - 4, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(buffer.size(), nsent.value()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, BigTCPSocketBlocking) { + toolbelt::InetAddress addr("127.0.0.1", 6502); + + co::CoroutineScheduler sendScheduler, ReceiveScheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + + constexpr size_t kBufferSize = 10 * 1024 * 1024; + co::Coroutine incoming( + sendScheduler, [&listener, kBufferSize](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + absl::StatusOr> b = socket.ReceiveVariableLengthMessage(c); + ASSERT_TRUE(b.ok()); + auto buf = b.value(); + ASSERT_EQ(kBufferSize, buf.size()); + for (size_t i = 0; i < kBufferSize; i++) { + if (buf[i] != 'a' + ((i + 4) % 26)) { + std::cerr << "Mismatch at " << i << ": " << buf[i] + << " != " << 'a' + (i % 26) << "\n"; + } + ASSERT_EQ('a' + ((i + 4) % 26), buf[i]); + } + }); + + co::Coroutine outgoing(ReceiveScheduler, [&addr](co::Coroutine* c) { + toolbelt::TCPSocket socket; + absl::Status s = socket.Connect(addr); + ASSERT_TRUE(s.ok()); + std::vector buffer(kBufferSize + 4); + for (size_t i = 4; i < buffer.size(); i++) { + buffer[i] = 'a' + (i % 26); + } + absl::StatusOr nsent = + socket.SendMessage(buffer.data() + 4, buffer.size() - 4, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(buffer.size(), nsent.value()); + }); + + std::thread sender([&sendScheduler]() { sendScheduler.Run(); }); + std::thread Receiver([&ReceiveScheduler]() { ReceiveScheduler.Run(); }); + sender.join(); + Receiver.join(); +} + +TEST(SocketsTest, TCPSocketInterrupt) { + // TODO(dave.allison): is there a way to pick an unused port? + toolbelt::InetAddress addr("127.0.0.1", 6502); + + co::CoroutineScheduler scheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming( + scheduler, + [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_FALSE(s.ok()); + }, + co::CoroutineOptions{.name = "foo", .interrupt_fd = scheduler.GetInterruptFd()}); + + co::Coroutine interrupt(scheduler, [](co::Coroutine* c) { + c->Yield(); + c->Scheduler().TriggerInterrupt(); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, TCPSocket2) { + // TODO(dave.allison): is there a way to pick an unused port? + toolbelt::InetAddress addr("127.0.0.1", 6502); + + co::CoroutineScheduler scheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + char buffer[256]; + absl::StatusOr nbytes = socket.Receive(buffer, 12, c); + ASSERT_TRUE(nbytes.ok()); + auto n = nbytes.value(); + ASSERT_EQ(12, n); // "hello world\0" + ASSERT_EQ("hello world", std::string(buffer, n - 1)); + std::vector fds; + }); + + co::Coroutine outgoing(scheduler, [&addr](co::Coroutine* c) { + toolbelt::TCPSocket socket; + absl::Status s = socket.Connect(addr); + ASSERT_TRUE(s.ok()); + char buffer[256]; + ssize_t n = snprintf(buffer, sizeof(buffer), "hello world"); + n += 1; // Include NUL at end. + absl::StatusOr nsent = socket.Send(buffer, n, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(n, nsent.value()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, TCPSocket3) { + toolbelt::InetAddress addr("127.0.0.1", 0); + + co::CoroutineScheduler scheduler; + + toolbelt::TCPSocket listener; + ASSERT_TRUE(listener.SetReuseAddr().ok()); + ASSERT_TRUE(listener.SetReusePort().ok()); + absl::Status status = listener.Bind(addr, true); + ASSERT_TRUE(status.ok()); + toolbelt::InetAddress baddr = listener.BoundAddress(); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + char buffer[256]; + absl::StatusOr nbytes = socket.Receive(buffer, 12, c); + ASSERT_TRUE(nbytes.ok()); + auto n = nbytes.value(); + ASSERT_EQ(12, n); // "hello world\0" + ASSERT_EQ("hello world", std::string(buffer, n - 1)); + std::vector fds; + }); + + co::Coroutine outgoing(scheduler, [&baddr](co::Coroutine* c) { + toolbelt::TCPSocket socket; + absl::Status s = socket.Connect(baddr); + ASSERT_TRUE(s.ok()); + char buffer[256]; + ssize_t n = snprintf(buffer, sizeof(buffer), "hello world"); + n += 1; // Include NUL at end. + absl::StatusOr nsent = socket.Send(buffer, n, c); + ASSERT_TRUE(nsent.ok()); + ASSERT_EQ(n, nsent.value()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, TCPSocketErrors) { + toolbelt::TCPSocket socket; + char buffer[256]; + + // Socket is not Connected. These will fail. + toolbelt::InetAddress goodAddr("localhost", 2222); + ASSERT_FALSE(socket.Connect(goodAddr).ok()); // Valid fd, but nothing is on this port. + ASSERT_FALSE(socket.Send(buffer, 1).ok()); + ASSERT_FALSE(socket.Receive(buffer, 1).ok()); + ASSERT_FALSE(socket.SendMessage(buffer, 1).ok()); + ASSERT_FALSE(socket.ReceiveMessage(buffer, 1).ok()); + + toolbelt::InetAddress badAddr = + toolbelt::InetAddress("foobardoesntexist", 2222); + ASSERT_FALSE(badAddr.Valid()); + ASSERT_FALSE(socket.Connect(badAddr).ok()); + + socket.Close(); + ASSERT_FALSE(socket.Connect(goodAddr).ok()); // InValid fd. +} + +TEST(SocketsTest, UDPSocket) { + // TODO(dave.allison): is there a way to pick an unused port? + toolbelt::InetAddress sender("127.0.0.1", 6502); + toolbelt::InetAddress Receiver("127.0.0.1", 6503); + + co::CoroutineScheduler scheduler; + + co::Coroutine incoming(scheduler, [&Receiver](co::Coroutine* c) { + toolbelt::UDPSocket socket; + absl::Status s1 = socket.Bind(Receiver); + ASSERT_TRUE(s1.ok()); + + char buffer[256]; + absl::StatusOr nbytes = socket.Receive(buffer, sizeof(buffer), c); + ASSERT_TRUE(nbytes.ok()); + auto n = nbytes.value(); + ASSERT_EQ(12, n); // "hello world\0" + ASSERT_EQ("hello world", std::string(buffer, n - 1)); + }); + + co::Coroutine outgoing(scheduler, [&sender, &Receiver](co::Coroutine* c) { + toolbelt::UDPSocket socket; + absl::Status s1 = socket.Bind(sender); + ASSERT_TRUE(s1.ok()); + + char buffer[256]; + ssize_t n = snprintf(buffer, sizeof(buffer), "hello world"); + n += 1; // Include NUL at end. + + absl::Status s2 = socket.SendTo(Receiver, buffer, n, c); + ASSERT_TRUE(s2.ok()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, UDPSocket2) { + // TODO(dave.allison): is there a way to pick an unused port? + toolbelt::InetAddress sender("127.0.0.1", 6502); + toolbelt::InetAddress receiver("127.0.0.1", 6503); + + co::CoroutineScheduler scheduler; + + co::Coroutine incoming(scheduler, [&receiver, &sender](co::Coroutine* c) { + toolbelt::UDPSocket socket; + absl::Status s1 = socket.Bind(receiver); + ASSERT_TRUE(s1.ok()); + + char buffer[256]; + toolbelt::InetAddress from; + absl::StatusOr nbytes = socket.ReceiveFrom(from, buffer, sizeof(buffer), c); + ASSERT_TRUE(nbytes.ok()); + auto n = nbytes.value(); + ASSERT_EQ(12, n); // "hello world\0" + ASSERT_EQ("hello world", std::string(buffer, n - 1)); + ASSERT_EQ(sender, from); + }); + + co::Coroutine outgoing(scheduler, [&sender, &receiver](co::Coroutine* c) { + toolbelt::UDPSocket socket; + absl::Status s1 = socket.Bind(sender); + ASSERT_TRUE(s1.ok()); + + char buffer[256]; + ssize_t n = snprintf(buffer, sizeof(buffer), "hello world"); + n += 1; // Include NUL at end. + + absl::Status s2 = socket.SendTo(receiver, buffer, n, c); + ASSERT_TRUE(s2.ok()); + }); + + scheduler.Run(); +} + +TEST(SocketsTest, UDPSocketBroadcast) { + toolbelt::UDPSocket socket; + ASSERT_TRUE(socket.SetBroadcast().ok()); +} + +TEST(SocketsTest, InValidAsString) { + toolbelt::InetAddress addr; + ASSERT_FALSE(addr.Valid()); + EXPECT_EQ(addr.ToString(), "0.0.0.0:0"); +} + + TEST(SocketsTest, UDPSocket_SendAndReceiveUnicast) { UnusedPort port; auto sender = toolbelt::UDPSocket(); - auto receiver = toolbelt::UDPSocket(); + auto Receiver = toolbelt::UDPSocket(); - ASSERT_TRUE(receiver.SetReusePort().ok()); - ASSERT_TRUE(receiver.Bind(toolbelt::InetAddress("localhost", port)).ok()); + ASSERT_TRUE(Receiver.SetReusePort().ok()); + ASSERT_TRUE(Receiver.Bind(toolbelt::InetAddress("localhost", port)).ok()); toolbelt::InetAddress sendto_address("localhost", port); ASSERT_TRUE(sender.SendTo(sendto_address, TEST_DATA.data(), TEST_DATA.size()).ok()); - std::vector receive_buffer(TEST_DATA.size()); - ASSERT_EQ(*receiver.Receive(receive_buffer.data(), receive_buffer.size()), TEST_DATA.size()); - ASSERT_EQ(std::string_view(receive_buffer.data(), receive_buffer.size()), TEST_DATA); + std::vector Receive_buffer(TEST_DATA.size()); + ASSERT_EQ(*Receiver.Receive(Receive_buffer.data(), Receive_buffer.size()), TEST_DATA.size()); + ASSERT_EQ(std::string_view(Receive_buffer.data(), Receive_buffer.size()), TEST_DATA); - ASSERT_EQ(0, std::strcmp(receive_buffer.data(), TEST_DATA.data())); + ASSERT_EQ(0, std::strcmp(Receive_buffer.data(), TEST_DATA.data())); } TEST(SocketsTest, UDPSocket_SendAndReceiveBroadcast) { UnusedPort port; auto sender = toolbelt::UDPSocket(); - auto receiver = toolbelt::UDPSocket(); + auto Receiver = toolbelt::UDPSocket(); ASSERT_TRUE(sender.SetBroadcast().ok()); - ASSERT_TRUE(receiver.SetReusePort().ok()); - ASSERT_TRUE(receiver.Bind(toolbelt::InetAddress(toolbelt::InetAddress::AnyAddress(port))).ok()); + ASSERT_TRUE(Receiver.SetReusePort().ok()); + ASSERT_TRUE(Receiver.Bind(toolbelt::InetAddress(toolbelt::InetAddress::AnyAddress(port))).ok()); toolbelt::InetAddress sendto_address(toolbelt::InetAddress::BroadcastAddress(port)); ASSERT_TRUE(sender.SendTo(sendto_address, TEST_DATA.data(), TEST_DATA.size()).ok()); - std::vector receive_buffer(TEST_DATA.size()); - ASSERT_EQ(*receiver.Receive(receive_buffer.data(), receive_buffer.size()), TEST_DATA.size()); - ASSERT_EQ(std::string_view(receive_buffer.data(), receive_buffer.size()), TEST_DATA); + std::vector Receive_buffer(TEST_DATA.size()); + ASSERT_EQ(*Receiver.Receive(Receive_buffer.data(), Receive_buffer.size()), TEST_DATA.size()); + ASSERT_EQ(std::string_view(Receive_buffer.data(), Receive_buffer.size()), TEST_DATA); - ASSERT_EQ(0, std::strcmp(receive_buffer.data(), TEST_DATA.data())); + ASSERT_EQ(0, std::strcmp(Receive_buffer.data(), TEST_DATA.data())); } TEST(SocketsTest, UDPSocket_SendAndReceiveMulticast) { @@ -75,33 +542,33 @@ TEST(SocketsTest, UDPSocket_SendAndReceiveMulticast) { toolbelt::InetAddress multicast_address(multicast_ip, port); auto sender = toolbelt::UDPSocket(); - auto receiver = toolbelt::UDPSocket(); + auto Receiver = toolbelt::UDPSocket(); ASSERT_TRUE(sender.SetMulticastLoop().ok()); - std::vector receive_buffer(TEST_DATA.size()); - ASSERT_TRUE(receiver.SetReusePort().ok()); - ASSERT_TRUE(receiver.SetNonBlocking().ok()); - ASSERT_TRUE(receiver.Bind(toolbelt::InetAddress::AnyAddress(port)).ok()); + std::vector Receive_buffer(TEST_DATA.size()); + ASSERT_TRUE(Receiver.SetReusePort().ok()); + ASSERT_TRUE(Receiver.SetNonBlocking().ok()); + ASSERT_TRUE(Receiver.Bind(toolbelt::InetAddress::AnyAddress(port)).ok()); - ASSERT_TRUE(receiver.JoinMulticastGroup(multicast_address).ok()); + ASSERT_TRUE(Receiver.JoinMulticastGroup(multicast_address).ok()); ASSERT_TRUE(sender.SendTo(multicast_address, TEST_DATA.data(), TEST_DATA.size()).ok()); absl::Time timeout = absl::Now() + LOOPBACK_TIMEOUT; while (absl::Now() < timeout) { - auto status_or_len = receiver.Receive(receive_buffer.data(), receive_buffer.size()); + auto status_or_len = Receiver.Receive(Receive_buffer.data(), Receive_buffer.size()); if (status_or_len.ok()) { ASSERT_EQ(*status_or_len, TEST_DATA.size()); - ASSERT_EQ(std::string_view(receive_buffer.data(), receive_buffer.size()), TEST_DATA); + ASSERT_EQ(std::string_view(Receive_buffer.data(), Receive_buffer.size()), TEST_DATA); break; } } - ASSERT_TRUE(receiver.LeaveMulticastGroup(multicast_address).ok()); + ASSERT_TRUE(Receiver.LeaveMulticastGroup(multicast_address).ok()); ASSERT_TRUE(sender.SendTo(multicast_address, TEST_DATA.data(), TEST_DATA.size()).ok()); timeout = absl::Now() + LOOPBACK_TIMEOUT; while (absl::Now() < timeout) { - auto status_or_len = receiver.Receive(receive_buffer.data(), receive_buffer.size()); + auto status_or_len = Receiver.Receive(Receive_buffer.data(), Receive_buffer.size()); if (status_or_len.ok()) { FAIL() << "Received " << *status_or_len << " bytes but expected nothing"; } diff --git a/toolbelt/stacktrace.cc b/toolbelt/stacktrace.cc new file mode 100644 index 0000000..4557e21 --- /dev/null +++ b/toolbelt/stacktrace.cc @@ -0,0 +1,42 @@ + +#include "absl/debugging/stacktrace.h" +#include "absl/base/optimization.h" +#include "absl/debugging/symbolize.h" +#include +#include + +namespace toolbelt { + +void PrintCurrentStack(std::ostream &os) { + os << "--- Stack Trace Capture (Deepest Function) ---\n"; + + constexpr int kMaxFrames = 50; + + // 1. Capture the raw stack addresses + void *stack[kMaxFrames]; + int depth = absl::GetStackTrace(stack, kMaxFrames, 0); + + // 2. Resolve addresses to human-readable symbol names + os << "Captured " << depth << " stack frames:\n"; + + // Buffer to hold the symbolized name + char symbolized_name[1024]; + + for (int i = 0; i < depth; ++i) { + // Attempt to symbolize the address + if (absl::Symbolize(stack[i], symbolized_name, sizeof(symbolized_name))) { + // Success: Print the frame index, address, and resolved symbol name + os << "#" << std::setw(2) << std::left << i << " [0x" << std::hex + << std::setw(16) << stack[i] << std::dec << "] " << symbolized_name + << "\n"; + } else { + // Failure: Symbolization failed (e.g., address not in a symbol table) + os << "#" << std::setw(2) << std::left << i << " [0x" << std::hex + << std::setw(16) << stack[i] << std::dec << "] " + << "\n"; + } + } + os << "----------------------------------------------\n"; +} + +} // namespace toolbelt diff --git a/toolbelt/stacktrace.h b/toolbelt/stacktrace.h new file mode 100644 index 0000000..be6c332 --- /dev/null +++ b/toolbelt/stacktrace.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace toolbelt { +void PrintCurrentStack(std::ostream &os); +} // namespace toolbelt diff --git a/toolbelt/stacktrace_test.cc b/toolbelt/stacktrace_test.cc new file mode 100644 index 0000000..9ed8394 --- /dev/null +++ b/toolbelt/stacktrace_test.cc @@ -0,0 +1,27 @@ +// Copyright 2025 David Allison +// All Rights Reserved +// See LICENSE file for licensing information. + +#include "absl/strings/str_format.h" +#include "toolbelt/stacktrace.h" +#include + +TEST(StacktraceTest, PrintCurrentStack) { + toolbelt::PrintCurrentStack(std::cout); +} + +void foo() { + toolbelt::PrintCurrentStack(std::cout); +} + +void bar() { + foo(); +} + +void baz() { + bar(); +} + +TEST(StacktraceTest, PrintCurrentStackWithFunction) { + baz(); +} \ No newline at end of file diff --git a/toolbelt/table.cc b/toolbelt/table.cc index 7517b4f..3840688 100644 --- a/toolbelt/table.cc +++ b/toolbelt/table.cc @@ -73,7 +73,7 @@ void Table::Print(int width, std::ostream &os) { // Print titles. for (auto &col : cols_) { std::string title = col.title; - if (title.size() > col.width) { + if (title.size() > static_cast(col.width)) { title = title.substr(0, col.width - 1); } os << std::left << std::setw(col.width) << std::setfill(' ') << title; @@ -83,10 +83,10 @@ void Table::Print(int width, std::ostream &os) { os << std::setw(width) << std::setfill('-') << "" << std::endl; // Print each row. - for (size_t i = 0; i < num_rows_; i++) { + for (int i = 0; i < num_rows_; i++) { for (auto &col : cols_) { std::string data = col.cells[i].data; - if (data.size() > col.width) { + if (data.size() > static_cast(col.width)) { // Truncate if too wide. data = data.substr(0, col.width - 1); } @@ -107,7 +107,7 @@ void Table::Clear() { void Table::Render(int width) { std::vector max_widths(cols_.size()); - for (size_t i = 0; i < num_rows_; i++) { + for (int i = 0; i < num_rows_; i++) { int col_index = 0; for (auto &col : cols_) { if (col.cells[i].data.size() > max_widths[col_index]) { @@ -132,7 +132,7 @@ void Table::Render(int width) { } void Table::Sort() { - if (sort_column_ == -1 || sort_column_ >= cols_.size()) { + if (sort_column_ == -1ULL || sort_column_ >= cols_.size()) { return; } struct Index { @@ -140,8 +140,8 @@ void Table::Sort() { std::string data; }; std::vector index(num_rows_); - for (size_t i = 0; i < num_rows_; i++) { - index[i] = {.row = i, .data = cols_[sort_column_].cells[i].data}; + for (int i = 0; i < num_rows_; i++) { + index[i] = {.row = static_cast(i), .data = cols_[sort_column_].cells[i].data}; } std::sort(index.begin(), index.end(), [this](const Index &a, const Index &b) { return sorter_(a.data, b.data);