Kit.Core/LibExternal/Npgsql/Internal/NpgsqlConnector.FrontendMes...

469 lines
19 KiB
C#

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Util;
// ReSharper disable VariableHidesOuterVariable
namespace Npgsql.Internal;
partial class NpgsqlConnector
{
internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bool async, CancellationToken cancellationToken = default)
{
Debug.Assert(name.All(c => c < 128));
var len = sizeof(byte) + // Message code
sizeof(int) + // Length
sizeof(byte) + // Statement or portal
(name.Length + 1); // Statement/portal name
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, statementOrPortal, name, async, cancellationToken);
Write(len, statementOrPortal, name);
return Task.CompletedTask;
async Task FlushAndWrite(int len, StatementOrPortal statementOrPortal, string name, bool async, CancellationToken cancellationToken)
{
await Flush(async, cancellationToken);
Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})");
Write(len, statementOrPortal, name);
}
void Write(int len, StatementOrPortal statementOrPortal, string name)
{
WriteBuffer.WriteByte(FrontendMessageCode.Describe);
WriteBuffer.WriteInt32(len - 1);
WriteBuffer.WriteByte((byte)statementOrPortal);
WriteBuffer.WriteNullTerminatedString(name);
}
}
internal Task WriteSync(bool async, CancellationToken cancellationToken = default)
{
const int len = sizeof(byte) + // Message code
sizeof(int); // Length
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(async, cancellationToken);
Write();
return Task.CompletedTask;
async Task FlushAndWrite(bool async, CancellationToken cancellationToken)
{
await Flush(async, cancellationToken);
Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})");
Write();
}
void Write()
{
WriteBuffer.WriteByte(FrontendMessageCode.Sync);
WriteBuffer.WriteInt32(len - 1);
}
}
internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellationToken = default)
{
// Note: non-empty portal currently not supported
const int len = sizeof(byte) + // Message code
sizeof(int) + // Length
sizeof(byte) + // Null-terminated portal name (always empty for now)
sizeof(int); // Max number of rows
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(maxRows, async, cancellationToken);
Write(maxRows);
return Task.CompletedTask;
async Task FlushAndWrite(int maxRows, bool async, CancellationToken cancellationToken)
{
await Flush(async, cancellationToken);
Debug.Assert(10 <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length 10 which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})");
Write(maxRows);
}
void Write(int maxRows)
{
WriteBuffer.WriteByte(FrontendMessageCode.Execute);
WriteBuffer.WriteInt32(len - 1);
WriteBuffer.WriteByte(0); // Portal is always empty for now
WriteBuffer.WriteInt32(maxRows);
}
}
internal async Task WriteParse(string sql, string statementName, List<NpgsqlParameter> inputParameters, bool async, CancellationToken cancellationToken = default)
{
Debug.Assert(statementName.All(c => c < 128));
int queryByteLen;
try
{
queryByteLen = TextEncoding.GetByteCount(sql);
}
catch (Exception e)
{
Break(e);
throw;
}
if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
await Flush(async, cancellationToken);
var messageLength =
sizeof(byte) + // Message code
sizeof(int) + // Length
statementName.Length + // Statement name
sizeof(byte) + // Null terminator for the statement name
queryByteLen + sizeof(byte) + // SQL query length plus null terminator
sizeof(ushort) + // Number of parameters
inputParameters.Count * sizeof(int); // Parameter OIDs
WriteBuffer.WriteByte(FrontendMessageCode.Parse);
WriteBuffer.WriteInt32(messageLength - 1);
WriteBuffer.WriteNullTerminatedString(statementName);
await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
if (WriteBuffer.WriteSpaceLeft < 1 + 2)
await Flush(async, cancellationToken);
WriteBuffer.WriteByte(0); // Null terminator for the query
WriteBuffer.WriteUInt16((ushort)inputParameters.Count);
foreach (var p in inputParameters)
{
if (WriteBuffer.WriteSpaceLeft < 4)
await Flush(async, cancellationToken);
WriteBuffer.WriteInt32((int)p.Handler!.PostgresType.OID);
}
}
internal async Task WriteBind(
List<NpgsqlParameter> parameters,
string portal,
string statement,
bool allResultTypesAreUnknown,
bool[]? unknownResultTypeList,
bool async,
CancellationToken cancellationToken = default)
{
Debug.Assert(statement.All(c => c < 128));
Debug.Assert(portal.All(c => c < 128));
var headerLength =
sizeof(byte) + // Message code
sizeof(int) + // Message length
sizeof(byte) + // Portal is always empty (only a null terminator)
statement.Length + sizeof(byte) + // Statement name plus null terminator
sizeof(ushort); // Number of parameter format codes that follow
if (WriteBuffer.WriteSpaceLeft < headerLength)
{
Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
await Flush(async, cancellationToken);
}
var formatCodesSum = 0;
var paramsLength = 0;
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
{
var param = parameters[paramIndex];
formatCodesSum += (int)param.FormatCode;
param.LengthCache?.Rewind();
paramsLength += param.ValidateAndGetLength();
}
var formatCodeListLength = formatCodesSum == 0 ? 0 : formatCodesSum == parameters.Count ? 1 : parameters.Count;
var messageLength = headerLength +
sizeof(short) * formatCodeListLength + // List of format codes
sizeof(short) + // Number of parameters
sizeof(int) * parameters.Count + // Parameter lengths
paramsLength + // Parameter values
sizeof(short) + // Number of result format codes
sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes
WriteBuffer.WriteByte(FrontendMessageCode.Bind);
WriteBuffer.WriteInt32(messageLength - 1);
Debug.Assert(portal == string.Empty);
WriteBuffer.WriteByte(0); // Portal is always empty
WriteBuffer.WriteNullTerminatedString(statement);
WriteBuffer.WriteInt16(formatCodeListLength);
// 0 length implicitly means all-text, 1 means all-binary, >1 means mix-and-match
if (formatCodeListLength == 1)
{
if (WriteBuffer.WriteSpaceLeft < 2)
await Flush(async, cancellationToken);
WriteBuffer.WriteInt16((short)FormatCode.Binary);
}
else if (formatCodeListLength > 1)
{
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
{
if (WriteBuffer.WriteSpaceLeft < 2)
await Flush(async, cancellationToken);
WriteBuffer.WriteInt16((short)parameters[paramIndex].FormatCode);
}
}
if (WriteBuffer.WriteSpaceLeft < 2)
await Flush(async, cancellationToken);
WriteBuffer.WriteUInt16((ushort)parameters.Count);
for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
{
var param = parameters[paramIndex];
param.LengthCache?.Rewind();
await param.WriteWithLength(WriteBuffer, async, cancellationToken);
}
if (unknownResultTypeList != null)
{
if (WriteBuffer.WriteSpaceLeft < 2 + unknownResultTypeList.Length * 2)
await Flush(async, cancellationToken);
WriteBuffer.WriteInt16(unknownResultTypeList.Length);
foreach (var t in unknownResultTypeList)
WriteBuffer.WriteInt16(t ? 0 : 1);
}
else
{
if (WriteBuffer.WriteSpaceLeft < 4)
await Flush(async, cancellationToken);
WriteBuffer.WriteInt16(1);
WriteBuffer.WriteInt16(allResultTypesAreUnknown ? 0 : 1);
}
}
internal Task WriteClose(StatementOrPortal type, string name, bool async, CancellationToken cancellationToken = default)
{
var len = sizeof(byte) + // Message code
sizeof(int) + // Length
sizeof(byte) + // Statement or portal
name.Length + sizeof(byte); // Statement or portal name plus null terminator
if (WriteBuffer.WriteSpaceLeft < len)
return FlushAndWrite(len, type, name, async, cancellationToken);
Write(len, type, name);
return Task.CompletedTask;
async Task FlushAndWrite(int len, StatementOrPortal type, string name, bool async, CancellationToken cancellationToken)
{
await Flush(async, cancellationToken);
Debug.Assert(len <= WriteBuffer.WriteSpaceLeft, $"Message of type {GetType().Name} has length {len} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})");
Write(len, type, name);
}
void Write(int len, StatementOrPortal type, string name)
{
WriteBuffer.WriteByte(FrontendMessageCode.Close);
WriteBuffer.WriteInt32(len - 1);
WriteBuffer.WriteByte((byte)type);
WriteBuffer.WriteNullTerminatedString(name);
}
}
internal void WriteQuery(string sql) => WriteQuery(sql, false).GetAwaiter().GetResult();
internal async Task WriteQuery(string sql, bool async, CancellationToken cancellationToken = default)
{
var queryByteLen = TextEncoding.GetByteCount(sql);
if (WriteBuffer.WriteSpaceLeft < 1 + 4)
await Flush(async, cancellationToken);
WriteBuffer.WriteByte(FrontendMessageCode.Query);
WriteBuffer.WriteInt32(
sizeof(int) + // Message length (including self excluding code)
queryByteLen + // Query byte length
sizeof(byte)); // Null terminator
await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
if (WriteBuffer.WriteSpaceLeft < 1)
await Flush(async, cancellationToken);
WriteBuffer.WriteByte(0); // Null terminator
}
internal void WriteCopyDone() => WriteCopyDone(false).GetAwaiter().GetResult();
internal async Task WriteCopyDone(bool async, CancellationToken cancellationToken = default)
{
const int len = sizeof(byte) + // Message code
sizeof(int); // Length
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);
WriteBuffer.WriteByte(FrontendMessageCode.CopyDone);
WriteBuffer.WriteInt32(len - 1);
}
internal async Task WriteCopyFail(bool async, CancellationToken cancellationToken = default)
{
// Note: error message not supported for now
const int len = sizeof(byte) + // Message code
sizeof(int) + // Length
sizeof(byte); // Error message is always empty (only a null terminator)
if (WriteBuffer.WriteSpaceLeft < len)
await Flush(async, cancellationToken);
WriteBuffer.WriteByte(FrontendMessageCode.CopyFail);
WriteBuffer.WriteInt32(len - 1);
WriteBuffer.WriteByte(0); // Error message is always empty (only a null terminator)
}
internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
{
const int len = sizeof(int) + // Length
sizeof(int) + // Cancel request code
sizeof(int) + // Backend process id
sizeof(int); // Backend secret key
Debug.Assert(backendProcessId != 0);
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();
WriteBuffer.WriteInt32(len);
WriteBuffer.WriteInt32(1234 << 16 | 5678);
WriteBuffer.WriteInt32(backendProcessId);
WriteBuffer.WriteInt32(backendSecretKey);
}
internal void WriteTerminate()
{
const int len = sizeof(byte) + // Message code
sizeof(int); // Length
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();
WriteBuffer.WriteByte(FrontendMessageCode.Terminate);
WriteBuffer.WriteInt32(len - 1);
}
internal void WriteSslRequest()
{
const int len = sizeof(int) + // Length
sizeof(int); // SSL request code
if (WriteBuffer.WriteSpaceLeft < len)
Flush(false).GetAwaiter().GetResult();
WriteBuffer.WriteInt32(len);
WriteBuffer.WriteInt32(80877103);
}
internal void WriteStartup(Dictionary<string, string> parameters)
{
const int protocolVersion3 = 3 << 16; // 196608
var len = sizeof(int) + // Length
sizeof(int) + // Protocol version
sizeof(byte); // Trailing zero byte
foreach (var kvp in parameters)
len += PGUtil.UTF8Encoding.GetByteCount(kvp.Key) + 1 +
PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;
// Should really never happen, just in case
if (len > WriteBuffer.Size)
throw new Exception("Startup message bigger than buffer");
WriteBuffer.WriteInt32(len);
WriteBuffer.WriteInt32(protocolVersion3);
foreach (var kv in parameters)
{
WriteBuffer.WriteString(kv.Key);
WriteBuffer.WriteByte(0);
WriteBuffer.WriteString(kv.Value);
WriteBuffer.WriteByte(0);
}
WriteBuffer.WriteByte(0);
}
#region Authentication
internal Task WritePassword(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, 0, payload.Length, async, cancellationToken);
internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
{
if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
await WriteBuffer.Flush(async, cancellationToken);
WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(sizeof(int) + count);
if (count <= WriteBuffer.WriteSpaceLeft)
{
// The entire array fits in our WriteBuffer, copy it into the WriteBuffer as usual.
WriteBuffer.WriteBytes(payload, offset, count);
return;
}
await WriteBuffer.Flush(async, cancellationToken);
await WriteBuffer.DirectWrite(new ReadOnlyMemory<byte>(payload, offset, count), async, cancellationToken);
}
internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialResponse, bool async, CancellationToken cancellationToken = default)
{
var len = sizeof(byte) + // Message code
sizeof(int) + // Length
PGUtil.UTF8Encoding.GetByteCount(mechanism) + sizeof(byte) + // Mechanism plus null terminator
sizeof(int) + // Initial response length
(initialResponse?.Length ?? 0); // Initial response payload
if (WriteBuffer.WriteSpaceLeft < len)
await WriteBuffer.Flush(async, cancellationToken);
WriteBuffer.WriteByte(FrontendMessageCode.Password);
WriteBuffer.WriteInt32(len - 1);
WriteBuffer.WriteString(mechanism);
WriteBuffer.WriteByte(0); // null terminator
if (initialResponse == null)
WriteBuffer.WriteInt32(-1);
else
{
WriteBuffer.WriteInt32(initialResponse.Length);
WriteBuffer.WriteBytes(initialResponse);
}
}
internal Task WriteSASLResponse(byte[] payload, bool async, CancellationToken cancellationToken = default) => WritePassword(payload, async, cancellationToken);
#endregion Authentication
internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
{
if (WriteBuffer.WriteSpaceLeft < data.Length)
return FlushAndWrite(data, async, cancellationToken);
WriteBuffer.WriteBytes(data, 0, data.Length);
return Task.CompletedTask;
async Task FlushAndWrite(byte[] data, bool async, CancellationToken cancellationToken)
{
await Flush(async, cancellationToken);
Debug.Assert(data.Length <= WriteBuffer.WriteSpaceLeft, $"Pregenerated message has length {data.Length} which is bigger than the buffer ({WriteBuffer.WriteSpaceLeft})");
WriteBuffer.WriteBytes(data, 0, data.Length);
}
}
internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult();
internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken);
}