using System.Buffers;
using System.Data;
using System.Diagnostics;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Channels;
using System.Transactions;
using Npgsql.BackendMessages;
using Npgsql.Logging;
using Npgsql.TypeMapping;
using Npgsql.Util;
using static Npgsql.Util.Statics;
namespace Npgsql.Internal;
///
/// Represents a connection to a PostgreSQL backend. Unlike NpgsqlConnection objects, which are
/// exposed to users, connectors are internal to Npgsql and are recycled by the connection pool.
///
public sealed partial class NpgsqlConnector : IDisposable
{
#region Fields and Properties
///
/// The physical connection socket to the backend.
///
Socket _socket = default!;
///
/// The physical connection stream to the backend, without anything on top.
///
NetworkStream _baseStream = default!;
///
/// The physical connection stream to the backend, layered with an SSL/TLS stream if in secure mode.
///
Stream _stream = default!;
///
/// The parsed connection string.
///
public NpgsqlConnectionStringBuilder Settings { get; }
ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; }
RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; }
ProvidePasswordCallback? ProvidePasswordCallback { get; }
#pragma warning disable CA2252 // Experimental API
PhysicalOpenCallback? PhysicalOpenCallback { get; }
PhysicalOpenAsyncCallback? PhysicalOpenAsyncCallback { get; }
#pragma warning restore CA2252
public Encoding TextEncoding { get; private set; } = default!;
///
/// Same as , except that it does not throw an exception if an invalid char is
/// encountered (exception fallback), but rather replaces it with a question mark character (replacement
/// fallback).
///
internal Encoding RelaxedTextEncoding { get; private set; } = default!;
///
/// Buffer used for reading data.
///
internal NpgsqlReadBuffer ReadBuffer { get; private set; } = default!;
///
/// If we read a data row that's bigger than , we allocate an oversize buffer.
/// The original (smaller) buffer is stored here, and restored when the connection is reset.
///
NpgsqlReadBuffer? _origReadBuffer;
///
/// Buffer used for writing data.
///
internal NpgsqlWriteBuffer WriteBuffer { get; private set; } = default!;
///
/// The secret key of the backend for this connector, used for query cancellation.
///
int _backendSecretKey;
///
/// The process ID of the backend for this connector.
///
internal int BackendProcessId { get; private set; }
bool SupportsPostgresCancellation => BackendProcessId != 0;
///
/// A unique ID identifying this connector, used for logging. Currently mapped to BackendProcessId
///
internal int Id => BackendProcessId;
///
/// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...).
///
public NpgsqlDatabaseInfo DatabaseInfo { get; internal set; } = default!;
internal ConnectorTypeMapper TypeMapper { get; set; } = default!;
///
/// The current transaction status for this connector.
///
internal TransactionStatus TransactionStatus { get; set; }
///
/// A transaction object for this connector. Since only one transaction can be in progress at any given time,
/// this instance is recycled. To check whether a transaction is currently in progress on this connector,
/// see .
///
internal NpgsqlTransaction? Transaction { get; set; }
internal NpgsqlTransaction? UnboundTransaction { get; set; }
///
/// The NpgsqlConnection that (currently) owns this connector. Null if the connector isn't
/// owned (i.e. idle in the pool)
///
internal NpgsqlConnection? Connection { get; set; }
///
/// The number of messages that were prepended to the current message chain, but not yet sent.
/// Note that this only tracks messages which produce a ReadyForQuery message
///
internal int PendingPrependedResponses { get; set; }
///
/// A ManualResetEventSlim used to make sure a cancellation request doesn't run
/// while we're reading responses for the prepended query
/// as we can't gracefully handle their cancellation.
///
readonly ManualResetEventSlim ReadingPrependedMessagesMRE = new(initialState: true);
internal NpgsqlDataReader? CurrentReader;
internal PreparedStatementManager PreparedStatementManager { get; }
internal SqlQueryParser SqlQueryParser { get; } = new();
///
/// If the connector is currently in COPY mode, holds a reference to the importer/exporter object.
/// Otherwise null.
///
internal ICancelable? CurrentCopyOperation;
///
/// Holds all run-time parameters received from the backend (via ParameterStatus messages)
///
internal Dictionary PostgresParameters { get; }
///
/// Holds all run-time parameters in raw, binary format for efficient handling without allocations.
///
readonly List<(byte[] Name, byte[] Value)> _rawParameters = new();
///
/// If this connector was broken, this contains the exception that caused the break.
///
volatile Exception? _breakReason;
///
/// Semaphore, used to synchronize DatabaseInfo between multiple connections, so it wouldn't be loaded in parallel.
///
static readonly SemaphoreSlim DatabaseInfoSemaphore = new(1);
///
///
/// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write
/// isn't started concurrently. Note that since we have only one write loop, this is only ever usedto
/// protect against an over-capacity writes into a connector that's currently *asynchronously* writing.
///
///
/// It is guaranteed that the currently-executing
/// Specifically, reading may occur - and the connector may even be returned to the pool - before this is
/// released.
///
///
internal volatile int MultiplexAsyncWritingLock;
///
internal void FlagAsNotWritableForMultiplexing()
{
Debug.Assert(Settings.Multiplexing);
Debug.Assert(CommandsInFlightCount > 0 || IsBroken || IsClosed,
$"About to mark multiplexing connector as non-writable, but {nameof(CommandsInFlightCount)} is {CommandsInFlightCount}");
Interlocked.Exchange(ref MultiplexAsyncWritingLock, 1);
}
///
internal void FlagAsWritableForMultiplexing()
{
Debug.Assert(Settings.Multiplexing);
if (Interlocked.CompareExchange(ref MultiplexAsyncWritingLock, 0, 1) != 1)
throw new Exception("Multiplexing lock was not taken when releasing. Please report a bug.");
}
///
/// The timeout for reading messages that are part of the user's command
/// (i.e. which aren't internal prepended commands).
///
/// Precision is milliseconds
internal int UserTimeout { private get; set; }
///
/// A lock that's taken while a user action is in progress, e.g. a command being executed.
/// Only used when keepalive is enabled, otherwise null.
///
SemaphoreSlim? _userLock;
///
/// A lock that's taken while a cancellation is being delivered; new queries are blocked until the
/// cancellation is delivered. This reduces the chance that a cancellation meant for a previous
/// command will accidentally cancel a later one, see #615.
///
object CancelLock { get; } = new();
///
/// A lock that's taken to make sure no other concurrent operation is running.
/// Break takes it to set the state of the connector.
/// Anyone else should immediately check the state and exit
/// if the connector is closed.
///
object SyncObj { get; } = new();
///
/// A lock that's used to wait for the Cleanup to complete while breaking the connection.
///
object CleanupLock { get; } = new();
readonly bool _isKeepAliveEnabled;
readonly Timer? _keepAliveTimer;
///
/// The command currently being executed by the connector, null otherwise.
/// Used only for concurrent use error reporting purposes.
///
NpgsqlCommand? _currentCommand;
bool _sendResetOnClose;
///
/// The connector source (e.g. pool) from where this connector came, and to which it will be returned.
/// Note that in multi-host scenarios, this references the host-specific rather than the
/// ,
///
readonly ConnectorSource _connectorSource;
internal string UserFacingConnectionString => _connectorSource.UserFacingConnectionString;
///
/// Contains the UTC timestamp when this connector was opened, used to implement
/// .
///
internal DateTime OpenTimestamp { get; private set; }
internal int ClearCounter { get; set; }
volatile bool _postgresCancellationPerformed;
internal bool PostgresCancellationPerformed
{
get => _postgresCancellationPerformed;
private set => _postgresCancellationPerformed = value;
}
volatile bool _userCancellationRequested;
CancellationTokenRegistration _cancellationTokenRegistration;
internal bool UserCancellationRequested => _userCancellationRequested;
internal CancellationToken UserCancellationToken { get; set; }
internal bool AttemptPostgresCancellation { get; private set; }
static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.FromMilliseconds(-1);
X509Certificate2? _certificate;
static readonly NpgsqlLogger Log = NpgsqlLogManager.CreateLogger(nameof(NpgsqlConnector));
internal readonly Stopwatch QueryLogStopWatch = new();
internal EndPoint? ConnectedEndPoint { get; private set; }
#endregion
#region Constants
///
/// The minimum timeout that can be set on internal commands such as COMMIT, ROLLBACK.
///
/// Precision is seconds
internal const int MinimumInternalCommandTimeout = 3;
#endregion
#region Reusable Message Objects
byte[]? _resetWithoutDeallocateMessage;
int _resetWithoutDeallocateResponseCount;
// Backend
readonly CommandCompleteMessage _commandCompleteMessage = new();
readonly ReadyForQueryMessage _readyForQueryMessage = new();
readonly ParameterDescriptionMessage _parameterDescriptionMessage = new();
readonly DataRowMessage _dataRowMessage = new();
readonly RowDescriptionMessage _rowDescriptionMessage = new();
// Since COPY is rarely used, allocate these lazily
CopyInResponseMessage? _copyInResponseMessage;
CopyOutResponseMessage? _copyOutResponseMessage;
CopyDataMessage? _copyDataMessage;
CopyBothResponseMessage? _copyBothResponseMessage;
#endregion
internal NpgsqlDataReader DataReader { get; set; }
internal NpgsqlDataReader? UnboundDataReader { get; set; }
#region Constructors
internal NpgsqlConnector(ConnectorSource connectorSource, NpgsqlConnection conn)
: this(connectorSource)
{
ProvideClientCertificatesCallback = conn.ProvideClientCertificatesCallback;
UserCertificateValidationCallback = conn.UserCertificateValidationCallback;
ProvidePasswordCallback = conn.ProvidePasswordCallback;
#pragma warning disable CA2252 // Experimental API
PhysicalOpenCallback = conn.PhysicalOpenCallback;
PhysicalOpenAsyncCallback = conn.PhysicalOpenAsyncCallback;
#pragma warning restore CA2252
}
NpgsqlConnector(NpgsqlConnector connector)
: this(connector._connectorSource)
{
ProvideClientCertificatesCallback = connector.ProvideClientCertificatesCallback;
UserCertificateValidationCallback = connector.UserCertificateValidationCallback;
ProvidePasswordCallback = connector.ProvidePasswordCallback;
}
NpgsqlConnector(ConnectorSource connectorSource)
{
Debug.Assert(connectorSource.OwnsConnectors);
_connectorSource = connectorSource;
State = ConnectorState.Closed;
TransactionStatus = TransactionStatus.Idle;
Settings = connectorSource.Settings;
PostgresParameters = new Dictionary();
_isKeepAliveEnabled = Settings.KeepAlive > 0;
if (_isKeepAliveEnabled)
{
_userLock = new SemaphoreSlim(1, 1);
_keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite);
}
DataReader = new NpgsqlDataReader(this);
// TODO: Not just for automatic preparation anymore...
PreparedStatementManager = new PreparedStatementManager(this);
if (Settings.Multiplexing)
{
// Note: It's OK for this channel to be unbounded: each command enqueued to it is accompanied by sending
// it to PostgreSQL. If we overload it, a TCP zero window will make us block on the networking side
// anyway.
// Note: the in-flight channel can probably be single-writer, but that doesn't actually do anything
// at this point. And we currently rely on being able to complete the channel at any point (from
// Break). We may want to revisit this if an optimized, SingleWriter implementation is introduced.
var commandsInFlightChannel = Channel.CreateUnbounded(
new UnboundedChannelOptions { SingleReader = true });
CommandsInFlightReader = commandsInFlightChannel.Reader;
CommandsInFlightWriter = commandsInFlightChannel.Writer;
// TODO: Properly implement this
if (_isKeepAliveEnabled)
throw new NotImplementedException("Keepalive not yet implemented for multiplexing");
}
}
#endregion
#region Configuration settings
internal string Host => Settings.Host!;
internal int Port => Settings.Port;
internal string Database => Settings.Database!;
string KerberosServiceName => Settings.KerberosServiceName;
int ConnectionTimeout => Settings.Timeout;
bool IntegratedSecurity => Settings.IntegratedSecurity;
///
/// The actual command timeout value that gets set on internal commands.
///
/// Precision is milliseconds
int InternalCommandTimeout
{
get
{
var internalTimeout = Settings.InternalCommandTimeout;
if (internalTimeout == -1)
return Math.Max(Settings.CommandTimeout, MinimumInternalCommandTimeout) * 1000;
// Todo: Decide what we really want here
// This assertion can easily fail if InternalCommandTimeout is set to 1 or 2 in the connection string
// We probably don't want to allow these values but in that case a Debug.Assert is the wrong way to enforce it.
Debug.Assert(internalTimeout == 0 || internalTimeout >= MinimumInternalCommandTimeout);
return internalTimeout * 1000;
}
}
#endregion Configuration settings
#region State management
int _state;
///
/// Gets the current state of the connector
///
internal ConnectorState State
{
get => (ConnectorState)_state;
set
{
var newState = (int)value;
if (newState == _state)
return;
Interlocked.Exchange(ref _state, newState);
}
}
///
/// Returns whether the connector is open, regardless of any task it is currently performing
///
bool IsConnected
=> State switch
{
ConnectorState.Ready => true,
ConnectorState.Executing => true,
ConnectorState.Fetching => true,
ConnectorState.Waiting => true,
ConnectorState.Copy => true,
ConnectorState.Replication => true,
ConnectorState.Closed => false,
ConnectorState.Connecting => false,
ConnectorState.Broken => false,
_ => throw new ArgumentOutOfRangeException("Unknown state: " + State)
};
internal bool IsReady => State == ConnectorState.Ready;
internal bool IsClosed => State == ConnectorState.Closed;
internal bool IsBroken => State == ConnectorState.Broken;
#endregion
#region Open
///
/// Opens the physical connection to the server.
///
/// Usually called by the RequestConnector
/// Method of the connection pool manager.
internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken)
{
Debug.Assert(State == ConnectorState.Closed);
State = ConnectorState.Connecting;
try
{
await OpenCore(this, Settings.SslMode, timeout, async, cancellationToken);
await LoadDatabaseInfo(forceReload: false, timeout, async, cancellationToken);
if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard)
{
_sendResetOnClose = true;
GenerateResetMessage();
}
OpenTimestamp = DateTime.UtcNow;
Log.Trace($"Opened connection to {Host}:{Port}");
#pragma warning disable CA2252 // Experimental API
if (async && PhysicalOpenAsyncCallback is not null)
await PhysicalOpenAsyncCallback(this);
else if (!async && PhysicalOpenCallback is not null)
PhysicalOpenCallback(this);
#pragma warning restore CA2252
if (Settings.Multiplexing)
{
// Start an infinite async loop, which processes incoming multiplexing traffic.
// It is intentionally not awaited and will run as long as the connector is alive.
// The CommandsInFlightWriter channel is completed in Cleanup, which should cause this task
// to complete.
_ = Task.Run(MultiplexingReadLoop, CancellationToken.None)
.ContinueWith(t =>
{
// Note that we *must* observe the exception if the task is faulted.
Log.Error("Exception bubbled out of multiplexing read loop", t.Exception!, Id);
}, TaskContinuationOptions.OnlyOnFaulted);
}
if (_isKeepAliveEnabled)
{
// Start the keep alive mechanism to work by scheduling the timer.
// Otherwise, it doesn't work for cases when no query executed during
// the connection lifetime in case of a new connector.
lock (SyncObj)
{
var keepAlive = Settings.KeepAlive * 1000;
_keepAliveTimer!.Change(keepAlive, keepAlive);
}
}
}
catch (Exception e)
{
Break(e);
throw;
}
static async Task OpenCore(
NpgsqlConnector conn,
SslMode sslMode,
NpgsqlTimeout timeout,
bool async,
CancellationToken cancellationToken,
bool isFirstAttempt = true)
{
await conn.RawOpen(sslMode, timeout, async, cancellationToken, isFirstAttempt);
var username = conn.GetUsername();
if (conn.Settings.Database == null)
conn.Settings.Database = username;
timeout.CheckAndApply(conn);
conn.WriteStartupMessage(username);
await conn.Flush(async, cancellationToken);
var cancellationRegistration = conn.StartCancellableOperation(cancellationToken, attemptPgCancellation: false);
try
{
await conn.Authenticate(username, timeout, async, cancellationToken);
}
catch (PostgresException e)
when (e.SqlState == PostgresErrorCodes.InvalidAuthorizationSpecification &&
(sslMode == SslMode.Prefer && conn.IsSecure || sslMode == SslMode.Allow && !conn.IsSecure))
{
cancellationRegistration.Dispose();
Debug.Assert(!conn.IsBroken);
conn.Cleanup();
// If Prefer was specified and we failed (with SSL), retry without SSL.
// If Allow was specified and we failed (without SSL), retry with SSL
await OpenCore(
conn,
sslMode == SslMode.Prefer ? SslMode.Disable : SslMode.Require,
timeout,
async,
cancellationToken,
isFirstAttempt: false);
return;
}
using var _ = cancellationRegistration;
// We treat BackendKeyData as optional because some PostgreSQL-like database
// don't send it (CockroachDB, CrateDB)
var msg = await conn.ReadMessage(async);
if (msg.Code == BackendMessageCode.BackendKeyData)
{
var keyDataMsg = (BackendKeyDataMessage)msg;
conn.BackendProcessId = keyDataMsg.BackendProcessId;
conn._backendSecretKey = keyDataMsg.BackendSecretKey;
msg = await conn.ReadMessage(async);
}
if (msg.Code != BackendMessageCode.ReadyForQuery)
throw new NpgsqlException($"Received backend message {msg.Code} while expecting ReadyForQuery. Please file a bug.");
conn.State = ConnectorState.Ready;
}
}
internal async ValueTask LoadDatabaseInfo(bool forceReload, NpgsqlTimeout timeout, bool async,
CancellationToken cancellationToken = default)
{
// The type loading below will need to send queries to the database, and that depends on a type mapper being set up (even if its
// empty). So we set up here, and then later inject the DatabaseInfo.
// For multiplexing connectors, the type mapper is the shared pool-wide one (since when validating/binding parameters on
// multiplexing there's no connector yet). However, in the very first multiplexing connection (bootstrap phase) we create
// a connector-specific mapper, which will later become shared pool-wide one.
TypeMapper =
Settings.Multiplexing && ((MultiplexingConnectorPool)_connectorSource).MultiplexingTypeMapper is { } multiplexingTypeMapper
? multiplexingTypeMapper
: new ConnectorTypeMapper(this);
var key = new NpgsqlDatabaseInfoCacheKey(Settings);
if (forceReload || !NpgsqlDatabaseInfo.Cache.TryGetValue(key, out var database))
{
var hasSemaphore = async
? await DatabaseInfoSemaphore.WaitAsync(timeout.CheckAndGetTimeLeft(), cancellationToken)
: DatabaseInfoSemaphore.Wait(timeout.CheckAndGetTimeLeft(), cancellationToken);
// We've timed out - calling Check, to throw the correct exception
if (!hasSemaphore)
timeout.Check();
try
{
if (forceReload || !NpgsqlDatabaseInfo.Cache.TryGetValue(key, out database))
{
NpgsqlDatabaseInfo.Cache[key] = database = await NpgsqlDatabaseInfo.Load(this, timeout, async);
}
}
finally
{
DatabaseInfoSemaphore.Release();
}
}
DatabaseInfo = database;
TypeMapper.DatabaseInfo = database;
}
internal async ValueTask QueryClusterState(
NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken = default)
{
using var batch = CreateBatch();
batch.BatchCommands.Add(new NpgsqlBatchCommand("select pg_is_in_recovery()"));
batch.BatchCommands.Add(new NpgsqlBatchCommand("SHOW default_transaction_read_only"));
batch.Timeout = (int)timeout.CheckAndGetTimeLeft().TotalSeconds;
var reader = async ? await batch.ExecuteReaderAsync(cancellationToken) : batch.ExecuteReader();
try
{
if (async)
{
await reader.ReadAsync(cancellationToken);
_isHotStandBy = reader.GetBoolean(0);
await reader.NextResultAsync(cancellationToken);
await reader.ReadAsync(cancellationToken);
}
else
{
reader.Read();
_isHotStandBy = reader.GetBoolean(0);
reader.NextResult();
reader.Read();
}
_isTransactionReadOnly = reader.GetString(0) != "off";
var clusterState = UpdateClusterState();
Debug.Assert(clusterState.HasValue);
return clusterState.Value;
}
finally
{
if (async)
await reader.DisposeAsync();
else
reader.Dispose();
}
}
void WriteStartupMessage(string username)
{
var startupParams = new Dictionary
{
["user"] = username,
["client_encoding"] = Settings.ClientEncoding ??
PostgresEnvironment.ClientEncoding ??
"UTF8",
["database"] = Settings.Database!
};
if (Settings.ApplicationName?.Length > 0)
startupParams["application_name"] = Settings.ApplicationName;
if (Settings.SearchPath?.Length > 0)
startupParams["search_path"] = Settings.SearchPath;
var timezone = Settings.Timezone ?? PostgresEnvironment.TimeZone;
if (timezone != null)
startupParams["TimeZone"] = timezone;
var options = Settings.Options ?? PostgresEnvironment.Options;
if (options?.Length > 0)
startupParams["options"] = options;
switch (Settings.ReplicationMode)
{
case ReplicationMode.Logical:
startupParams["replication"] = "database";
break;
case ReplicationMode.Physical:
startupParams["replication"] = "true";
break;
}
WriteStartup(startupParams);
}
string GetUsername()
{
var username = Settings.Username;
if (username?.Length > 0)
return username;
username = PostgresEnvironment.User;
if (username?.Length > 0)
return username;
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
username = KerberosUsernameProvider.GetUsername(Settings.IncludeRealm);
if (username?.Length > 0)
return username;
}
username = Environment.UserName;
if (username?.Length > 0)
return username;
throw new NpgsqlException("No username could be found, please specify one explicitly");
}
async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken, bool isFirstAttempt = true)
{
try
{
if (async)
await ConnectAsync(timeout, cancellationToken);
else
Connect(timeout);
_baseStream = new NetworkStream(_socket, true);
_stream = _baseStream;
if (Settings.Encoding == "UTF8")
{
TextEncoding = PGUtil.UTF8Encoding;
RelaxedTextEncoding = PGUtil.RelaxedUTF8Encoding;
}
else
{
TextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback);
RelaxedTextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ReplacementFallback, DecoderFallback.ReplacementFallback);
}
ReadBuffer = new NpgsqlReadBuffer(this, _stream, _socket, Settings.ReadBufferSize, TextEncoding, RelaxedTextEncoding);
WriteBuffer = new NpgsqlWriteBuffer(this, _stream, _socket, Settings.WriteBufferSize, TextEncoding);
timeout.CheckAndApply(this);
IsSecure = false;
if (sslMode is SslMode.Prefer or SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull)
{
WriteSslRequest();
await Flush(async, cancellationToken);
await ReadBuffer.Ensure(1, async);
var response = (char)ReadBuffer.ReadByte();
timeout.CheckAndApply(this);
switch (response)
{
default:
throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)");
case 'N':
if (sslMode != SslMode.Prefer)
throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured.");
break;
case 'S':
var clientCertificates = new X509Certificate2Collection();
var certPath = Settings.SslCertificate ?? PostgresEnvironment.SslCert ?? PostgresEnvironment.SslCertDefault;
if (certPath != null)
{
var password = Settings.SslPassword;
if (Path.GetExtension(certPath).ToUpperInvariant() != ".PFX")
{
#if NET5_0_OR_GREATER
// It's PEM time
var keyPath = Settings.SslKey ?? PostgresEnvironment.SslKey ?? PostgresEnvironment.SslKeyDefault;
_certificate = string.IsNullOrEmpty(password)
? X509Certificate2.CreateFromPemFile(certPath, keyPath)
: X509Certificate2.CreateFromEncryptedPemFile(certPath, password, keyPath);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// Windows crypto API has a bug with pem certs
// See #3650
using var previousCert = _certificate;
_certificate = new X509Certificate2(_certificate.Export(X509ContentType.Pkcs12));
}
#else
throw new NotSupportedException("PEM certificates are only supported with .NET 5 and higher");
#endif
}
if (_certificate is null)
_certificate = new X509Certificate2(certPath, password);
clientCertificates.Add(_certificate);
}
ProvideClientCertificatesCallback?.Invoke(clientCertificates);
var checkCertificateRevocation = Settings.CheckCertificateRevocation;
RemoteCertificateValidationCallback? certificateValidationCallback;
if (UserCertificateValidationCallback is not null)
{
if (sslMode is SslMode.VerifyCA or SslMode.VerifyFull)
throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode));
if (Settings.RootCertificate is not null)
throw new ArgumentException(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback);
certificateValidationCallback = UserCertificateValidationCallback;
}
else if (sslMode is SslMode.Prefer or SslMode.Require)
{
if (isFirstAttempt && sslMode is SslMode.Require && !Settings.TrustServerCertificate)
throw new ArgumentException(NpgsqlStrings.CannotUseSslModeRequireWithoutTrustServerCertificate);
certificateValidationCallback = SslTrustServerValidation;
checkCertificateRevocation = false;
}
else if ((Settings.RootCertificate ?? PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is
{ } certRootPath)
{
certificateValidationCallback = SslRootValidation(certRootPath, sslMode == SslMode.VerifyFull);
}
else if (sslMode == SslMode.VerifyCA)
{
certificateValidationCallback = SslVerifyCAValidation;
}
else
{
Debug.Assert(sslMode == SslMode.VerifyFull);
certificateValidationCallback = SslVerifyFullValidation;
}
timeout.CheckAndApply(this);
try
{
var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false, certificateValidationCallback);
var sslProtocols = SslProtocols.None;
// On .NET Framework SslProtocols.None can be disabled, see #3718
#if NETSTANDARD2_0
sslProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12;
#endif
if (async)
await sslStream.AuthenticateAsClientAsync(Host, clientCertificates,
sslProtocols, checkCertificateRevocation);
else
sslStream.AuthenticateAsClient(Host, clientCertificates,
sslProtocols, checkCertificateRevocation);
_stream = sslStream;
}
catch (Exception e)
{
throw new NpgsqlException("Exception while performing SSL handshake", e);
}
ReadBuffer.Underlying = _stream;
WriteBuffer.Underlying = _stream;
IsSecure = true;
Log.Trace("SSL negotiation successful");
break;
}
if (ReadBuffer.ReadBytesLeft > 0)
throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack.");
}
Log.Trace($"Socket connected to {Host}:{Port}");
}
catch
{
_certificate?.Dispose();
_certificate = null;
_stream?.Dispose();
_stream = null!;
_baseStream?.Dispose();
_baseStream = null!;
_socket?.Dispose();
_socket = null!;
throw;
}
}
void Connect(NpgsqlTimeout timeout)
{
// Note that there aren't any timeout-able or cancellable DNS methods
var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)
? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) }
: Dns.GetHostAddresses(Host).Select(a => new IPEndPoint(a, Port)).ToArray();
timeout.Check();
// Give each endpoint an equal share of the remaining time
var perEndpointTimeout = -1; // Default to infinity
if (timeout.IsSet)
perEndpointTimeout = (int)(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length / 10);
for (var i = 0; i < endpoints.Length; i++)
{
var endpoint = endpoints[i];
Log.Trace($"Attempting to connect to {endpoint}");
var protocolType =
endpoint.AddressFamily == AddressFamily.InterNetwork ||
endpoint.AddressFamily == AddressFamily.InterNetworkV6
? ProtocolType.Tcp
: ProtocolType.IP;
var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType)
{
Blocking = false
};
try
{
try
{
socket.Connect(endpoint);
}
catch (SocketException e)
{
if (e.SocketErrorCode != SocketError.WouldBlock)
throw;
}
var write = new List { socket };
var error = new List { socket };
Socket.Select(null, write, error, perEndpointTimeout);
var errorCode = (int)socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.Error)!;
if (errorCode != 0)
throw new SocketException(errorCode);
if (!write.Any())
throw new TimeoutException("Timeout during connection attempt");
socket.Blocking = true;
SetSocketOptions(socket);
_socket = socket;
ConnectedEndPoint = endpoint;
return;
}
catch (Exception e)
{
try { socket.Dispose(); }
catch
{
// ignored
}
Log.Trace($"Failed to connect to {endpoint}", e);
if (i == endpoints.Length - 1)
throw new NpgsqlException($"Failed to connect to {endpoint}", e);
}
}
}
async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
// Note that there aren't any timeout-able or cancellable DNS methods
var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)
? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) }
: (await GetHostAddressesAsync(timeout, cancellationToken))
.Select(a => new IPEndPoint(a, Port)).ToArray();
// Give each IP an equal share of the remaining time
var perIpTimespan = default(TimeSpan);
var perIpTimeout = timeout;
if (timeout.IsSet)
{
perIpTimespan = new TimeSpan(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length);
perIpTimeout = new NpgsqlTimeout(perIpTimespan);
}
for (var i = 0; i < endpoints.Length; i++)
{
var endpoint = endpoints[i];
Log.Trace($"Attempting to connect to {endpoint}");
var protocolType =
endpoint.AddressFamily == AddressFamily.InterNetwork ||
endpoint.AddressFamily == AddressFamily.InterNetworkV6
? ProtocolType.Tcp
: ProtocolType.IP;
var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType);
try
{
await OpenSocketConnectionAsync(socket, endpoint, perIpTimeout, cancellationToken);
SetSocketOptions(socket);
_socket = socket;
ConnectedEndPoint = endpoint;
return;
}
catch (Exception e)
{
try
{
socket.Dispose();
}
catch
{
// ignored
}
cancellationToken.ThrowIfCancellationRequested();
if (e is OperationCanceledException)
e = new TimeoutException("Timeout during connection attempt");
Log.Trace($"Failed to connect to {endpoint}", e);
if (i == endpoints.Length - 1)
throw new NpgsqlException($"Failed to connect to {endpoint}", e);
}
}
Task GetHostAddressesAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
// .NET 6.0 added cancellation support to GetHostAddressesAsync, which allows us to implement real
// cancellation and timeout. On older TFMs, we fake-cancel the operation, i.e. stop waiting
// and raise the exception, but the actual connection task is left running.
#if NET6_0_OR_GREATER
var task = TaskExtensions.ExecuteWithTimeout(
ct => Dns.GetHostAddressesAsync(Host, ct),
timeout, cancellationToken);
#else
var task = Dns.GetHostAddressesAsync(Host);
#endif
// As the cancellation support of GetHostAddressesAsync is not guaranteed on all platforms
// we apply the fake-cancel mechanism in all cases.
return task.WithCancellationAndTimeout(timeout, cancellationToken);
}
static Task OpenSocketConnectionAsync(Socket socket, EndPoint endpoint, NpgsqlTimeout perIpTimeout, CancellationToken cancellationToken)
{
// .NET 5.0 added cancellation support to ConnectAsync, which allows us to implement real
// cancellation and timeout. On older TFMs, we fake-cancel the operation, i.e. stop waiting
// and raise the exception, but the actual connection task is left running.
#if NET5_0_OR_GREATER
return TaskExtensions.ExecuteWithTimeout(
ct => socket.ConnectAsync(endpoint, ct).AsTask(),
perIpTimeout, cancellationToken);
#else
return socket.ConnectAsync(endpoint)
.WithCancellationAndTimeout(perIpTimeout, cancellationToken);
#endif
}
}
void SetSocketOptions(Socket socket)
{
if (socket.AddressFamily == AddressFamily.InterNetwork || socket.AddressFamily == AddressFamily.InterNetworkV6)
socket.NoDelay = true;
if (Settings.SocketReceiveBufferSize > 0)
socket.ReceiveBufferSize = Settings.SocketReceiveBufferSize;
if (Settings.SocketSendBufferSize > 0)
socket.SendBufferSize = Settings.SocketSendBufferSize;
if (Settings.TcpKeepAlive)
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
if (Settings.TcpKeepAliveInterval > 0 && Settings.TcpKeepAliveTime == 0)
throw new ArgumentException("If TcpKeepAliveInterval is defined, TcpKeepAliveTime must be defined as well");
if (Settings.TcpKeepAliveTime > 0)
{
var timeSeconds = Settings.TcpKeepAliveTime;
var intervalSeconds = Settings.TcpKeepAliveInterval > 0
? Settings.TcpKeepAliveInterval
: Settings.TcpKeepAliveTime;
#if NETSTANDARD2_0 || NETSTANDARD2_1
var timeMilliseconds = timeSeconds * 1000;
var intervalMilliseconds = intervalSeconds * 1000;
// For the following see https://msdn.microsoft.com/en-us/library/dd877220.aspx
var uintSize = Marshal.SizeOf(typeof(uint));
var inOptionValues = new byte[uintSize * 3];
BitConverter.GetBytes((uint)1).CopyTo(inOptionValues, 0);
BitConverter.GetBytes((uint)timeMilliseconds).CopyTo(inOptionValues, uintSize);
BitConverter.GetBytes((uint)intervalMilliseconds).CopyTo(inOptionValues, uintSize * 2);
var result = 0;
try
{
result = socket.IOControl(IOControlCode.KeepAliveValues, inOptionValues, null);
}
catch (PlatformNotSupportedException)
{
throw new PlatformNotSupportedException("Setting TCP Keepalive Time and TCP Keepalive Interval is supported only on Windows, Mono and .NET Core 3.1+. " +
"TCP keepalives can still be used on other systems but are enabled via the TcpKeepAlive option or configured globally for the machine, see the relevant docs.");
}
if (result != 0)
throw new NpgsqlException($"Got non-zero value when trying to set TCP keepalive: {result}");
#else
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, timeSeconds);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, intervalSeconds);
#endif
}
}
#endregion
#region I/O
readonly ChannelReader? CommandsInFlightReader;
internal readonly ChannelWriter? CommandsInFlightWriter;
internal volatile int CommandsInFlightCount;
internal ManualResetValueTaskSource