using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Util;
using static System.Threading.Timeout;
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
namespace Npgsql.Internal;
///
/// A buffer used by Npgsql to write data to the socket efficiently.
/// Provides methods which encode different values types and tracks the current position.
///
public sealed partial class NpgsqlWriteBuffer : IDisposable
{
#region Fields and Properties
internal readonly NpgsqlConnector Connector;
internal Stream Underlying { private get; set; }
readonly Socket? _underlyingSocket;
readonly ResettableCancellationTokenSource _timeoutCts;
///
/// Timeout for sync and async writes
///
internal TimeSpan Timeout
{
get => _timeoutCts.Timeout;
set
{
if (_timeoutCts.Timeout != value)
{
Debug.Assert(_underlyingSocket != null);
if (value > TimeSpan.Zero)
{
_underlyingSocket.SendTimeout = (int)value.TotalMilliseconds;
_timeoutCts.Timeout = value;
}
else
{
_underlyingSocket.SendTimeout = -1;
_timeoutCts.Timeout = InfiniteTimeSpan;
}
}
}
}
///
/// The total byte length of the buffer.
///
internal int Size { get; private set; }
bool _copyMode;
internal Encoding TextEncoding { get; }
public int WriteSpaceLeft => Size - WritePosition;
internal readonly byte[] Buffer;
readonly Encoder _textEncoder;
internal int WritePosition;
ParameterStream? _parameterStream;
bool _disposed;
///
/// The minimum buffer size possible.
///
internal const int MinimumSize = 4096;
internal const int DefaultSize = 8192;
#endregion
#region Constructors
internal NpgsqlWriteBuffer(
NpgsqlConnector connector,
Stream stream,
Socket? socket,
int size,
Encoding textEncoding)
{
if (size < MinimumSize)
throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize);
Connector = connector;
Underlying = stream;
_underlyingSocket = socket;
_timeoutCts = new ResettableCancellationTokenSource();
Buffer = new byte[size];
Size = size;
TextEncoding = textEncoding;
_textEncoder = TextEncoding.GetEncoder();
}
#endregion
#region I/O
public async Task Flush(bool async, CancellationToken cancellationToken = default)
{
if (_copyMode)
{
// In copy mode, we write CopyData messages. The message code has already been
// written to the beginning of the buffer, but we need to go back and write the
// length.
if (WritePosition == 1)
return;
var pos = WritePosition;
WritePosition = 1;
WriteInt32(pos - 1);
WritePosition = pos;
} else if (WritePosition == 0)
return;
var finalCt = cancellationToken;
if (async && Timeout > TimeSpan.Zero)
finalCt = _timeoutCts.Start(cancellationToken);
try
{
if (async)
{
await Underlying.WriteAsync(Buffer, 0, WritePosition, finalCt);
await Underlying.FlushAsync(finalCt);
_timeoutCts.Stop();
}
else
{
Underlying.Write(Buffer, 0, WritePosition);
Underlying.Flush();
}
}
catch (Exception e)
{
// Stopping twice (in case the previous Stop() call succeeded) doesn't hurt.
// Not stopping will cause an assertion failure in debug mode when we call Start() the next time.
// We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained
// _timeoutCts
_timeoutCts.Stop();
switch (e)
{
// User requested the cancellation
case OperationCanceledException _ when (cancellationToken.IsCancellationRequested):
throw Connector.Break(e);
// Read timeout
case OperationCanceledException _:
// Note that mono throws SocketException with the wrong error (see #1330)
case IOException _ when (e.InnerException as SocketException)?.SocketErrorCode ==
(Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock):
Debug.Assert(e is OperationCanceledException ? async : !async);
throw Connector.Break(new NpgsqlException("Exception while writing to stream", new TimeoutException("Timeout during writing attempt")));
}
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
NpgsqlEventSource.Log.BytesWritten(WritePosition);
//NpgsqlEventSource.Log.RequestFailed();
WritePosition = 0;
if (_copyMode)
WriteCopyDataHeader();
}
internal void Flush() => Flush(false).GetAwaiter().GetResult();
#endregion
#region Direct write
internal void DirectWrite(ReadOnlySpan buffer)
{
Flush();
if (_copyMode)
{
// Flush has already written the CopyData header for us, but write the CopyData
// header to the socket with the write length before we can start writing the data directly.
Debug.Assert(WritePosition == 5);
WritePosition = 1;
WriteInt32(buffer.Length + 4);
WritePosition = 5;
_copyMode = false;
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
Debug.Assert(WritePosition == 0);
try
{
Underlying.Write(buffer);
}
catch (Exception e)
{
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
}
internal async Task DirectWrite(ReadOnlyMemory memory, bool async, CancellationToken cancellationToken = default)
{
await Flush(async, cancellationToken);
if (_copyMode)
{
// Flush has already written the CopyData header for us, but write the CopyData
// header to the socket with the write length before we can start writing the data directly.
Debug.Assert(WritePosition == 5);
WritePosition = 1;
WriteInt32(memory.Length + 4);
WritePosition = 5;
_copyMode = false;
await Flush(async, cancellationToken);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
Debug.Assert(WritePosition == 0);
try
{
if (async)
await Underlying.WriteAsync(memory, cancellationToken);
else
Underlying.Write(memory.Span);
}
catch (Exception e)
{
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
}
#endregion Direct write
#region Write Simple
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteSByte(sbyte value) => Write(value);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteByte(byte value) => Write(value);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void WriteInt16(int value)
=> WriteInt16((short)value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt16(short value)
=> WriteInt16(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt16(short value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt16(ushort value)
=> WriteUInt16(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt16(ushort value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt32(int value)
=> WriteInt32(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt32(int value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt32(uint value)
=> WriteUInt32(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt32(uint value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt64(long value)
=> WriteInt64(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt64(long value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt64(ulong value)
=> WriteUInt64(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt64(ulong value, bool littleEndian)
=> Write(littleEndian == BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteSingle(float value)
=> WriteSingle(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteSingle(float value, bool littleEndian)
=> WriteInt32(Unsafe.As(ref value), littleEndian);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteDouble(double value)
=> WriteDouble(value, false);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteDouble(double value, bool littleEndian)
=> WriteInt64(Unsafe.As(ref value), littleEndian);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
void Write(T value)
{
if (Unsafe.SizeOf() > WriteSpaceLeft)
ThrowNotSpaceLeft();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], value);
WritePosition += Unsafe.SizeOf();
}
[MethodImpl(MethodImplOptions.NoInlining)]
static void ThrowNotSpaceLeft()
=> throw new InvalidOperationException("There is not enough space left in the buffer.");
public Task WriteString(string s, int byteLen, bool async, CancellationToken cancellationToken = default)
=> WriteString(s, s.Length, byteLen, async, cancellationToken);
public Task WriteString(string s, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default)
{
if (byteLen <= WriteSpaceLeft)
{
WriteString(s, charLen);
return Task.CompletedTask;
}
return WriteStringLong(this, async, s, charLen, byteLen, cancellationToken);
static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int charLen, int byteLen, CancellationToken cancellationToken)
{
Debug.Assert(byteLen > buffer.WriteSpaceLeft);
if (byteLen <= buffer.Size)
{
// String can fit entirely in an empty buffer. Flush and retry rather than
// going into the partial writing flow below (which requires ToCharArray())
await buffer.Flush(async, cancellationToken);
buffer.WriteString(s, charLen);
}
else
{
var charPos = 0;
while (true)
{
buffer.WriteStringChunked(s, charPos, charLen - charPos, true, out var charsUsed, out var completed);
if (completed)
break;
await buffer.Flush(async, cancellationToken);
charPos += charsUsed;
}
}
}
}
internal Task WriteChars(char[] chars, int offset, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default)
{
if (byteLen <= WriteSpaceLeft)
{
WriteChars(chars, offset, charLen);
return Task.CompletedTask;
}
return WriteCharsLong(this, async, chars, offset, charLen, byteLen, cancellationToken);
static async Task WriteCharsLong(NpgsqlWriteBuffer buffer, bool async, char[] chars, int offset, int charLen, int byteLen, CancellationToken cancellationToken)
{
Debug.Assert(byteLen > buffer.WriteSpaceLeft);
if (byteLen <= buffer.Size)
{
// String can fit entirely in an empty buffer. Flush and retry rather than
// going into the partial writing flow below (which requires ToCharArray())
await buffer.Flush(async, cancellationToken);
buffer.WriteChars(chars, offset, charLen);
}
else
{
var charPos = 0;
while (true)
{
buffer.WriteStringChunked(chars, charPos + offset, charLen - charPos, true, out var charsUsed, out var completed);
if (completed)
break;
await buffer.Flush(async, cancellationToken);
charPos += charsUsed;
}
}
}
}
public void WriteString(string s, int len = 0)
{
Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft);
WritePosition += TextEncoding.GetBytes(s, 0, len == 0 ? s.Length : len, Buffer, WritePosition);
}
internal void WriteChars(char[] chars, int offset, int len)
{
var charCount = len == 0 ? chars.Length : len;
Debug.Assert(TextEncoding.GetByteCount(chars, 0, charCount) <= WriteSpaceLeft);
WritePosition += TextEncoding.GetBytes(chars, offset, charCount, Buffer, WritePosition);
}
public void WriteBytes(ReadOnlySpan buf)
{
Debug.Assert(buf.Length <= WriteSpaceLeft);
buf.CopyTo(new Span(Buffer, WritePosition, Buffer.Length - WritePosition));
WritePosition += buf.Length;
}
public void WriteBytes(byte[] buf, int offset, int count)
=> WriteBytes(new ReadOnlySpan(buf, offset, count));
public Task WriteBytesRaw(byte[] bytes, bool async, CancellationToken cancellationToken = default)
{
if (bytes.Length <= WriteSpaceLeft)
{
WriteBytes(bytes);
return Task.CompletedTask;
}
return WriteBytesLong(this, async, bytes, cancellationToken);
static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, byte[] bytes, CancellationToken cancellationToken)
{
if (bytes.Length <= buffer.Size)
{
// value can fit entirely in an empty buffer. Flush and retry rather than
// going into the partial writing flow below
await buffer.Flush(async, cancellationToken);
buffer.WriteBytes(bytes);
}
else
{
var remaining = bytes.Length;
do
{
if (buffer.WriteSpaceLeft == 0)
await buffer.Flush(async, cancellationToken);
var writeLen = Math.Min(remaining, buffer.WriteSpaceLeft);
var offset = bytes.Length - remaining;
buffer.WriteBytes(bytes, offset, writeLen);
remaining -= writeLen;
}
while (remaining > 0);
}
}
}
public void WriteNullTerminatedString(string s)
{
Debug.Assert(s.All(c => c < 128), "Method only supports ASCII strings");
Debug.Assert(WriteSpaceLeft >= s.Length + 1);
WritePosition += Encoding.ASCII.GetBytes(s, 0, s.Length, Buffer, WritePosition);
WriteByte(0);
}
#endregion
#region Write Complex
public Stream GetStream()
{
if (_parameterStream == null)
_parameterStream = new ParameterStream(this);
_parameterStream.Init();
return _parameterStream;
}
internal void WriteStringChunked(char[] chars, int charIndex, int charCount,
bool flush, out int charsUsed, out bool completed)
{
if (WriteSpaceLeft < _textEncoder.GetByteCount(chars, charIndex, char.IsHighSurrogate(chars[charIndex]) ? 2 : 1, flush: false))
{
charsUsed = 0;
completed = false;
return;
}
_textEncoder.Convert(chars, charIndex, charCount, Buffer, WritePosition, WriteSpaceLeft,
flush, out charsUsed, out var bytesUsed, out completed);
WritePosition += bytesUsed;
}
internal unsafe void WriteStringChunked(string s, int charIndex, int charCount,
bool flush, out int charsUsed, out bool completed)
{
int bytesUsed;
fixed (char* sPtr = s)
fixed (byte* bufPtr = Buffer)
{
if (WriteSpaceLeft < _textEncoder.GetByteCount(sPtr + charIndex, char.IsHighSurrogate(*(sPtr + charIndex)) ? 2 : 1, flush: false))
{
charsUsed = 0;
completed = false;
return;
}
_textEncoder.Convert(sPtr + charIndex, charCount, bufPtr + WritePosition, WriteSpaceLeft,
flush, out charsUsed, out bytesUsed, out completed);
}
WritePosition += bytesUsed;
}
#endregion
#region Copy
internal void StartCopyMode()
{
_copyMode = true;
Size -= 5;
WriteCopyDataHeader();
}
internal void EndCopyMode()
{
// EndCopyMode is usually called after a Flush which ended the last CopyData message.
// That Flush also wrote the header for another CopyData which we clear here.
_copyMode = false;
Size += 5;
Clear();
}
void WriteCopyDataHeader()
{
Debug.Assert(_copyMode);
Debug.Assert(WritePosition == 0);
WriteByte(FrontendMessageCode.CopyData);
// Leave space for the message length
WriteInt32(0);
}
#endregion
#region Dispose
public void Dispose()
{
if (_disposed)
return;
_timeoutCts.Dispose();
_disposed = true;
}
#endregion
#region Misc
internal void Clear()
{
WritePosition = 0;
}
///
/// Returns all contents currently written to the buffer (but not flushed).
/// Useful for pre-generating messages.
///
internal byte[] GetContents()
{
var buf = new byte[WritePosition];
Array.Copy(Buffer, buf, WritePosition);
return buf;
}
#endregion
}