1//! @file socket.cpp
2//! @author ryftchen
3//! @brief The definitions (socket) in the utility module.
4//! @version 0.1.0
5//! @copyright Copyright (c) 2022-2026 ryftchen. All rights reserved.
6
7#include "socket.hpp"
8
9#include <sys/poll.h>
10#include <netdb.h>
11#include <algorithm>
12#include <array>
13#include <cstring>
14#include <vector>
15
16namespace utility::socket
17{
18//! @brief Function version number.
19//! @return version number (major.minor.patch)
20const char* version() noexcept
21{
22 static const char* const ver = "0.1.0";
23 return ver;
24}
25
26//! @brief Get the ip address from transport information.
27//! @param addr - transport information
28//! @return ip address string
29static std::string ipAddrString(const ::sockaddr_in& addr)
30{
31 std::array<char, INET_ADDRSTRLEN> ip{};
32 ::inet_ntop(AF_INET, cp: &addr.sin_addr, buf: ip.data(), len: ip.size());
33 return std::string{ip.data()};
34}
35
36//! @brief Get the errno string safely.
37//! @return errno string
38static std::string safeStrErrno()
39{
40 std::array<char, 64> buffer{};
41#ifdef _GNU_SOURCE
42 return ::strerror_r(errno, buf: buffer.data(), buflen: buffer.size());
43#else
44 return (::strerror_r(errno, buffer.data(), buffer.size()) == 0) ? std::string{buffer.data()} : "Unknown error";
45#endif
46}
47
48// NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
49Socket::Socket(const Type sockType, const int sockId)
50{
51 ::pthread_spin_init(lock: &sockLock, pshared: ::PTHREAD_PROCESS_PRIVATE);
52 if (sockId != -1)
53 {
54 sock = sockId;
55 return;
56 }
57
58 sock = ::socket(AF_INET, type: static_cast<std::uint8_t>(sockType), protocol: 0);
59 if (sock == -1)
60 {
61 throw std::runtime_error{"Socket creation error, errno: " + safeStrErrno() + '.'};
62 }
63}
64
65Socket::~Socket()
66{
67 close();
68 ::pthread_spin_destroy(lock: &sockLock);
69}
70
71void Socket::close()
72{
73 requestStop();
74
75 const Guard lock(*this);
76 ::shutdown(fd: sock, how: ::SHUT_RDWR);
77 ::close(fd: sock);
78}
79
80void Socket::join()
81{
82 if (ownedTask.valid() && (ownedTask.wait_until(abs: std::chrono::system_clock::now()) != std::future_status::ready))
83 {
84 ownedTask.wait();
85 }
86
87 while (!stopRequested())
88 {
89 std::this_thread::yield();
90 }
91}
92
93void Socket::requestStop()
94{
95 exitReady.store(i: true);
96}
97
98bool Socket::stopRequested() const
99{
100 return exitReady.load();
101}
102
103std::string Socket::transportAddress() const
104{
105 return ipAddrString(addr: sockAddr);
106}
107
108std::uint16_t Socket::transportPort() const
109{
110 return ::ntohs(netshort: sockAddr.sin_port);
111}
112
113void Socket::spinLock() const
114{
115 ::pthread_spin_lock(lock: &sockLock);
116}
117
118void Socket::spinUnlock() const
119{
120 ::pthread_spin_unlock(lock: &sockLock);
121}
122
123template <typename Func, typename... Args>
124void Socket::spawnDetached(Func&& func, Args&&... args)
125{
126 std::thread(std::forward<Func>(func), std::forward<Args>(args)...).detach();
127}
128
129template <typename Func, typename... Args>
130void Socket::spawnJoinable(Func&& func, Args&&... args)
131{
132 ownedTask = std::async(std::launch::async, std::forward<Func>(func), std::forward<Args>(args)...);
133}
134
135::ssize_t TCPSocket::send(const char* const bytes, const std::size_t size)
136{
137 const Guard lock(*this);
138 return ::send(fd: sock, buf: bytes, n: size, flags: 0);
139}
140
141::ssize_t TCPSocket::send(const std::string_view message)
142{
143 return send(bytes: message.data(), size: message.length());
144}
145
146void TCPSocket::connect(const std::string& ip, const std::uint16_t port)
147{
148 ::addrinfo* addrInfo = nullptr;
149 ::addrinfo hints{};
150 hints.ai_family = AF_INET;
151 hints.ai_socktype = ::SOCK_STREAM;
152
153 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
154 {
155 throw std::runtime_error{
156 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
157 }
158
159 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
160 {
161 if (entry->ai_family == AF_INET)
162 {
163 std::memcpy(dest: static_cast<void*>(&sockAddr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
164 break;
165 }
166 }
167 ::freeaddrinfo(ai: addrInfo);
168
169 sockAddr.sin_family = AF_INET;
170 sockAddr.sin_port = ::htons(hostshort: port);
171 sockAddr.sin_addr.s_addr = static_cast<std::uint32_t>(sockAddr.sin_addr.s_addr);
172 if (::connect(fd: sock, addr: reinterpret_cast<const ::sockaddr*>(&sockAddr), len: sizeof(::sockaddr_in)) == -1)
173 {
174 throw std::runtime_error{"Failed to connect to the socket, errno: " + safeStrErrno() + '.'};
175 }
176
177 receive();
178}
179
180void TCPSocket::receive(const bool detached)
181{
182 detached ? spawnDetached(func&: doReceive, args: shared_from_this()) : spawnJoinable(func&: doReceive, args: shared_from_this());
183}
184
185void TCPSocket::subscribeMessage(MessageCallback callback)
186{
187 msgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
188}
189
190void TCPSocket::subscribeRawMessage(RawMessageCallback callback)
191{
192 rawMsgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
193}
194
195void TCPSocket::doReceive(const std::shared_ptr<TCPSocket> socket) // NOLINT(performance-unnecessary-value-param)
196{
197 std::array<char, bufferSize> tempBuffer{};
198 std::vector<::pollfd> pollFDs(1);
199 pollFDs.at(n: 0).fd = socket->sock;
200 pollFDs.at(n: 0).events = POLLIN;
201 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
202 {
203 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
204 if (status == -1)
205 {
206 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
207 }
208 if (status == 0)
209 {
210 continue;
211 }
212
213 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
214 {
215 if (const Guard lock(*socket); true)
216 {
217 msgLen = ::recv(fd: socket->sock, buf: tempBuffer.data(), n: tempBuffer.size(), flags: 0);
218 if (msgLen <= 0)
219 {
220 break;
221 }
222 }
223
224 tempBuffer[msgLen] = '\0';
225 socket->onMessage(message: std::string(tempBuffer.data(), msgLen));
226 socket->onRawMessage(bytes: tempBuffer.data(), size: msgLen);
227 }
228 }
229
230 socket->close();
231}
232
233void TCPSocket::onMessage(const std::string_view message) const
234{
235 const auto callback = msgCb.load(o: std::memory_order_acquire);
236 if (callback && *callback)
237 {
238 (*callback)(message);
239 }
240}
241
242void TCPSocket::onRawMessage(char* const bytes, const std::size_t size) const
243{
244 const auto callback = rawMsgCb.load(o: std::memory_order_acquire);
245 if (callback && *callback)
246 {
247 (*callback)(bytes, size);
248 }
249}
250
251TCPServer::TCPServer() : Socket(Type::tcp)
252{
253 const Guard lock(*this);
254 int opt1 = 1;
255 int opt2 = 0;
256 ::setsockopt(fd: sock, SOL_SOCKET, SO_REUSEADDR, optval: &opt1, optlen: sizeof(opt1));
257 ::setsockopt(fd: sock, SOL_SOCKET, SO_REUSEPORT, optval: &opt2, optlen: sizeof(opt2));
258}
259
260void TCPServer::bind(const std::string& ip, const std::uint16_t port)
261{
262 if (::inet_pton(AF_INET, cp: ip.c_str(), buf: &sockAddr.sin_addr) == -1)
263 {
264 throw std::runtime_error{"Invalid address, address type is not supported, errno: " + safeStrErrno() + '.'};
265 }
266
267 sockAddr.sin_family = AF_INET;
268 sockAddr.sin_port = ::htons(hostshort: port);
269 if (const Guard lock(*this); ::bind(fd: sock, addr: reinterpret_cast<const ::sockaddr*>(&sockAddr), len: sizeof(sockAddr)) == -1)
270 {
271 throw std::runtime_error{"Failed to bind the socket, errno: " + safeStrErrno() + '.'};
272 }
273}
274
275void TCPServer::bind(const std::uint16_t port)
276{
277 bind(ip: "0.0.0.0", port);
278}
279
280void TCPServer::listen()
281{
282 constexpr std::uint8_t retryTimes = 10;
283 if (const Guard lock(*this); ::listen(fd: sock, n: retryTimes) == -1)
284 {
285 throw std::runtime_error{"Server could not listen on the socket, errno: " + safeStrErrno() + '.'};
286 }
287}
288
289void TCPServer::accept(const bool detached)
290{
291 const auto task = [weakSelf = std::weak_ptr<TCPServer>(shared_from_this())]
292 {
293 if (auto sharedSelf = weakSelf.lock())
294 {
295 accept(server: sharedSelf);
296 }
297 };
298 detached ? spawnDetached(func: task) : spawnJoinable(func: task);
299}
300
301void TCPServer::subscribeConnection(ConnectionCallback callback)
302{
303 connCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
304}
305
306void TCPServer::accept(const std::shared_ptr<TCPServer> server) // NOLINT(performance-unnecessary-value-param)
307{
308 ::sockaddr_in newSockAddr{};
309 ::socklen_t newSockAddrLen = sizeof(newSockAddr);
310
311 for (std::vector<std::shared_ptr<TCPSocket>> activeSockets{};;)
312 {
313 const int newSock = ::accept(fd: server->sock, addr: reinterpret_cast<::sockaddr*>(&newSockAddr), addr_len: &newSockAddrLen);
314 if (newSock == -1)
315 {
316 std::ranges::for_each(activeSockets, [](const auto& socket) { socket->requestStop(); });
317 if ((errno == EBADF) || (errno == EINVAL))
318 {
319 return;
320 }
321 throw std::runtime_error{"Error while accepting a new connection, errno: " + safeStrErrno() + '.'};
322 }
323
324 auto newSocket = std::make_shared<TCPSocket>(args: newSock);
325 newSocket->sockAddr = newSockAddr;
326 server->onConnection(client: newSocket);
327
328 newSocket->receive(detached: true);
329 activeSockets.emplace_back(args: std::move(newSocket));
330 }
331}
332
333void TCPServer::onConnection(
334 const std::shared_ptr<TCPSocket> client) const // NOLINT(performance-unnecessary-value-param)
335{
336 const auto callback = connCb.load(o: std::memory_order_acquire);
337 if (callback && *callback)
338 {
339 (*callback)(client);
340 }
341}
342
343::ssize_t UDPSocket::sendTo(
344 const char* const bytes, const std::size_t size, const std::string& ip, const std::uint16_t port)
345{
346 ::addrinfo* addrInfo = nullptr;
347 ::addrinfo hints{};
348 hints.ai_family = AF_INET;
349 hints.ai_socktype = ::SOCK_DGRAM;
350
351 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
352 {
353 throw std::runtime_error{
354 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
355 }
356
357 ::sockaddr_in addr{};
358 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
359 {
360 if (entry->ai_family == AF_INET)
361 {
362 std::memcpy(dest: static_cast<void*>(&addr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
363 break;
364 }
365 }
366 ::freeaddrinfo(ai: addrInfo);
367
368 addr.sin_port = ::htons(hostshort: port);
369 addr.sin_family = AF_INET;
370 ::ssize_t sent = 0;
371 if (const Guard lock(*this); true)
372 {
373 sent = ::sendto(fd: sock, buf: bytes, n: size, flags: 0, addr: reinterpret_cast<const ::sockaddr*>(&addr), addr_len: sizeof(addr));
374 if (sent == -1)
375 {
376 throw std::runtime_error{"Unable to send message to address, errno: " + safeStrErrno() + '.'};
377 }
378 }
379 return sent;
380}
381
382::ssize_t UDPSocket::sendTo(const std::string_view message, const std::string& ip, const std::uint16_t port)
383{
384 return sendTo(bytes: message.data(), size: message.length(), ip, port);
385}
386
387::ssize_t UDPSocket::send(const char* const bytes, const std::size_t size)
388{
389 const Guard lock(*this);
390 return ::send(fd: sock, buf: bytes, n: size, flags: 0);
391}
392
393::ssize_t UDPSocket::send(const std::string_view message)
394{
395 return send(bytes: message.data(), size: message.length());
396}
397
398void UDPSocket::connect(const std::string& ip, const std::uint16_t port)
399{
400 ::addrinfo* addrInfo = nullptr;
401 ::addrinfo hints{};
402 hints.ai_family = AF_INET;
403 hints.ai_socktype = ::SOCK_DGRAM;
404
405 if (const int status = ::getaddrinfo(name: ip.c_str(), service: nullptr, req: &hints, pai: &addrInfo); status != 0)
406 {
407 throw std::runtime_error{
408 "Invalid address, status: " + std::string{::gai_strerror(ecode: status)} + ", errno: " + safeStrErrno() + '.'};
409 }
410
411 for (const auto* entry = addrInfo; entry != nullptr; entry = entry->ai_next)
412 {
413 if (entry->ai_family == AF_INET)
414 {
415 std::memcpy(dest: static_cast<void*>(&sockAddr), src: static_cast<void*>(entry->ai_addr), n: sizeof(::sockaddr_in));
416 break;
417 }
418 }
419 ::freeaddrinfo(ai: addrInfo);
420
421 sockAddr.sin_family = AF_INET;
422 sockAddr.sin_port = ::htons(hostshort: port);
423 sockAddr.sin_addr.s_addr = static_cast<std::uint32_t>(sockAddr.sin_addr.s_addr);
424 if (::connect(fd: sock, addr: reinterpret_cast<const ::sockaddr*>(&sockAddr), len: sizeof(::sockaddr_in)) == -1)
425 {
426 throw std::runtime_error{"Failed to connect to the socket, errno: " + safeStrErrno() + '.'};
427 }
428}
429
430void UDPSocket::receive(const bool detached)
431{
432 detached ? spawnDetached(func&: doReceive, args: shared_from_this()) : spawnJoinable(func&: doReceive, args: shared_from_this());
433}
434
435void UDPSocket::receiveFrom(const bool detached)
436{
437 detached ? spawnDetached(func&: doReceiveFrom, args: shared_from_this()) : spawnJoinable(func&: doReceiveFrom, args: shared_from_this());
438}
439
440void UDPSocket::subscribeMessage(MessageCallback callback)
441{
442 msgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
443}
444
445void UDPSocket::subscribeRawMessage(RawMessageCallback callback)
446{
447 rawMsgCb.store(desired: std::make_shared<decltype(callback)>(args: std::move(callback)), o: std::memory_order_release);
448}
449
450void UDPSocket::doReceive(const std::shared_ptr<UDPSocket> socket) // NOLINT(performance-unnecessary-value-param)
451{
452 std::array<char, bufferSize> tempBuffer{};
453 std::vector<::pollfd> pollFDs(1);
454 pollFDs.at(n: 0).fd = socket->sock;
455 pollFDs.at(n: 0).events = POLLIN;
456 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
457 {
458 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
459 if (status == -1)
460 {
461 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
462 }
463 if (status == 0)
464 {
465 continue;
466 }
467
468 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
469 {
470 if (const Guard lock(*socket); true)
471 {
472 msgLen = ::recv(fd: socket->sock, buf: tempBuffer.data(), n: tempBuffer.size(), flags: 0);
473 if (msgLen == -1)
474 {
475 break;
476 }
477 }
478
479 tempBuffer[msgLen] = '\0';
480 socket->onMessage(
481 message: std::string_view(tempBuffer.data(), msgLen), ip: socket->transportAddress(), port: socket->transportPort());
482 socket->onRawMessage(bytes: tempBuffer.data(), size: msgLen, ip: socket->transportAddress(), port: socket->transportPort());
483 }
484 }
485}
486
487void UDPSocket::doReceiveFrom(const std::shared_ptr<UDPSocket> socket) // NOLINT(performance-unnecessary-value-param)
488{
489 ::sockaddr_in addr{};
490 ::socklen_t hostAddrSize = sizeof(addr);
491
492 std::array<char, bufferSize> tempBuffer{};
493 std::vector<::pollfd> pollFDs(1);
494 pollFDs.at(n: 0).fd = socket->sock;
495 pollFDs.at(n: 0).events = POLLIN;
496 for (constexpr std::uint8_t timeout = 10; !socket->stopRequested();)
497 {
498 const int status = ::poll(fds: pollFDs.data(), nfds: pollFDs.size(), timeout: timeout);
499 if (status == -1)
500 {
501 throw std::runtime_error{"Not the expected wait result for poll, errno: " + safeStrErrno() + '.'};
502 }
503 if (status == 0)
504 {
505 continue;
506 }
507
508 if (::ssize_t msgLen = 0; pollFDs.at(n: 0).revents & POLLIN)
509 {
510 if (const Guard lock(*socket); true)
511 {
512 msgLen = ::recvfrom(
513 fd: socket->sock,
514 buf: tempBuffer.data(),
515 n: tempBuffer.size(),
516 flags: 0,
517 addr: reinterpret_cast<::sockaddr*>(&addr),
518 addr_len: &hostAddrSize);
519 if (msgLen == -1)
520 {
521 break;
522 }
523 }
524
525 tempBuffer[msgLen] = '\0';
526 socket->onMessage(message: std::string_view(tempBuffer.data(), msgLen), ip: ipAddrString(addr), port: ::ntohs(netshort: addr.sin_port));
527 socket->onRawMessage(bytes: tempBuffer.data(), size: msgLen, ip: ipAddrString(addr), port: ::ntohs(netshort: addr.sin_port));
528 }
529 }
530}
531
532void UDPSocket::onMessage(const std::string_view message, const std::string& ip, const std::uint16_t port) const
533{
534 const auto callback = msgCb.load(o: std::memory_order_acquire);
535 if (callback && *callback)
536 {
537 (*callback)(message, ip, port);
538 }
539}
540
541void UDPSocket::onRawMessage(
542 char* const bytes, const std::size_t size, const std::string& ip, const std::uint16_t port) const
543{
544 const auto callback = rawMsgCb.load(o: std::memory_order_acquire);
545 if (callback && *callback)
546 {
547 (*callback)(bytes, size, ip, port);
548 }
549}
550
551void UDPServer::bind(const std::string& ip, const std::uint16_t port)
552{
553 if (::inet_pton(AF_INET, cp: ip.c_str(), buf: &sockAddr.sin_addr) == -1)
554 {
555 throw std::runtime_error{"Invalid address, address type is not supported, errno: " + safeStrErrno() + '.'};
556 }
557
558 sockAddr.sin_family = AF_INET;
559 sockAddr.sin_port = ::htons(hostshort: port);
560 if (const Guard lock(*this); ::bind(fd: sock, addr: reinterpret_cast<const ::sockaddr*>(&sockAddr), len: sizeof(sockAddr)) == -1)
561 {
562 throw std::runtime_error{"Failed to bind the socket, errno: " + safeStrErrno() + '.'};
563 }
564}
565
566void UDPServer::bind(const std::uint16_t port)
567{
568 bind(ip: "0.0.0.0", port);
569}
570// NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
571} // namespace utility::socket
572