Rewrite TCP socket reading using bytes::vector.

I hope this fixes a strange assertion violation.
This commit is contained in:
John Preston 2018-07-11 17:00:06 +03:00
parent 951634a717
commit 556f75ef6c
2 changed files with 105 additions and 95 deletions

View file

@ -19,8 +19,9 @@ namespace MTP {
namespace internal { namespace internal {
namespace { namespace {
constexpr auto kPacketSizeMax = 0x01000000 * sizeof(mtpPrime); constexpr auto kPacketSizeMax = int(0x01000000 * sizeof(mtpPrime));
constexpr auto kFullConnectionTimeout = 8 * TimeMs(1000); constexpr auto kFullConnectionTimeout = 8 * TimeMs(1000);
constexpr auto kSmallBufferSize = 256 * 1024;
using ErrorSignal = void(QTcpSocket::*)(QAbstractSocket::SocketError); using ErrorSignal = void(QTcpSocket::*)(QAbstractSocket::SocketError);
const auto QTcpSocket_error = ErrorSignal(&QAbstractSocket::error); const auto QTcpSocket_error = ErrorSignal(&QAbstractSocket::error);
@ -38,9 +39,9 @@ public:
virtual void prepareKey(bytes::span key, bytes::const_span source) = 0; virtual void prepareKey(bytes::span key, bytes::const_span source) = 0;
virtual bytes::span finalizePacket(mtpBuffer &buffer) = 0; virtual bytes::span finalizePacket(mtpBuffer &buffer) = 0;
static constexpr auto kUnknownSize = uint32(-1); static constexpr auto kUnknownSize = -1;
static constexpr auto kInvalidSize = uint32(-2); static constexpr auto kInvalidSize = -2;
virtual uint32 readPacketLength(bytes::const_span bytes) const = 0; virtual int readPacketLength(bytes::const_span bytes) const = 0;
virtual bytes::const_span readPacket(bytes::const_span bytes) const = 0; virtual bytes::const_span readPacket(bytes::const_span bytes) const = 0;
virtual ~Protocol() = default; virtual ~Protocol() = default;
@ -61,7 +62,7 @@ public:
void prepareKey(bytes::span key, bytes::const_span source) override; void prepareKey(bytes::span key, bytes::const_span source) override;
bytes::span finalizePacket(mtpBuffer &buffer) override; bytes::span finalizePacket(mtpBuffer &buffer) override;
uint32 readPacketLength(bytes::const_span bytes) const override; int readPacketLength(bytes::const_span bytes) const override;
bytes::const_span readPacket(bytes::const_span bytes) const override; bytes::const_span readPacket(bytes::const_span bytes) const override;
}; };
@ -105,7 +106,7 @@ bytes::span TcpConnection::Protocol::Version0::finalizePacket(
return bytes::make_span(buffer).subspan(8 - added, added + bytesSize); return bytes::make_span(buffer).subspan(8 - added, added + bytesSize);
} }
uint32 TcpConnection::Protocol::Version0::readPacketLength( int TcpConnection::Protocol::Version0::readPacketLength(
bytes::const_span bytes) const { bytes::const_span bytes) const {
if (bytes.empty()) { if (bytes.empty()) {
return kUnknownSize; return kUnknownSize;
@ -118,10 +119,10 @@ uint32 TcpConnection::Protocol::Version0::readPacketLength(
const auto ints = static_cast<uint32>(bytes[1]) const auto ints = static_cast<uint32>(bytes[1])
| (static_cast<uint32>(bytes[2]) << 8) | (static_cast<uint32>(bytes[2]) << 8)
| (static_cast<uint32>(bytes[3]) << 16); | (static_cast<uint32>(bytes[3]) << 16);
return (ints >= 0x7F) ? ((ints << 2) + 4) : kInvalidSize; return (ints >= 0x7F) ? (int(ints << 2) + 4) : kInvalidSize;
} else if (first > 0 && first < 0x7F) { } else if (first > 0 && first < 0x7F) {
const auto ints = uint32(first); const auto ints = uint32(first);
return (ints << 2) + 1; return int(ints << 2) + 1;
} }
return kInvalidSize; return kInvalidSize;
} }
@ -172,7 +173,7 @@ public:
bytes::span finalizePacket(mtpBuffer &buffer) override; bytes::span finalizePacket(mtpBuffer &buffer) override;
uint32 readPacketLength(bytes::const_span bytes) const override; int readPacketLength(bytes::const_span bytes) const override;
bytes::const_span readPacket(bytes::const_span bytes) const override; bytes::const_span readPacket(bytes::const_span bytes) const override;
}; };
@ -200,13 +201,15 @@ bytes::span TcpConnection::Protocol::VersionD::finalizePacket(
return bytes::make_span(buffer).subspan(4, 4 + bytesSize); return bytes::make_span(buffer).subspan(4, 4 + bytesSize);
} }
uint32 TcpConnection::Protocol::VersionD::readPacketLength( int TcpConnection::Protocol::VersionD::readPacketLength(
bytes::const_span bytes) const { bytes::const_span bytes) const {
if (bytes.size() < 4) { if (bytes.size() < 4) {
return kUnknownSize; return kUnknownSize;
} }
const auto value = *reinterpret_cast<const uint32*>(bytes.data()) + 4; const auto value = *reinterpret_cast<const uint32*>(bytes.data()) + 4;
return (value >= 8 && value < kPacketSizeMax) ? value : kInvalidSize; return (value >= 8 && value < kPacketSizeMax)
? int(value)
: kInvalidSize;
} }
bytes::const_span TcpConnection::Protocol::VersionD::readPacket( bytes::const_span TcpConnection::Protocol::VersionD::readPacket(
@ -234,7 +237,6 @@ auto TcpConnection::Protocol::Create(bytes::vector &&secret)
TcpConnection::TcpConnection(QThread *thread, const ProxyData &proxy) TcpConnection::TcpConnection(QThread *thread, const ProxyData &proxy)
: AbstractConnection(thread, proxy) : AbstractConnection(thread, proxy)
, _currentPosition(reinterpret_cast<char*>(_shortBuffer))
, _checkNonce(rand_value<MTPint128>()) { , _checkNonce(rand_value<MTPint128>()) {
_socket.moveToThread(thread); _socket.moveToThread(thread);
_socket.setProxy(ToNetworkProxy(proxy)); _socket.setProxy(ToNetworkProxy(proxy));
@ -265,6 +267,8 @@ ConnectionPointer TcpConnection::clone(const ProxyData &proxy) {
} }
void TcpConnection::socketRead() { void TcpConnection::socketRead() {
Expects(_leftBytes > 0 || !_usingLargeBuffer);
if (_socket.state() != QAbstractSocket::ConnectedState) { if (_socket.state() != QAbstractSocket::ConnectedState) {
LOG(("MTP error: " LOG(("MTP error: "
"socket not connected in socketRead(), state: %1" "socket not connected in socketRead(), state: %1"
@ -273,93 +277,101 @@ void TcpConnection::socketRead() {
return; return;
} }
if (_smallBuffer.empty()) {
_smallBuffer.resize(kSmallBufferSize);
}
do { do {
uint32 toRead = _packetLeft const auto readLimit = (_leftBytes > 0)
? _packetLeft ? _leftBytes
: (_readingToShort : (kSmallBufferSize - _offsetBytes - _readBytes);
? (kShortBufferSize * sizeof(mtpPrime) - _packetRead) Assert(readLimit > 0);
: 4);
if (_readingToShort) {
if (_currentPosition + toRead > ((char*)_shortBuffer) + kShortBufferSize * sizeof(mtpPrime)) {
_longBuffer.resize(((_packetRead + toRead) >> 2) + 1);
memcpy(&_longBuffer[0], _shortBuffer, _packetRead);
_currentPosition = ((char*)&_longBuffer[0]) + _packetRead;
_readingToShort = false;
}
} else {
if (_longBuffer.size() * sizeof(mtpPrime) < _packetRead + toRead) {
_longBuffer.resize(((_packetRead + toRead) >> 2) + 1);
_currentPosition = ((char*)&_longBuffer[0]) + _packetRead;
}
}
int32 bytes = (int32)_socket.read(_currentPosition, toRead);
if (bytes > 0) {
aesCtrEncrypt(
bytes::make_span(_currentPosition, bytes),
_receiveKey,
&_receiveState);
TCP_LOG(("TCP Info: read %1 bytes").arg(bytes));
_packetRead += bytes; auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer;
_currentPosition += bytes; const auto full = bytes::make_span(buffer).subspan(_offsetBytes);
if (_packetLeft) { const auto free = full.subspan(_readBytes);
_packetLeft -= bytes; Assert(free.size() >= readLimit);
if (!_packetLeft) {
socketPacket(bytes::make_span( const auto readCount = _socket.read(
_currentPosition - _packetRead, reinterpret_cast<char*>(free.data()),
_packetRead)); readLimit);
_currentPosition = (char*)_shortBuffer; if (readCount > 0) {
_packetRead = _packetLeft = 0; const auto read = free.subspan(0, readCount);
_readingToShort = true; aesCtrEncrypt(read, _receiveKey, &_receiveState);
_longBuffer.clear(); TCP_LOG(("TCP Info: read %1 bytes").arg(readCount));
_readBytes += readCount;
if (_leftBytes > 0) {
Assert(readCount <= _leftBytes);
_leftBytes -= readCount;
if (!_leftBytes) {
socketPacket(full.subspan(0, _readBytes));
_usingLargeBuffer = false;
_largeBuffer.clear();
_offsetBytes = _readBytes = 0;
} else { } else {
TCP_LOG(("TCP Info: not enough %1 for packet! read %2" TCP_LOG(("TCP Info: not enough %1 for packet! read %2"
).arg(_packetLeft ).arg(_leftBytes
).arg(_packetRead)); ).arg(_readBytes));
emit receivedSome(); emit receivedSome();
} }
} else { } else {
bool move = false; auto available = full.subspan(0, _readBytes);
while (_packetRead >= 4) { while (_readBytes > 0) {
const auto packetSize = _protocol->readPacketLength( const auto packetSize = _protocol->readPacketLength(
bytes::make_span( available);
_currentPosition - _packetRead, if (packetSize == Protocol::kUnknownSize) {
_packetRead)); // Not enough bytes yet.
if (packetSize == Protocol::kUnknownSize break;
|| packetSize == Protocol::kInvalidSize) { } else if (packetSize <= 0) {
LOG(("TCP Error: packet size = %1").arg(packetSize)); LOG(("TCP Error: bad packet size in 4 bytes: %1"
).arg(packetSize));
emit error(kErrorCodeOther); emit error(kErrorCodeOther);
return; return;
} } else if (available.size() >= packetSize) {
if (_packetRead >= packetSize) { socketPacket(available.subspan(0, packetSize));
socketPacket(bytes::make_span( available = available.subspan(packetSize);
_currentPosition - _packetRead, _offsetBytes += packetSize;
packetSize)); _readBytes -= packetSize;
_packetRead -= packetSize;
_packetLeft = 0;
move = true;
} else { } else {
_packetLeft = packetSize - _packetRead; _leftBytes = packetSize - available.size();
TCP_LOG(("TCP Info: not enough %1 for packet! size %2 read %3").arg(_packetLeft).arg(packetSize).arg(_packetRead));
// If the next packet won't fit in the buffer.
const auto full = bytes::make_span(buffer).subspan(
_offsetBytes);
if (full.size() < packetSize) {
const auto read = full.subspan(0, _readBytes);
if (packetSize <= _smallBuffer.size()) {
if (_usingLargeBuffer) {
bytes::copy(_smallBuffer, read);
_usingLargeBuffer = false;
_largeBuffer.clear();
} else {
bytes::move(_smallBuffer, read);
}
} else if (packetSize <= _largeBuffer.size()) {
Assert(_usingLargeBuffer);
bytes::move(_largeBuffer, read);
} else {
auto enough = bytes::vector(packetSize);
bytes::copy(enough, read);
_largeBuffer = std::move(enough);
_usingLargeBuffer = true;
}
_offsetBytes = 0;
}
TCP_LOG(("TCP Info: not enough %1 for packet! "
"full size %2 read %3"
).arg(_leftBytes
).arg(packetSize
).arg(available.size()));
emit receivedSome(); emit receivedSome();
break; break;
} }
} }
if (move) {
if (!_packetRead) {
_currentPosition = (char*)_shortBuffer;
_readingToShort = true;
_longBuffer.clear();
} else if (!_readingToShort && _packetRead < kShortBufferSize * sizeof(mtpPrime)) {
memcpy(_shortBuffer, _currentPosition - _packetRead, _packetRead);
_currentPosition = (char*)_shortBuffer + _packetRead;
_readingToShort = true;
_longBuffer.clear();
}
}
} }
} else if (bytes < 0) { } else if (readCount < 0) {
LOG(("TCP Error: socket read return -1")); LOG(("TCP Error: socket read return %1").arg(readCount));
emit error(kErrorCodeOther); emit error(kErrorCodeOther);
return; return;
} else { } else {
@ -527,15 +539,14 @@ void TcpConnection::writeConnectionStart() {
} }
void TcpConnection::sendBuffer(mtpBuffer &&buffer) { void TcpConnection::sendBuffer(mtpBuffer &&buffer) {
if (!_packetIndex++) { if (!_connectionStarted) {
writeConnectionStart(); writeConnectionStart();
_connectionStarted = true;
} }
// buffer: 2 available int-s + data + available int. // buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer); const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write %1 packet %2" TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
).arg(_packetIndex
).arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState); aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket.write( _socket.write(
reinterpret_cast<const char*>(bytes.data()), reinterpret_cast<const char*>(bytes.data()),

View file

@ -47,7 +47,6 @@ private:
Ready, Ready,
Finished, Finished,
}; };
static constexpr auto kShortBufferSize = 65535; // Of ints, 256 kb.
void socketRead(); void socketRead();
void writeConnectionStart(); void writeConnectionStart();
@ -68,14 +67,14 @@ private:
void sendBuffer(mtpBuffer &&buffer); void sendBuffer(mtpBuffer &&buffer);
QTcpSocket _socket; QTcpSocket _socket;
uint32 _packetIndex = 0; // sent packet number bool _connectionStarted = false;
uint32 _packetRead = 0; int _offsetBytes = 0;
uint32 _packetLeft = 0; // reading from socket int _readBytes = 0;
bool _readingToShort = true; int _leftBytes = 0;
mtpBuffer _longBuffer; bytes::vector _smallBuffer;
mtpPrime _shortBuffer[kShortBufferSize]; bytes::vector _largeBuffer;
char *_currentPosition = nullptr; bool _usingLargeBuffer = false;
uchar _sendKey[CTRState::KeySize]; uchar _sendKey[CTRState::KeySize];
CTRState _sendState; CTRState _sendState;