using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.IO;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Util;
using static System.Threading.Timeout;
namespace Npgsql.Internal;
///
/// A buffer used by Npgsql to write data to the socket efficiently.
/// Provides methods which encode different values types and tracks the current position.
///
sealed class NpgsqlWriteBuffer : IDisposable
{
#region Fields and Properties
internal static readonly UTF8Encoding UTF8Encoding = new(false, true);
internal static readonly UTF8Encoding RelaxedUTF8Encoding = new(false, false);
internal readonly NpgsqlConnector Connector;
internal Stream Underlying { private get; set; }
readonly Socket? _underlyingSocket;
internal bool MessageLengthValidation { get; set; } = true;
readonly ResettableCancellationTokenSource _timeoutCts;
readonly MetricsReporter? _metricsReporter;
///
/// Timeout for sync and async writes
///
internal TimeSpan Timeout
{
get => _timeoutCts.Timeout;
set
{
if (_timeoutCts.Timeout != value)
{
Debug.Assert(_underlyingSocket != null);
if (value > TimeSpan.Zero)
{
_underlyingSocket.SendTimeout = (int)value.TotalMilliseconds;
_timeoutCts.Timeout = value;
}
else
{
_underlyingSocket.SendTimeout = -1;
_timeoutCts.Timeout = InfiniteTimeSpan;
}
}
}
}
///
/// The total byte length of the buffer.
///
internal int Size { get; private set; }
bool _copyMode;
internal Encoding TextEncoding { get; }
public int WriteSpaceLeft => Size - WritePosition;
// (Re)init to make sure we'll refetch from the write buffer.
internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode = FlushMode.None)
=> _pgWriter.Init(typeCatalog, flushMode);
internal readonly byte[] Buffer;
readonly Encoder _textEncoder;
internal int WritePosition;
int _messageBytesFlushed;
int? _messageLength;
bool _disposed;
readonly PgWriter _pgWriter;
Span Span => Buffer.AsSpan(WritePosition, WriteSpaceLeft);
///
/// The minimum buffer size possible.
///
internal const int MinimumSize = 4096;
internal const int DefaultSize = 8192;
#endregion
#region Constructors
internal NpgsqlWriteBuffer(
NpgsqlConnector? connector,
Stream stream,
Socket? socket,
int size,
Encoding textEncoding)
{
ArgumentOutOfRangeException.ThrowIfLessThan(size, MinimumSize);
Connector = connector!; // TODO: Clean this up; only null when used from PregeneratedMessages, where we don't care.
Underlying = stream;
_underlyingSocket = socket;
_metricsReporter = connector?.DataSource.MetricsReporter!;
_timeoutCts = new ResettableCancellationTokenSource();
Buffer = new byte[size];
Size = size;
TextEncoding = textEncoding;
_textEncoder = TextEncoding.GetEncoder();
_pgWriter = new PgWriter(new NpgsqlBufferWriter(this));
}
#endregion
#region I/O
public async Task Flush(bool async, CancellationToken cancellationToken = default)
{
if (_copyMode)
{
// In copy mode, we write CopyData messages. The message code has already been
// written to the beginning of the buffer, but we need to go back and write the
// length.
if (WritePosition == 1)
return;
var pos = WritePosition;
WritePosition = 1;
WriteInt32(pos - 1);
WritePosition = pos;
} else if (WritePosition == 0)
return;
else
AdvanceMessageBytesFlushed(WritePosition);
var finalCt = async && Timeout > TimeSpan.Zero
? _timeoutCts.Start(cancellationToken)
: cancellationToken;
try
{
if (async)
{
await Underlying.WriteAsync(Buffer, 0, WritePosition, finalCt).ConfigureAwait(false);
await Underlying.FlushAsync(finalCt).ConfigureAwait(false);
if (Timeout > TimeSpan.Zero)
_timeoutCts.Stop();
}
else
{
Underlying.Write(Buffer, 0, WritePosition);
Underlying.Flush();
}
}
catch (Exception ex)
{
// Stopping twice (in case the previous Stop() call succeeded) doesn't hurt.
// Not stopping will cause an assertion failure in debug mode when we call Start() the next time.
// We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained
// _timeoutCts
_timeoutCts.Stop();
switch (ex)
{
// User requested the cancellation
case OperationCanceledException when cancellationToken.IsCancellationRequested:
throw Connector.Break(ex);
// Read timeout
case OperationCanceledException:
case IOException { InnerException: SocketException { SocketErrorCode: SocketError.TimedOut } }:
Debug.Assert(ex is OperationCanceledException ? async : !async);
throw Connector.Break(new NpgsqlException("Exception while writing to stream", new TimeoutException("Timeout during writing attempt")));
}
throw Connector.Break(new NpgsqlException("Exception while writing to stream", ex));
}
NpgsqlEventSource.Log.BytesWritten(WritePosition);
_metricsReporter?.ReportBytesWritten(WritePosition);
WritePosition = 0;
if (_copyMode)
WriteCopyDataHeader();
}
internal void Flush() => Flush(false).GetAwaiter().GetResult();
#endregion
#region Direct write
internal void DirectWrite(ReadOnlySpan buffer)
{
Flush();
if (_copyMode)
{
// Flush has already written the CopyData header for us, but write the CopyData
// header to the socket with the write length before we can start writing the data directly.
Debug.Assert(WritePosition == 5);
WritePosition = 1;
WriteInt32(checked(buffer.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
Flush();
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(buffer.Length);
}
try
{
Underlying.Write(buffer);
}
catch (Exception e)
{
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
}
internal async Task DirectWrite(ReadOnlyMemory memory, bool async, CancellationToken cancellationToken = default)
{
await Flush(async, cancellationToken).ConfigureAwait(false);
if (_copyMode)
{
// Flush has already written the CopyData header for us, but write the CopyData
// header to the socket with the write length before we can start writing the data directly.
Debug.Assert(WritePosition == 5);
WritePosition = 1;
WriteInt32(checked(memory.Length + 4));
WritePosition = 5;
_copyMode = false;
StartMessage(5);
await Flush(async, cancellationToken).ConfigureAwait(false);
_copyMode = true;
WriteCopyDataHeader(); // And ready the buffer after the direct write completes
}
else
{
Debug.Assert(WritePosition == 0);
AdvanceMessageBytesFlushed(memory.Length);
}
try
{
if (async)
await Underlying.WriteAsync(memory, cancellationToken).ConfigureAwait(false);
else
Underlying.Write(memory.Span);
}
catch (Exception e)
{
throw Connector.Break(new NpgsqlException("Exception while writing to stream", e));
}
}
#endregion Direct write
#region Write Simple
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteByte(byte value)
{
CheckBounds();
Buffer[WritePosition] = value;
WritePosition += sizeof(byte);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt16(short value)
{
CheckBounds();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value);
WritePosition += sizeof(short);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt16(ushort value)
{
CheckBounds();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value);
WritePosition += sizeof(ushort);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt32(int value)
{
CheckBounds();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value);
WritePosition += sizeof(int);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteUInt32(uint value)
{
CheckBounds();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value);
WritePosition += sizeof(uint);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteInt64(long value)
{
CheckBounds();
Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value);
WritePosition += sizeof(long);
}
[Conditional("DEBUG")]
unsafe void CheckBounds() where T : unmanaged
{
if (sizeof(T) > WriteSpaceLeft)
ThrowNotSpaceLeft();
}
static void ThrowNotSpaceLeft()
=> ThrowHelper.ThrowInvalidOperationException("There is not enough space left in the buffer.");
public Task WriteString(string s, int byteLen, bool async, CancellationToken cancellationToken = default)
{
if (byteLen <= WriteSpaceLeft)
{
WriteString(s);
return Task.CompletedTask;
}
return WriteStringLong(this, async, s, byteLen, cancellationToken);
static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int byteLen, CancellationToken cancellationToken)
{
Debug.Assert(byteLen > buffer.WriteSpaceLeft);
if (byteLen <= buffer.Size)
{
// String can fit entirely in an empty buffer. Flush and retry rather than
// going into the partial writing flow below
await buffer.Flush(async, cancellationToken).ConfigureAwait(false);
buffer.WriteString(s);
}
else
{
var encoder = buffer._textEncoder;
encoder.Reset();
var data = s.AsMemory();
var minBufferSize = buffer.TextEncoding.GetMaxByteCount(1);
bool completed;
do
{
if (buffer.WriteSpaceLeft < minBufferSize)
await buffer.Flush(async, cancellationToken).ConfigureAwait(false);
encoder.Convert(data.Span, buffer.Span, flush: true, out var charsUsed, out var bytesUsed, out completed);
data = data.Slice(charsUsed);
buffer.WritePosition += bytesUsed;
} while (!completed);
}
}
}
public void WriteString(string s)
{
Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft);
WritePosition += TextEncoding.GetBytes(s, 0, s.Length, Buffer, WritePosition);
}
public void WriteBytes(ReadOnlySpan buf)
{
Debug.Assert(buf.Length <= WriteSpaceLeft);
buf.CopyTo(new Span(Buffer, WritePosition, Buffer.Length - WritePosition));
WritePosition += buf.Length;
}
public void WriteBytes(ReadOnlyMemory buf)
=> WriteBytes(buf.Span);
public void WriteBytes(byte[] buf) => WriteBytes(buf.AsSpan());
public void WriteBytes(byte[] buf, int offset, int count)
=> WriteBytes(new ReadOnlySpan(buf, offset, count));
public Task WriteBytesRaw(ReadOnlyMemory bytes, bool async, CancellationToken cancellationToken = default)
{
if (bytes.Length <= WriteSpaceLeft)
{
WriteBytes(bytes);
return Task.CompletedTask;
}
return WriteBytesLong(this, async, bytes, cancellationToken);
static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, ReadOnlyMemory bytes, CancellationToken cancellationToken)
{
if (bytes.Length <= buffer.Size)
{
// value can fit entirely in an empty buffer. Flush and retry rather than
// going into the partial writing flow below
await buffer.Flush(async, cancellationToken).ConfigureAwait(false);
buffer.WriteBytes(bytes);
}
else
{
var remaining = bytes.Length;
do
{
if (buffer.WriteSpaceLeft == 0)
await buffer.Flush(async, cancellationToken).ConfigureAwait(false);
var writeLen = Math.Min(remaining, buffer.WriteSpaceLeft);
var offset = bytes.Length - remaining;
buffer.WriteBytes(bytes.Slice(offset, writeLen));
remaining -= writeLen;
}
while (remaining > 0);
}
}
}
public void WriteNullTerminatedString(string s)
{
AssertASCIIOnly(s);
Debug.Assert(WriteSpaceLeft >= s.Length + 1);
WritePosition += Encoding.ASCII.GetBytes(s, 0, s.Length, Buffer, WritePosition);
WriteByte(0);
}
public void WriteNullTerminatedString(byte[] s)
{
AssertASCIIOnly(s);
Debug.Assert(WriteSpaceLeft >= s.Length + 1);
WriteBytes(s);
WriteByte(0);
}
#endregion
#region Copy
internal void StartCopyMode()
{
_copyMode = true;
Size -= 5;
WriteCopyDataHeader();
}
internal void EndCopyMode()
{
// EndCopyMode is usually called after a Flush which ended the last CopyData message.
// That Flush also wrote the header for another CopyData which we clear here.
_copyMode = false;
Size += 5;
Clear();
}
void WriteCopyDataHeader()
{
Debug.Assert(_copyMode);
Debug.Assert(WritePosition == 0);
WriteByte(FrontendMessageCode.CopyData);
// Leave space for the message length
WriteInt32(0);
}
#endregion
#region Dispose
public void Dispose()
{
if (_disposed)
return;
_timeoutCts.Dispose();
_disposed = true;
}
#endregion
#region Misc
internal void StartMessage(int messageLength)
{
if (!MessageLengthValidation)
return;
if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
Throw();
// Add negative WritePosition to compensate for previous message(s) written without flushing.
_messageBytesFlushed = -WritePosition;
_messageLength = messageLength;
void Throw()
{
throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
}
}
void AdvanceMessageBytesFlushed(int count)
{
if (!MessageLengthValidation)
return;
if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
Throw();
_messageBytesFlushed += count;
void Throw()
{
ArgumentOutOfRangeException.ThrowIfNegative(count);
if (_messageLength is null)
throw Connector.Break(new InvalidOperationException("No message was started"));
if ((long)_messageBytesFlushed + count > _messageLength)
throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
}
}
internal void Clear()
{
WritePosition = 0;
_messageLength = null;
}
///
/// Returns all contents currently written to the buffer (but not flushed).
/// Useful for pre-generating messages.
///
internal byte[] GetContents()
{
var buf = new byte[WritePosition];
Array.Copy(Buffer, buf, WritePosition);
return buf;
}
[Conditional("DEBUG")]
internal static void AssertASCIIOnly(string s)
{
foreach (var c in s)
if (c >= 128)
Debug.Fail("Method only supports ASCII strings");
}
[Conditional("DEBUG")]
internal static void AssertASCIIOnly(byte[] s)
{
foreach (var c in s)
if (c >= 128)
Debug.Fail("Method only supports ASCII strings");
}
#endregion
}