dolphin/Source/Core/Common/TraversalClient.cpp
Lioncash cb4ca7837a TraversalClient: Prevent uninitialized values from occurring in MakeENetAddress
Previously, eaddr would only be partially initialized in the ipv6 case.
Even if there's no support for it, we may as well ensure that the
variable always has deterministic initialization.

While we're at it, we can make the parameter a const reference, given no
members are modified.
2021-01-20 12:24:05 -05:00

344 lines
8.5 KiB
C++

// This file is public domain, in case it's useful to anyone. -comex
#include "Common/TraversalClient.h"
#include <cstddef>
#include <cstring>
#include <string>
#include "Common/CommonTypes.h"
#include "Common/Logging/Log.h"
#include "Common/MsgHandler.h"
#include "Common/Random.h"
#include "Core/NetPlayProto.h"
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
: m_NetHost(netHost), m_Server(server), m_port(port)
{
netHost->intercept = TraversalClient::InterceptCallback;
Reset();
ReconnectToServer();
}
TraversalClient::~TraversalClient() = default;
TraversalHostId TraversalClient::GetHostID() const
{
return m_HostId;
}
TraversalClient::State TraversalClient::GetState() const
{
return m_State;
}
TraversalClient::FailureReason TraversalClient::GetFailureReason() const
{
return m_FailureReason;
}
void TraversalClient::ReconnectToServer()
{
if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
{
OnFailure(FailureReason::BadHost);
return;
}
m_ServerAddress.port = m_port;
m_State = State::Connecting;
TraversalPacket hello = {};
hello.type = TraversalPacketType::HelloFromClient;
hello.helloFromClient.protoVersion = TraversalProtoVersion;
SendTraversalPacket(hello);
if (m_Client)
m_Client->OnTraversalStateChanged();
}
static ENetAddress MakeENetAddress(const TraversalInetAddress& address)
{
ENetAddress eaddr{};
if (address.isIPV6)
{
eaddr.port = 0; // no support yet :(
}
else
{
eaddr.host = address.address[0];
eaddr.port = ntohs(address.port);
}
return eaddr;
}
void TraversalClient::ConnectToClient(const std::string& host)
{
if (host.size() > sizeof(TraversalHostId))
{
PanicAlertFmt("Host too long");
return;
}
TraversalPacket packet = {};
packet.type = TraversalPacketType::ConnectPlease;
memcpy(packet.connectPlease.hostId.data(), host.c_str(), host.size());
m_ConnectRequestId = SendTraversalPacket(packet);
m_PendingConnect = true;
}
bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
{
if (from->host == m_ServerAddress.host && from->port == m_ServerAddress.port)
{
if (size < sizeof(TraversalPacket))
{
ERROR_LOG_FMT(NETPLAY, "Received too-short traversal packet.");
}
else
{
HandleServerPacket((TraversalPacket*)data);
return true;
}
}
return false;
}
//--Temporary until more of the old netplay branch is moved over
void TraversalClient::Update()
{
ENetEvent netEvent;
if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
{
switch (netEvent.type)
{
case ENET_EVENT_TYPE_RECEIVE:
TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
enet_packet_destroy(netEvent.packet);
break;
default:
break;
}
}
HandleResends();
}
void TraversalClient::HandleServerPacket(TraversalPacket* packet)
{
u8 ok = 1;
switch (packet->type)
{
case TraversalPacketType::Ack:
if (!packet->ack.ok)
{
OnFailure(FailureReason::ServerForgotAboutUs);
break;
}
for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
{
if (it->packet.requestId == packet->requestId)
{
m_OutgoingTraversalPackets.erase(it);
break;
}
}
break;
case TraversalPacketType::HelloFromServer:
if (!IsConnecting())
break;
if (!packet->helloFromServer.ok)
{
OnFailure(FailureReason::VersionTooOld);
break;
}
m_HostId = packet->helloFromServer.yourHostId;
m_State = State::Connected;
if (m_Client)
m_Client->OnTraversalStateChanged();
break;
case TraversalPacketType::PleaseSendPacket:
{
// security is overrated.
ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address);
if (addr.port != 0)
{
char message[] = "Hello from Dolphin Netplay...";
ENetBuffer buf;
buf.data = message;
buf.dataLength = sizeof(message) - 1;
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
}
else
{
// invalid IPV6
ok = 0;
}
break;
}
case TraversalPacketType::ConnectReady:
case TraversalPacketType::ConnectFailed:
{
if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
break;
m_PendingConnect = false;
if (!m_Client)
break;
if (packet->type == TraversalPacketType::ConnectReady)
m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address));
else
m_Client->OnConnectFailed(packet->connectFailed.reason);
break;
}
default:
WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", packet->type);
break;
}
if (packet->type != TraversalPacketType::Ack)
{
TraversalPacket ack = {};
ack.type = TraversalPacketType::Ack;
ack.requestId = packet->requestId;
ack.ack.ok = ok;
ENetBuffer buf;
buf.data = &ack;
buf.dataLength = sizeof(ack);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
OnFailure(FailureReason::SocketSendError);
}
}
void TraversalClient::OnFailure(FailureReason reason)
{
m_State = State::Failure;
m_FailureReason = reason;
if (m_Client)
m_Client->OnTraversalStateChanged();
}
void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
{
info->sendTime = enet_time_get();
info->tries++;
ENetBuffer buf;
buf.data = &info->packet;
buf.dataLength = sizeof(info->packet);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
OnFailure(FailureReason::SocketSendError);
}
void TraversalClient::HandleResends()
{
const u32 now = enet_time_get();
for (auto& tpi : m_OutgoingTraversalPackets)
{
if (now - tpi.sendTime >= (u32)(300 * tpi.tries))
{
if (tpi.tries >= 5)
{
OnFailure(FailureReason::ResendTimeout);
m_OutgoingTraversalPackets.clear();
break;
}
else
{
ResendPacket(&tpi);
}
}
}
HandlePing();
}
void TraversalClient::HandlePing()
{
const u32 now = enet_time_get();
if (IsConnected() && now - m_PingTime >= 500)
{
TraversalPacket ping = {};
ping.type = TraversalPacketType::Ping;
ping.ping.hostId = m_HostId;
SendTraversalPacket(ping);
m_PingTime = now;
}
}
TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
{
OutgoingTraversalPacketInfo info;
info.packet = packet;
info.packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
info.tries = 0;
m_OutgoingTraversalPackets.push_back(info);
ResendPacket(&m_OutgoingTraversalPackets.back());
return info.packet.requestId;
}
void TraversalClient::Reset()
{
m_PendingConnect = false;
m_Client = nullptr;
}
int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
{
auto traversalClient = g_TraversalClient.get();
if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength,
&host->receivedAddress) ||
(host->receivedDataLength == 1 && host->receivedData[0] == 0))
{
event->type = (ENetEventType)42;
return 1;
}
return 0;
}
std::unique_ptr<TraversalClient> g_TraversalClient;
std::unique_ptr<ENetHost> g_MainNetHost;
// The settings at the previous TraversalClient reset - notably, we
// need to know not just what port it's on, but whether it was
// explicitly requested.
static std::string g_OldServer;
static u16 g_OldServerPort;
static u16 g_OldListenPort;
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port)
{
if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
server_port != g_OldServerPort || listen_port != g_OldListenPort)
{
g_OldServer = server;
g_OldServerPort = server_port;
g_OldListenPort = listen_port;
ENetAddress addr = {ENET_HOST_ANY, listen_port};
ENetHost* host = enet_host_create(&addr, // address
50, // peerCount
NetPlay::CHANNEL_COUNT, // channelLimit
0, // incomingBandwidth
0); // outgoingBandwidth
if (!host)
{
g_MainNetHost.reset();
return false;
}
g_MainNetHost.reset(host);
g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, server_port));
}
return true;
}
void ReleaseTraversalClient()
{
if (!g_TraversalClient)
return;
g_TraversalClient.reset();
g_MainNetHost.reset();
}