1//! @file socket.cpp
2//! @author ryftchen
3//! @brief The definitions (socket) in the utility module.
4//! @version 0.1.0
5//! @copyright Copyright (c) 2022-2026 ryftchen. All rights reserved.
6
7#include "socket.hpp"
8
9#include <sys/poll.h>
10#include <netdb.h>
11#include <algorithm>
12#include <array>
13#include <cstring>
14#include <vector>
15
16namespace utility::socket
17{
18//! @brief Function version number.
19//! @return version number (major.minor.patch)
20const char* version() noexcept
21{
22 static const char* const ver = "0.1.0";
23 return ver;
24}
25
26//! @brief Get the ip address from transport information.
27//! @param addr - transport information
28//! @return ip address string
29static std::string ipAddrString(const ::sockaddr_in& addr)
30{
31 std::array<char, INET_ADDRSTRLEN> ip{};
32 ::inet_ntop(AF_INET, cp: &addr.sin_addr, buf: ip.data(), len: ip.size());
33 return std::string{ip.data()};
34}
35
36//! @brief Get the errno string safely.
37//! @return errno string
38static std::string safeStrErrno()
39{
40 std::array<char, 64> buffer{};
41#ifdef _GNU_SOURCE
42 return ::strerror_r(errno, buf: buffer.data(), buflen: buffer.size());
43#else
44 return (::strerror_r(errno, buffer.data(), buffer.size()) == 0) ? std::string{buffer.data()} : "Unknown error";
45#endif
46}
47
48// NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
49Socket::Socket(const Type sockType, const int sockId)
50{
51 ::pthread_spin_init(lock: &sockLock, pshared: ::PTHREAD_PROCESS_PRIVATE);
52 if (sockId != -1)
53 {
54 sock = sockId;
55 return;
56 }
57
58 sock = ::socket(AF_INET, type: static_cast<std::uint8_t>(sockType), protocol: 0);
59 if (sock == -1)
60 {
61 throw std::runtime_error{"Socket creation error, errno: " + safeStrErrno() + '.'};
62 }
63}
64
65Socket::~Socket()
66{
67 close();
68 ::pthread_spin_destroy(lock: &sockLock);
69}
70
71void Socket::close()
72{
73 requestStop();
74
75 const Guard lock(*this);
76 ::shutdown(fd: sock, how: ::SHUT_RDWR);
77 ::close(fd: sock);
78}
79
80void Socket::join()
81{
82 if (ownedTask.valid() && (ownedTask.wait_until(abs: std::chrono::system_clock::now()) != std::future_status::ready))
83 {
84 ownedTask.wait();
85 }
86
87 while (!stopRequested())
88 {
89 std::this_thread::yield();
90 }
91}
92
93void Socket::requestStop()
94{
95 exitReady.store(i: true);
96}
97
98bool Socket::stopRequested() const
99{
100 return exitReady.load();
101}
102
103std::string Socket::transportAddress() const
104{
105 return ipAddrString(addr: sockAddr);
106}
107
108std::uint16_t Socket::transportPort() const
109{
110 return ::ntohs(netshort: sockAddr.sin_port);
111}
112
113void Socket::spinLock() const
114{
115 ::pthread_spin_lock(lock: &sockLock);
116}
117
118void Socket::spinUnlock() const
119{
120 ::pthread_spin_unlock(lock: &sockLock);
121}
122
123template <typename Func, typename... Args>
124void Socket::spawnDetached(Func&& func, Args&&... args)
125{
126 std::thread(std::forward<Func>(func), std::forward<Args>(args)...).detach();
127}
128
129template <typename Func, typename... Args>
130void Socket::spawnJoinable(Func&& func, Args&&... args)
131{
132 ownedTask = std::async(std::launch::async, std::forward<Func>(func), std::forward<Args>(args)...);
133}
134
135::ssize_t TCPSocket::send(const char* const bytes, const std::size_t size)
136{
137 const Guard lock(*this);
138 return ::send(fd: sock, buf: bytes, n: size, flags: 0);
139}
140
141::ssize_t TCPSocket::send(const std::string_view message)
142{
143 return send(bytes: message.data(), size: message.length());
144}
145
146void TCPSocket::connect(const std::string& ip, const std::uint16_t port)
147{
148 ::addrinfo* addrInfo = nullptr;
149 ::addrinfo hints{};
150 hints.ai_family = AF_INET;
151 hints.ai_socktype = ::SOCK_STREAM;
152
153 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
154 {
155 throw std::runtime_error{
156 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
157 }
158
159 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
160 {
161 if (entry->ai_family == AF_INET)
162 {
163 std::memcpy(dest: static_cast<void*>(&sockAddr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
164 break;
165 }
166 }
167 ::freeaddrinfo(ai: addrInfo);
168
169 sockAddr.sin_family = AF_INET;
170 sockAddr.sin_port = ::htons(hostshort: port);
171 sockAddr.sin_addr.s_addr = static_cast<std::uint32_t>(sockAddr.sin_addr.s_addr);
172 if (::connect(fd: sock, addr: reinterpret_cast<::sockaddr*>(&sockAddr), len: sizeof(::sockaddr_in)) == -1)
173 {
174 throw std::runtime_error{"Failed to connect to the socket, errno: " + safeStrErrno() + '.'};
175 }
176
177 receive();
178}
179
180void TCPSocket::receive(const bool detached)
181{
182 detached ? spawnDetached(func&: doReceive, args: shared_from_this()) : spawnJoinable(func&: doReceive, args: shared_from_this());
183}
184
185void TCPSocket::subscribeMessage(MessageCallback callback)
186{
187 msgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
188}
189
190void TCPSocket::subscribeRawMessage(RawMessageCallback callback)
191{
192 rawMsgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
193}
194
195void TCPSocket::doReceive(const std::shared_ptr<TCPSocket> socket) // NOLINT(performance-unnecessary-value-param)
196{
197 std::array<char, bufferSize> tempBuffer{};
198 std::vector<::pollfd> pollFDs(1);
199 pollFDs.at(n: 0).fd = socket->sock;
200 pollFDs.at(n: 0).events = POLLIN;
201 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
202 {
203 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
204 if (status == -1)
205 {
206 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
207 }
208 if (status == 0)
209 {
210 continue;
211 }
212
213 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
214 {
215 if (const Guard lock(*socket); true)
216 {
217 msgLen = ::recv(fd: socket->sock, buf: tempBuffer.data(), n: tempBuffer.size(), flags: 0);
218 if (msgLen <= 0)
219 {
220 break;
221 }
222 }
223
224 socket->onMessage(message: std::string_view(tempBuffer.data(), static_cast<std::size_t>(msgLen)));
225 socket->onRawMessage(bytes: tempBuffer.data(), size: static_cast<std::size_t>(msgLen));
226 }
227 }
228
229 socket->close();
230}
231
232void TCPSocket::onMessage(const std::string_view message) const
233{
234 const auto callback = msgCb.load(o: std::memory_order_acquire);
235 if (callback && *callback)
236 {
237 (*callback)(message);
238 }
239}
240
241void TCPSocket::onRawMessage(char* const bytes, const std::size_t size) const
242{
243 const auto callback = rawMsgCb.load(o: std::memory_order_acquire);
244 if (callback && *callback)
245 {
246 (*callback)(bytes, size);
247 }
248}
249
250TCPServer::TCPServer() : Socket(Type::tcp)
251{
252 const Guard lock(*this);
253 int opt1 = 1;
254 int opt2 = 0;
255 ::setsockopt(fd: sock, SOL_SOCKET, SO_REUSEADDR, optval: &opt1, optlen: sizeof(opt1));
256 ::setsockopt(fd: sock, SOL_SOCKET, SO_REUSEPORT, optval: &opt2, optlen: sizeof(opt2));
257}
258
259void TCPServer::bind(const std::string& ip, const std::uint16_t port)
260{
261 if (::inet_pton(AF_INET, cp: ip.c_str(), buf: &sockAddr.sin_addr) == -1)
262 {
263 throw std::runtime_error{"Invalid address, address type is not supported, errno: " + safeStrErrno() + '.'};
264 }
265
266 sockAddr.sin_family = AF_INET;
267 sockAddr.sin_port = ::htons(hostshort: port);
268 if (const Guard lock(*this); ::bind(fd: sock, addr: reinterpret_cast<::sockaddr*>(&sockAddr), len: sizeof(sockAddr)) == -1)
269 {
270 throw std::runtime_error{"Failed to bind the socket, errno: " + safeStrErrno() + '.'};
271 }
272}
273
274void TCPServer::bind(const std::uint16_t port)
275{
276 bind(ip: "0.0.0.0", port);
277}
278
279void TCPServer::listen()
280{
281 constexpr std::uint8_t retryTimes = 10;
282 if (const Guard lock(*this); ::listen(fd: sock, n: retryTimes) == -1)
283 {
284 throw std::runtime_error{"Server could not listen on the socket, errno: " + safeStrErrno() + '.'};
285 }
286}
287
288void TCPServer::accept(const bool detached)
289{
290 const auto task = [weakSelf = std::weak_ptr<TCPServer>(shared_from_this())]
291 {
292 if (auto sharedSelf = weakSelf.lock())
293 {
294 accept(server: sharedSelf);
295 }
296 };
297 detached ? spawnDetached(func: task) : spawnJoinable(func: task);
298}
299
300void TCPServer::subscribeConnection(ConnectionCallback callback)
301{
302 connCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
303}
304
305void TCPServer::accept(const std::shared_ptr<TCPServer> server) // NOLINT(performance-unnecessary-value-param)
306{
307 ::sockaddr_in newSockAddr{};
308 ::socklen_t newSockAddrLen = sizeof(newSockAddr);
309
310 for (std::vector<std::shared_ptr<TCPSocket>> activeSockets{};;)
311 {
312 const int newSock = ::accept(fd: server->sock, addr: reinterpret_cast<::sockaddr*>(&newSockAddr), addr_len: &newSockAddrLen);
313 if (newSock == -1)
314 {
315 std::ranges::for_each(activeSockets, [](const auto& socket) { socket->requestStop(); });
316 if ((errno == EBADF) || (errno == EINVAL))
317 {
318 return;
319 }
320 throw std::runtime_error{"Error while accepting a new connection, errno: " + safeStrErrno() + '.'};
321 }
322
323 auto newSocket = std::make_shared<TCPSocket>(args: newSock);
324 newSocket->sockAddr = newSockAddr;
325 server->onConnection(client: newSocket);
326
327 newSocket->receive(detached: true);
328 activeSockets.emplace_back(args: std::move(newSocket));
329 }
330}
331
332void TCPServer::onConnection(
333 const std::shared_ptr<TCPSocket> client) const // NOLINT(performance-unnecessary-value-param)
334{
335 const auto callback = connCb.load(o: std::memory_order_acquire);
336 if (callback && *callback)
337 {
338 (*callback)(client);
339 }
340}
341
342::ssize_t UDPSocket::sendTo(
343 const char* const bytes, const std::size_t size, const std::string& ip, const std::uint16_t port)
344{
345 ::addrinfo* addrInfo = nullptr;
346 ::addrinfo hints{};
347 hints.ai_family = AF_INET;
348 hints.ai_socktype = ::SOCK_DGRAM;
349
350 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
351 {
352 throw std::runtime_error{
353 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
354 }
355
356 ::sockaddr_in addr{};
357 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
358 {
359 if (entry->ai_family == AF_INET)
360 {
361 std::memcpy(dest: static_cast<void*>(&addr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
362 break;
363 }
364 }
365 ::freeaddrinfo(ai: addrInfo);
366
367 addr.sin_port = ::htons(hostshort: port);
368 addr.sin_family = AF_INET;
369 ::ssize_t sent = 0;
370 if (const Guard lock(*this); true)
371 {
372 sent = ::sendto(fd: sock, buf: bytes, n: size, flags: 0, addr: reinterpret_cast<::sockaddr*>(&addr), addr_len: sizeof(addr));
373 if (sent == -1)
374 {
375 throw std::runtime_error{"Unable to send message to address, errno: " + safeStrErrno() + '.'};
376 }
377 }
378 return sent;
379}
380
381::ssize_t UDPSocket::sendTo(const std::string_view message, const std::string& ip, const std::uint16_t port)
382{
383 return sendTo(bytes: message.data(), size: message.length(), ip, port);
384}
385
386::ssize_t UDPSocket::send(const char* const bytes, const std::size_t size)
387{
388 const Guard lock(*this);
389 return ::send(fd: sock, buf: bytes, n: size, flags: 0);
390}
391
392::ssize_t UDPSocket::send(const std::string_view message)
393{
394 return send(bytes: message.data(), size: message.length());
395}
396
397void UDPSocket::connect(const std::string& ip, const std::uint16_t port)
398{
399 ::addrinfo* addrInfo = nullptr;
400 ::addrinfo hints{};
401 hints.ai_family = AF_INET;
402 hints.ai_socktype = ::SOCK_DGRAM;
403
404 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
405 {
406 throw std::runtime_error{
407 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
408 }
409
410 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
411 {
412 if (entry->ai_family == AF_INET)
413 {
414 std::memcpy(dest: static_cast<void*>(&sockAddr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
415 break;
416 }
417 }
418 ::freeaddrinfo(ai: addrInfo);
419
420 sockAddr.sin_family = AF_INET;
421 sockAddr.sin_port = ::htons(hostshort: port);
422 sockAddr.sin_addr.s_addr = static_cast<std::uint32_t>(sockAddr.sin_addr.s_addr);
423 if (::connect(fd: sock, addr: reinterpret_cast<::sockaddr*>(&sockAddr), len: sizeof(::sockaddr_in)) == -1)
424 {
425 throw std::runtime_error{"Failed to connect to the socket, errno: " + safeStrErrno() + '.'};
426 }
427}
428
429void UDPSocket::receive(const bool detached)
430{
431 detached ? spawnDetached(func&: doReceive, args: shared_from_this()) : spawnJoinable(func&: doReceive, args: shared_from_this());
432}
433
434void UDPSocket::receiveFrom(const bool detached)
435{
436 detached ? spawnDetached(func&: doReceiveFrom, args: shared_from_this()) : spawnJoinable(func&: doReceiveFrom, args: shared_from_this());
437}
438
439void UDPSocket::subscribeMessage(MessageCallback callback)
440{
441 msgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
442}
443
444void UDPSocket::subscribeRawMessage(RawMessageCallback callback)
445{
446 rawMsgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
447}
448
449void UDPSocket::doReceive(const std::shared_ptr<UDPSocket> socket) // NOLINT(performance-unnecessary-value-param)
450{
451 std::array<char, bufferSize> tempBuffer{};
452 std::vector<::pollfd> pollFDs(1);
453 pollFDs.at(n: 0).fd = socket->sock;
454 pollFDs.at(n: 0).events = POLLIN;
455 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
456 {
457 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
458 if (status == -1)
459 {
460 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
461 }
462 if (status == 0)
463 {
464 continue;
465 }
466
467 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
468 {
469 if (const Guard lock(*socket); true)
470 {
471 msgLen = ::recv(fd: socket->sock, buf: tempBuffer.data(), n: tempBuffer.size(), flags: 0);
472 if (msgLen == -1)
473 {
474 break;
475 }
476 }
477
478 socket->onMessage(
479 message: std::string_view(tempBuffer.data(), static_cast<std::size_t>(msgLen)),
480 ip: socket->transportAddress(),
481 port: socket->transportPort());
482 socket->onRawMessage(
483 bytes: tempBuffer.data(),
484 size: static_cast<std::size_t>(msgLen),
485 ip: socket->transportAddress(),
486 port: socket->transportPort());
487 }
488 }
489}
490
491void UDPSocket::doReceiveFrom(const std::shared_ptr<UDPSocket> socket) // NOLINT(performance-unnecessary-value-param)
492{
493 ::sockaddr_in addr{};
494 ::socklen_t hostAddrSize = sizeof(addr);
495
496 std::array<char, bufferSize> tempBuffer{};
497 std::vector<::pollfd> pollFDs(1);
498 pollFDs.at(n: 0).fd = socket->sock;
499 pollFDs.at(n: 0).events = POLLIN;
500 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
501 {
502 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
503 if (status == -1)
504 {
505 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
506 }
507 if (status == 0)
508 {
509 continue;
510 }
511
512 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
513 {
514 if (const Guard lock(*socket); true)
515 {
516 msgLen = ::recvfrom(
517 fd: socket->sock,
518 buf: tempBuffer.data(),
519 n: tempBuffer.size(),
520 flags: 0,
521 addr: reinterpret_cast<::sockaddr*>(&addr),
522 addr_len: &hostAddrSize);
523 if (msgLen == -1)
524 {
525 break;
526 }
527 }
528
529 socket->onMessage(
530 message: std::string_view(tempBuffer.data(), static_cast<std::size_t>(msgLen)),
531 ip: ipAddrString(addr),
532 port: ::ntohs(netshort: addr.sin_port));
533 socket->onRawMessage(
534 bytes: tempBuffer.data(), size: static_cast<std::size_t>(msgLen), ip: ipAddrString(addr), port: ::ntohs(netshort: addr.sin_port));
535 }
536 }
537}
538
539void UDPSocket::onMessage(const std::string_view message, const std::string& ip, const std::uint16_t port) const
540{
541 const auto callback = msgCb.load(o: std::memory_order_acquire);
542 if (callback && *callback)
543 {
544 (*callback)(message, ip, port);
545 }
546}
547
548void UDPSocket::onRawMessage(
549 char* const bytes, const std::size_t size, const std::string& ip, const std::uint16_t port) const
550{
551 const auto callback = rawMsgCb.load(o: std::memory_order_acquire);
552 if (callback && *callback)
553 {
554 (*callback)(bytes, size, ip, port);
555 }
556}
557
558void UDPServer::bind(const std::string& ip, const std::uint16_t port)
559{
560 if (::inet_pton(AF_INET, cp: ip.c_str(), buf: &sockAddr.sin_addr) == -1)
561 {
562 throw std::runtime_error{"Invalid address, address type is not supported, errno: " + safeStrErrno() + '.'};
563 }
564
565 sockAddr.sin_family = AF_INET;
566 sockAddr.sin_port = ::htons(hostshort: port);
567 if (const Guard lock(*this); ::bind(fd: sock, addr: reinterpret_cast<::sockaddr*>(&sockAddr), len: sizeof(sockAddr)) == -1)
568 {
569 throw std::runtime_error{"Failed to bind the socket, errno: " + safeStrErrno() + '.'};
570 }
571}
572
573void UDPServer::bind(const std::uint16_t port)
574{
575 bind(ip: "0.0.0.0", port);
576}
577// NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
578} // namespace utility::socket
579