using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Npgsql.BackendMessages;
using Npgsql.Internal;
using InfiniteTimeout = System.Threading.Timeout;
using static Npgsql.Util.Statics;
#pragma warning disable 1591
namespace Npgsql;
///
/// Provides an API for a raw binary COPY operation, a high-performance data import/export mechanism to
/// a PostgreSQL table. Initiated by
///
///
/// See https://www.postgresql.org/docs/current/static/sql-copy.html.
///
public sealed class NpgsqlRawCopyStream : Stream, ICancelable
{
#region Fields and Properties
NpgsqlConnector _connector;
NpgsqlReadBuffer _readBuf;
NpgsqlWriteBuffer _writeBuf;
int _leftToReadInDataMsg;
CopyStreamState _state = CopyStreamState.Uninitialized;
bool _canRead;
bool _canWrite;
internal bool IsBinary { get; private set; }
public override bool CanWrite => _canWrite;
public override bool CanRead => _canRead;
public override bool CanTimeout => true;
public override int WriteTimeout
{
get => (int) _writeBuf.Timeout.TotalMilliseconds;
set => _writeBuf.Timeout = value > 0 ? TimeSpan.FromMilliseconds(value) : InfiniteTimeout.InfiniteTimeSpan;
}
public override int ReadTimeout
{
get => (int) _readBuf.Timeout.TotalMilliseconds;
set => _readBuf.Timeout = value > 0 ? TimeSpan.FromMilliseconds(value) : InfiniteTimeout.InfiniteTimeSpan;
}
///
/// The copy binary format header signature
///
internal static readonly byte[] BinarySignature =
[
(byte)'P',(byte)'G',(byte)'C',(byte)'O',(byte)'P',(byte)'Y',
(byte)'\n', 255, (byte)'\r', (byte)'\n', 0
];
readonly ILogger _copyLogger;
Activity? _activity;
#endregion
#region Constructor / Initializer
internal NpgsqlRawCopyStream(NpgsqlConnector connector)
{
_connector = connector;
_readBuf = connector.ReadBuffer;
_writeBuf = connector.WriteBuffer;
_copyLogger = connector.LoggingConfiguration.CopyLogger;
}
internal async Task Init(string copyCommand, bool async, bool? forExport, CancellationToken cancellationToken = default)
{
Debug.Assert(_activity is null);
_activity = _connector.TraceCopyStart(copyCommand, forExport switch
{
true => "COPY TO",
false => "COPY FROM",
null => "COPY",
});
try
{
await _connector.WriteQuery(copyCommand, async, cancellationToken).ConfigureAwait(false);
await _connector.Flush(async, cancellationToken).ConfigureAwait(false);
using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false);
var msg = await _connector.ReadMessage(async).ConfigureAwait(false);
switch (msg.Code)
{
case BackendMessageCode.CopyInResponse:
_state = CopyStreamState.Ready;
var copyInResponse = (CopyInResponseMessage)msg;
IsBinary = copyInResponse.IsBinary;
_canWrite = true;
_writeBuf.StartCopyMode();
TraceSetImport();
break;
case BackendMessageCode.CopyOutResponse:
_state = CopyStreamState.Ready;
var copyOutResponse = (CopyOutResponseMessage)msg;
IsBinary = copyOutResponse.IsBinary;
_canRead = true;
TraceSetExport();
break;
case BackendMessageCode.CommandComplete:
throw new InvalidOperationException(
"This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " +
"To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " +
"Note that your data has been successfully imported/exported.");
default:
throw _connector.UnexpectedMessageReceived(msg.Code);
}
}
catch (Exception e)
{
TraceSetException(e);
throw;
}
}
#endregion
#region Write
public override void Write(byte[] buffer, int offset, int count)
{
ValidateArguments(buffer, offset, count);
Write(new ReadOnlySpan(buffer, offset, count));
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateArguments(buffer, offset, count);
return WriteAsync(new Memory(buffer, offset, count), cancellationToken).AsTask();
}
public override void Write(ReadOnlySpan buffer)
{
CheckDisposed();
if (!CanWrite)
throw new InvalidOperationException("Stream not open for writing");
if (buffer.Length == 0) { return; }
if (buffer.Length <= _writeBuf.WriteSpaceLeft)
{
_writeBuf.WriteBytes(buffer);
return;
}
// Value is too big, flush.
Flush();
if (buffer.Length <= _writeBuf.WriteSpaceLeft)
{
_writeBuf.WriteBytes(buffer);
return;
}
// Value is too big even after a flush - bypass the buffer and write directly.
_writeBuf.DirectWrite(buffer);
}
public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default)
{
CheckDisposed();
if (!CanWrite)
throw new InvalidOperationException("Stream not open for writing");
cancellationToken.ThrowIfCancellationRequested();
return WriteAsyncInternal(buffer, cancellationToken);
async ValueTask WriteAsyncInternal(ReadOnlyMemory buffer, CancellationToken cancellationToken)
{
if (buffer.Length == 0)
return;
if (buffer.Length <= _writeBuf.WriteSpaceLeft)
{
_writeBuf.WriteBytes(buffer.Span);
return;
}
// Value is too big, flush.
await FlushAsync(true, cancellationToken).ConfigureAwait(false);
if (buffer.Length <= _writeBuf.WriteSpaceLeft)
{
_writeBuf.WriteBytes(buffer.Span);
return;
}
// Value is too big even after a flush - bypass the buffer and write directly.
await _writeBuf.DirectWrite(buffer, true, cancellationToken).ConfigureAwait(false);
}
}
public override void Flush() => FlushAsync(async: false).GetAwaiter().GetResult();
public override Task FlushAsync(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
return Task.FromCanceled(cancellationToken);
return FlushAsync(async: true, cancellationToken);
}
Task FlushAsync(bool async, CancellationToken cancellationToken = default)
{
CheckDisposed();
return _writeBuf.Flush(async, cancellationToken);
}
#endregion
#region Read
public override int Read(byte[] buffer, int offset, int count)
{
ValidateArguments(buffer, offset, count);
return Read(new Span(buffer, offset, count));
}
public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateArguments(buffer, offset, count);
return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask();
}
public override int Read(Span span)
{
CheckDisposed();
if (!CanRead)
throw new InvalidOperationException("Stream not open for reading");
var count = ReadCore(span.Length, false).GetAwaiter().GetResult();
if (count > 0)
_readBuf.ReadBytes(span.Slice(0, count));
return count;
}
public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken)
{
CheckDisposed();
if (!CanRead)
throw new InvalidOperationException("Stream not open for reading");
cancellationToken.ThrowIfCancellationRequested();
return ReadAsyncInternal();
async ValueTask ReadAsyncInternal()
{
var count = await ReadCore(buffer.Length, true, cancellationToken).ConfigureAwait(false);
if (count > 0)
_readBuf.ReadBytes(buffer.Slice(0, count).Span);
return count;
}
}
async ValueTask ReadCore(int count, bool async, CancellationToken cancellationToken = default)
{
if (_state == CopyStreamState.Consumed)
return 0;
using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false);
if (_leftToReadInDataMsg == 0)
{
IBackendMessage msg;
try
{
// We've consumed the current DataMessage (or haven't yet received the first),
// read the next message
msg = await _connector.ReadMessage(async).ConfigureAwait(false);
}
catch (Exception e)
{
if (_state != CopyStreamState.Disposed)
{
TraceSetException(e);
Cleanup();
}
throw;
}
switch (msg.Code)
{
case BackendMessageCode.CopyData:
_leftToReadInDataMsg = ((CopyDataMessage)msg).Length;
break;
case BackendMessageCode.CopyDone:
Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector);
Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector);
_state = CopyStreamState.Consumed;
return 0;
default:
throw _connector.UnexpectedMessageReceived(msg.Code);
}
}
Debug.Assert(_leftToReadInDataMsg > 0);
// If our buffer is empty, read in more. Otherwise return whatever is there, even if the
// user asked for more (normal socket behavior)
if (_readBuf.ReadBytesLeft == 0)
await _readBuf.ReadMore(async).ConfigureAwait(false);
Debug.Assert(_readBuf.ReadBytesLeft > 0);
var maxCount = Math.Min(_readBuf.ReadBytesLeft, _leftToReadInDataMsg);
if (count > maxCount)
count = maxCount;
_leftToReadInDataMsg -= count;
return count;
}
#endregion
#region Cancel
///
/// Cancels and terminates an ongoing operation. Any data already written will be discarded.
///
public void Cancel() => Cancel(async: false).GetAwaiter().GetResult();
///
/// Cancels and terminates an ongoing operation. Any data already written will be discarded.
///
public Task CancelAsync() => Cancel(async: true);
async Task Cancel(bool async)
{
CheckDisposed();
if (CanWrite)
{
_writeBuf.EndCopyMode();
_writeBuf.Clear();
await _connector.WriteCopyFail(async).ConfigureAwait(false);
await _connector.Flush(async).ConfigureAwait(false);
try
{
var msg = await _connector.ReadMessage(async).ConfigureAwait(false);
// The CopyFail should immediately trigger an exception from the read above.
throw _connector.Break(
new NpgsqlException("Expected ErrorResponse when cancelling COPY but got: " + msg.Code));
}
catch (PostgresException e)
{
// TODO: NpgsqlBinaryImporter doesn't cleanup on cancellation
// And instead relies on users disposing the object
// We probably should do the same here
Cleanup();
if (e.SqlState != PostgresErrorCodes.QueryCanceled)
{
TraceSetException(e);
throw;
}
TraceStop();
}
}
else
{
_connector.PerformPostgresCancellation();
}
}
#endregion
#region Dispose
protected override void Dispose(bool disposing) => DisposeAsync(disposing, false).GetAwaiter().GetResult();
public override ValueTask DisposeAsync()
=> DisposeAsync(disposing: true, async: true);
async ValueTask DisposeAsync(bool disposing, bool async)
{
if (_state == CopyStreamState.Disposed || !disposing)
return;
try
{
_connector.CurrentCopyOperation = null;
if (CanWrite)
{
try
{
await FlushAsync(async).ConfigureAwait(false);
_writeBuf.EndCopyMode();
await _connector.WriteCopyDone(async).ConfigureAwait(false);
await _connector.Flush(async).ConfigureAwait(false);
Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector);
Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector);
TraceStop();
}
catch (Exception e)
{
TraceSetException(e);
throw;
}
}
else
{
try
{
if (_state != CopyStreamState.Consumed && _state != CopyStreamState.Uninitialized)
{
if (_leftToReadInDataMsg > 0)
{
await _readBuf.Skip(async, _leftToReadInDataMsg).ConfigureAwait(false);
}
_connector.SkipUntil(BackendMessageCode.ReadyForQuery);
}
TraceStop();
}
catch (OperationCanceledException e) when (e.InnerException is PostgresException { SqlState: PostgresErrorCodes.QueryCanceled })
{
LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id);
TraceStop();
}
catch (Exception e)
{
LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e);
TraceSetException(e);
}
}
}
finally
{
Cleanup();
}
}
#pragma warning disable CS8625
void Cleanup()
{
Debug.Assert(_state != CopyStreamState.Disposed);
LogMessages.CopyOperationCompleted(_copyLogger, _connector.Id);
_connector.EndUserAction();
_connector.CurrentCopyOperation = null;
_connector = null;
_readBuf = null;
_writeBuf = null;
_state = CopyStreamState.Disposed;
}
#pragma warning restore CS8625
void CheckDisposed()
{
if (_state == CopyStreamState.Disposed) {
throw new ObjectDisposedException(nameof(NpgsqlRawCopyStream), "The COPY operation has already ended.");
}
}
#endregion
#region Unsupported
public override bool CanSeek => false;
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
#endregion
#region Input validation
static void ValidateArguments(byte[] buffer, int offset, int count)
{
ArgumentNullException.ThrowIfNull(buffer);
ArgumentOutOfRangeException.ThrowIfNegative(offset);
ArgumentOutOfRangeException.ThrowIfNegative(count);
if (buffer.Length - offset < count)
ThrowHelper.ThrowArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection.");
}
#endregion
#region Tracing
private void TraceSetImport()
{
if (_activity is not null)
{
NpgsqlActivitySource.SetOperation(_activity, "COPY FROM");
}
}
private void TraceSetExport()
{
if (_activity is not null)
{
NpgsqlActivitySource.SetOperation(_activity, "COPY TO");
}
}
private void TraceStop()
{
if (_activity is not null)
{
NpgsqlActivitySource.CopyStop(_activity);
_activity = null;
}
}
private void TraceSetException(Exception e)
{
if (_activity is not null)
{
NpgsqlActivitySource.SetException(_activity, e);
_activity = null;
}
}
#endregion
#region Enums
enum CopyStreamState
{
Uninitialized,
Ready,
Consumed,
Disposed
}
#endregion Enums
}
///
/// Writer for a text import, initiated by .
///
///
/// See https://www.postgresql.org/docs/current/static/sql-copy.html.
///
public sealed class NpgsqlCopyTextWriter : StreamWriter, ICancelable
{
internal NpgsqlCopyTextWriter(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying)
{
if (underlying.IsBinary)
throw connector.Break(new Exception("Can't use a binary copy stream for text writing"));
}
///
/// Gets or sets a value, in milliseconds, that determines how long the text writer will attempt to write before timing out.
///
public int Timeout
{
get => ((NpgsqlRawCopyStream)BaseStream).WriteTimeout;
set
{
var stream = (NpgsqlRawCopyStream)BaseStream;
stream.ReadTimeout = value;
stream.WriteTimeout = value;
}
}
///
/// Cancels and terminates an ongoing import. Any data already written will be discarded.
///
public void Cancel()
=> ((NpgsqlRawCopyStream)BaseStream).Cancel();
///
/// Cancels and terminates an ongoing import. Any data already written will be discarded.
///
public Task CancelAsync() => ((NpgsqlRawCopyStream)BaseStream).CancelAsync();
}
///
/// Reader for a text export, initiated by .
///
///
/// See https://www.postgresql.org/docs/current/static/sql-copy.html.
///
public sealed class NpgsqlCopyTextReader : StreamReader, ICancelable
{
internal NpgsqlCopyTextReader(NpgsqlConnector connector, NpgsqlRawCopyStream underlying) : base(underlying)
{
if (underlying.IsBinary)
throw connector.Break(new Exception("Can't use a binary copy stream for text reading"));
}
///
/// Gets or sets a value, in milliseconds, that determines how long the text reader will attempt to read before timing out.
///
public int Timeout
{
get => ((NpgsqlRawCopyStream)BaseStream).ReadTimeout;
set
{
var stream = (NpgsqlRawCopyStream)BaseStream;
stream.ReadTimeout = value;
stream.WriteTimeout = value;
}
}
///
/// Cancels and terminates an ongoing export.
///
public void Cancel()
=> ((NpgsqlRawCopyStream)BaseStream).Cancel();
///
/// Asynchronously cancels and terminates an ongoing export.
///
public Task CancelAsync() => ((NpgsqlRawCopyStream)BaseStream).CancelAsync();
public ValueTask DisposeAsync()
{
Dispose();
return default;
}
}