diff --git a/README.md b/README.md index 54275b9..04281dd 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The UdpSocket class has the following interface for managing the UDP socket: UdpSocket(CallbackImpl &callback, SocketOpt *options = nullptr); // Start a multicast socket -SocketRet startMcast(const char *mcastAddr, uint16_t port); +SocketRet startMcast(const char *mcastAddr, uint16_t port, const char *localIpAddr = nullptr); // Start a unicast client/server socket SocketRet startUnicast(const char *remoteAddr, uint16_t localPort, uint16_t port) diff --git a/examples/mcastApp.cpp b/examples/mcastApp.cpp index 6c36946..abf1c7c 100644 --- a/examples/mcastApp.cpp +++ b/examples/mcastApp.cpp @@ -9,7 +9,7 @@ class McastApp { public: // UDP Multicast - McastApp(const char *multicastAddr, uint16_t port); + McastApp(const char *multicastAddr, uint16_t port, const char *localIface); virtual ~McastApp() = default; @@ -21,8 +21,8 @@ class McastApp { sockets::UdpSocket m_mcast; }; -McastApp::McastApp(const char *multicastAddr, uint16_t port) : m_mcast(*this) { - sockets::SocketRet ret = m_mcast.startMcast(multicastAddr, port); +McastApp::McastApp(const char *multicastAddr, uint16_t port, const char *localIface) : m_mcast(*this) { + sockets::SocketRet ret = m_mcast.startMcast(multicastAddr, port, localIface); if (ret.m_success) { std::cout << "Connected to mcast group " << multicastAddr << ":" << port << "\n"; } else { @@ -45,14 +45,15 @@ void McastApp::onReceiveData(const char *data, size_t size) { } void usage() { - std::cout << "McastApp -m -p \n"; + std::cout << "McastApp -m -p -[-l ]\n"; } int main(int argc, char **argv) { int arg = 0; const char *addr = nullptr; + const char *local = nullptr; uint16_t port = 0; - while ((arg = getopt(argc, argv, "m:p:?")) != EOF) { // NOLINT + while ((arg = getopt(argc, argv, "m:p:l:?")) != EOF) { // NOLINT switch (arg) { case 'm': addr = optarg; @@ -60,13 +61,16 @@ int main(int argc, char **argv) { case 'p': port = static_cast(std::stoul(optarg)); break; + case 'l': + local = optarg; + break; case '?': usage(); exit(1); // NOLINT } } - auto *app = new McastApp(addr, port); + auto *app = new McastApp(addr, port, local); while (true) { std::string data; diff --git a/include/sockets-cpp/UdpSocket.h b/include/sockets-cpp/UdpSocket.h index da00067..9a95c9a 100644 --- a/include/sockets-cpp/UdpSocket.h +++ b/include/sockets-cpp/UdpSocket.h @@ -68,7 +68,7 @@ class UdpSocket { * @param port - port number to listen/connect to * @return SocketRet - indication that multicast setup was successful */ - SocketRet startMcast(const char *mcastAddr, uint16_t port) { + SocketRet startMcast(const char *mcastAddr, uint16_t port, const char *localIpAddr = nullptr) { SocketRet ret; int result = m_socketCore.Initialize(); @@ -164,7 +164,11 @@ class UdpSocket { // struct ip_mreq mreq { }; inet_pton(AF_INET,mcastAddr,&mreq.imr_multiaddr); - mreq.imr_interface.s_addr = htonl(INADDR_ANY); + if (localIpAddr) { + inet_pton(AF_INET, localIpAddr, &mreq.imr_interface.s_addr); + } else { + mreq.imr_interface.s_addr = htonl(INADDR_ANY); + } if (m_socketCore.SetSockOpt(m_fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, reinterpret_cast(&mreq), sizeof(mreq)) < 0) { ret.m_success = false; #if defined(FMT_SUPPORT) @@ -177,6 +181,25 @@ class UdpSocket { return ret; } + struct in_addr addr; + if (localIpAddr) { + inet_pton(AF_INET, localIpAddr, &addr.s_addr); + } else { + mreq.imr_interface.s_addr = htonl(INADDR_ANY); + } + if (m_socketCore.SetSockOpt(m_fd, IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr)) < 0) + { + ret.m_success = false; +#if defined(FMT_SUPPORT) + ret.m_msg = fmt::format("Error: setsockopt(IP_MULTICAST_IF) failed: errno {}", errno); +#else + std::array msg; + (void)snprintf(msg.data(),msg.size(),"Error: setsockopt(IP_MULTICAST_IF) failed: %d", errno); + ret.m_msg = msg.data(); +#endif + return ret; + } + m_thread = std::thread(&UdpSocket::ReceiveTask, this); ret.m_success = true; return ret; diff --git a/test/test_UdpSocket.cpp b/test/test_UdpSocket.cpp index 26c77b8..350de85 100644 --- a/test/test_UdpSocket.cpp +++ b/test/test_UdpSocket.cpp @@ -196,7 +196,7 @@ TEST(UdpSocket,mcast_start_stop) MockSocketCore &core = app.m_socket.getCore(); EXPECT_CALL(core, Socket(_,_,_)).WillOnce(Return(4)); - EXPECT_CALL(core, SetSockOpt(_,_,_,_,_)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)); + EXPECT_CALL(core, SetSockOpt(_,_,_,_,_)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)); EXPECT_CALL(core,Bind(_,_,_)).WillOnce(Return(0)); EXPECT_CALL(core, Close(_)).WillOnce(Return(0)); EXPECT_CALL(core, Select(_,_,_,_,_)).WillRepeatedly(Return(0)); @@ -207,6 +207,23 @@ TEST(UdpSocket,mcast_start_stop) app.m_socket.finish(); } +TEST(UdpSocket,mcast_start_stop_local) +{ + UdpTestApp app; + MockSocketCore &core = app.m_socket.getCore(); + + EXPECT_CALL(core, Socket(_,_,_)).WillOnce(Return(4)); + EXPECT_CALL(core, SetSockOpt(_,_,_,_,_)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)).WillOnce(Return(0)); + EXPECT_CALL(core,Bind(_,_,_)).WillOnce(Return(0)); + EXPECT_CALL(core, Close(_)).WillOnce(Return(0)); + EXPECT_CALL(core, Select(_,_,_,_,_)).WillRepeatedly(Return(0)); + auto ret = app.m_socket.startMcast("224.0.0.1",5000,"127.0.0.1"); + EXPECT_EQ(true,ret.m_success); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + app.m_socket.finish(); +} + TEST(UdpSocket,finish_close_failure) { UdpTestApp app;