|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.IO;
using System.Threading;
using Microsoft.Win32.SafeHandles;
namespace System.Net.Sockets
{
public partial class SocketAsyncEventArgs : EventArgs, IDisposable
{
private IntPtr _acceptedFileDescriptor;
private int _socketAddressSize;
private SocketFlags _receivedFlags;
private Action<int, Memory<byte>, SocketFlags, SocketError>? _transferCompletionCallback;
partial void InitializeInternals();
partial void FreeInternals();
partial void SetupMultipleBuffers();
partial void CompleteCore();
private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, Memory<byte> socketAddress, SocketError socketError)
{
CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketError);
CompletionCallback(0, SocketFlags.None, socketError);
}
private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, Memory<byte> socketAddress, SocketError socketError)
{
_acceptedFileDescriptor = acceptedFileDescriptor;
if (socketError == SocketError.Success)
{
_acceptAddressBufferCount = socketAddress.Length;
}
else
{
_acceptAddressBufferCount = 0;
}
}
internal unsafe SocketError DoOperationAccept(Socket _ /*socket*/, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken)
{
if (!_buffer.Equals(default))
{
throw new PlatformNotSupportedException(SR.net_sockets_accept_receive_notsupported);
}
_acceptedFileDescriptor = (IntPtr)(-1);
Debug.Assert(acceptHandle == null, $"Unexpected acceptHandle: {acceptHandle}");
IntPtr acceptedFd;
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, out int socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken);
if (socketError != SocketError.IOPending)
{
CompleteAcceptOperation(acceptedFd, new Memory<byte>(_acceptBuffer, 0, socketAddressLen), socketError);
FinishOperationSync(socketError, 0, SocketFlags.None);
}
return socketError;
}
private void ConnectCompletionCallback(int bytesTransferred, Memory<byte> socketAddress, SocketFlags receivedFlags, SocketError socketError)
{
CompletionCallback(bytesTransferred, SocketFlags.None, socketError);
}
internal unsafe SocketError DoOperationConnectEx(Socket _ /*socket*/, SafeSocketHandle handle)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, _buffer.Slice(_offset, _count), out int sentBytes);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, sentBytes, SocketFlags.None);
}
return socketError;
}
internal unsafe SocketError DoOperationConnect(SafeSocketHandle handle)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, ConnectCompletionCallback, Memory<byte>.Empty, out int _);
if (socketError != SocketError.IOPending)
{
FinishOperationSync(socketError, 0, SocketFlags.None);
}
return socketError;
}
internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken _ /*cancellationToken*/)
{
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
FinishOperationSync(socketError, 0, SocketFlags.None);
return socketError;
}
private Action<int, Memory<byte>, SocketFlags, SocketError> TransferCompletionCallback =>
_transferCompletionCallback ??= TransferCompletionCallbackCore;
private void TransferCompletionCallbackCore(int bytesTransferred, Memory<byte> socketAddress, SocketFlags receivedFlags, SocketError socketError)
{
CompleteTransferOperation(socketAddress, socketAddress.Length, receivedFlags);
CompletionCallback(bytesTransferred, receivedFlags, socketError);
}
private void CompleteTransferOperation(Memory<byte> _, int socketAddressSize, SocketFlags receivedFlags)
{
_socketAddressSize = socketAddressSize;
_receivedFlags = receivedFlags;
}
internal unsafe SocketError DoOperationReceive(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
SocketFlags flags;
int bytesReceived;
SocketError errorCode;
if (_bufferList == null)
{
// TCP has no out-going receive flags. We can use different syscalls which give better performance.
bool noReceivedFlags = _currentSocket!.ProtocolType == ProtocolType.Tcp;
if (noReceivedFlags)
{
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, TransferCompletionCallback, cancellationToken);
flags = SocketFlags.None;
}
else
{
errorCode = handle.AsyncContext.ReceiveAsync(_buffer.Slice(_offset, _count), _socketFlags, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken);
}
}
else
{
errorCode = handle.AsyncContext.ReceiveAsync(_bufferListInternal!, _socketFlags, out bytesReceived, out flags, TransferCompletionCallback);
}
if (errorCode != SocketError.IOPending)
{
CompleteTransferOperation(null, 0, flags);
FinishOperationSync(errorCode, bytesReceived, flags);
}
return errorCode;
}
internal unsafe SocketError DoOperationReceiveFrom(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
SocketFlags flags;
SocketError errorCode;
int bytesReceived;
int socketAddressLen;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.ReceiveFromAsync(_buffer.Slice(_offset, _count), _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback, cancellationToken);
}
else
{
errorCode = handle.AsyncContext.ReceiveFromAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out socketAddressLen, out bytesReceived, out flags, TransferCompletionCallback);
}
if (errorCode != SocketError.IOPending)
{
CompleteTransferOperation(_socketAddress.Buffer, socketAddressLen, flags);
FinishOperationSync(errorCode, bytesReceived, flags);
}
return errorCode;
}
private void ReceiveMessageFromCompletionCallback(int bytesTransferred, Memory<byte> socketAddress, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation, SocketError errorCode)
{
CompleteReceiveMessageFromOperation(socketAddress, socketAddress.Length, receivedFlags, ipPacketInformation);
CompletionCallback(bytesTransferred, receivedFlags, errorCode);
}
private void CompleteReceiveMessageFromOperation(Memory<byte> socketAddress, int socketAddressSize, SocketFlags receivedFlags, IPPacketInformation ipPacketInformation)
{
Debug.Assert(socketAddress.Length == socketAddressSize);
_socketAddressSize = socketAddress.Length;
_receivedFlags = receivedFlags;
_receiveMessageFromPacketInfo = ipPacketInformation;
}
internal unsafe SocketError DoOperationReceiveMessageFrom(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receiveMessageFromPacketInfo = default(IPPacketInformation);
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
bool isIPv4, isIPv6;
Socket.GetIPProtocolInformation(socket.AddressFamily, _socketAddress!, out isIPv4, out isIPv6);
int socketAddressSize = _socketAddress!.Buffer.Length;
int bytesReceived;
SocketFlags receivedFlags;
IPPacketInformation ipPacketInformation;
SocketError socketError = handle.AsyncContext.ReceiveMessageFromAsync(_buffer.Slice(_offset, _count), _bufferListInternal, _socketFlags, _socketAddress.Buffer, out socketAddressSize, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, ReceiveMessageFromCompletionCallback, cancellationToken);
if (socketError != SocketError.IOPending)
{
_socketAddress.Size = socketAddressSize;
CompleteReceiveMessageFromOperation(_socketAddress.Buffer, socketAddressSize, receivedFlags, ipPacketInformation);
FinishOperationSync(socketError, bytesReceived, receivedFlags);
}
return socketError;
}
internal unsafe SocketError DoOperationSend(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
int bytesSent;
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendAsync(_buffer, _offset, _count, _socketFlags, out bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
errorCode = handle.AsyncContext.SendAsync(_bufferListInternal!, _socketFlags, out bytesSent, TransferCompletionCallback);
}
if (errorCode != SocketError.IOPending)
{
CompleteTransferOperation(null, 0, SocketFlags.None);
FinishOperationSync(errorCode, bytesSent, SocketFlags.None);
}
return errorCode;
}
internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle _1 /*handle*/, CancellationToken cancellationToken)
{
Debug.Assert(_sendPacketsElements != null);
SendPacketsElement[] elements = (SendPacketsElement[])_sendPacketsElements.Clone();
SafeFileHandle[] fileHandles = new SafeFileHandle[elements.Length];
// Open all files synchronously ahead of time so that any exceptions are propagated
// to the caller, to match Windows behavior.
try
{
for (int i = 0; i < elements.Length; i++)
{
string? path = elements[i]?.FilePath;
if (path != null)
{
fileHandles[i] = File.OpenHandle(path, FileMode.Open, FileAccess.Read, FileShare.Read, FileOptions.Asynchronous);
}
}
}
catch (Exception exc)
{
// Clean up any files that were already opened.
foreach (SafeFileHandle s in fileHandles)
{
s?.Dispose();
}
// Windows differentiates the directory not being found from the file not being found.
// Approximate this by checking to see if the directory exists; this is only best-effort,
// as there are various things that could affect this, e.g. directory creation racing with
// this check, but it's good enough for most situations.
if (exc is FileNotFoundException fnfe)
{
string? dirname = Path.GetDirectoryName(fnfe.FileName);
if (!string.IsNullOrEmpty(dirname) && !Directory.Exists(dirname))
{
throw new DirectoryNotFoundException(fnfe.Message);
}
}
// Otherwise propagate the original error.
throw;
}
_ = SocketPal.SendPacketsAsync(socket, SendPacketsFlags, elements, fileHandles, cancellationToken, (bytesTransferred, error) =>
{
if (error == SocketError.Success)
{
FinishOperationAsyncSuccess((int)bytesTransferred, SocketFlags.None);
}
else
{
FinishOperationAsyncFailure(error, (int)bytesTransferred, SocketFlags.None);
}
});
return SocketError.IOPending;
}
internal SocketError DoOperationSendTo(SafeSocketHandle handle, CancellationToken cancellationToken)
{
_receivedFlags = System.Net.Sockets.SocketFlags.None;
_socketAddressSize = 0;
int bytesSent = 0;
SocketError errorCode;
if (_bufferList == null)
{
errorCode = handle.AsyncContext.SendToAsync(_buffer, _offset, _count, _socketFlags, _socketAddress!.Buffer, ref bytesSent, TransferCompletionCallback, cancellationToken);
}
else
{
errorCode = handle.AsyncContext.SendToAsync(_bufferListInternal!, _socketFlags, _socketAddress!.Buffer, out bytesSent, TransferCompletionCallback);
}
if (errorCode != SocketError.IOPending)
{
CompleteTransferOperation(_socketAddress.Buffer, _socketAddress.Size, SocketFlags.None);
FinishOperationSync(errorCode, bytesSent, SocketFlags.None);
}
return errorCode;
}
internal void LogBuffer(int size)
{
// This should only be called if tracing is enabled. However, there is the potential for a race
// condition where tracing is disabled between a calling check and here, in which case the assert
// may fire erroneously.
Debug.Assert(NetEventSource.Log.IsEnabled());
if (_bufferList == null)
{
NetEventSource.DumpBuffer(this, _buffer, _offset, size);
}
else if (_acceptBuffer != null)
{
NetEventSource.DumpBuffer(this, _acceptBuffer, 0, size);
}
}
private SocketError FinishOperationAccept(SocketAddress remoteSocketAddress)
{
new ReadOnlySpan<byte>(_acceptBuffer, 0, _acceptAddressBufferCount).CopyTo(remoteSocketAddress.Buffer.Span);
remoteSocketAddress.Size = _acceptAddressBufferCount;
// on macOS accept can sometimes return empty remote address even when it returns successfully.
Socket acceptedSocket = _currentSocket!.CreateAcceptSocket(
SocketPal.CreateSocket(_acceptedFileDescriptor),
remoteSocketAddress.Size > 0 ? _currentSocket._rightEndPoint!.Create(remoteSocketAddress) : null);
if (_acceptSocket is null)
{
// Store the accepted socket
_acceptSocket = acceptedSocket;
}
else
{
// Copy state from the accepted socket into the caller-supplied socket and then dispose of the original.
_acceptSocket.DisposeHandle();
_acceptSocket.CopyStateFromSource(acceptedSocket);
acceptedSocket.ClearHandle();
acceptedSocket.Dispose();
}
return SocketError.Success;
}
private static SocketError FinishOperationConnect()
{
// No-op for *nix.
return SocketError.Success;
}
private void UpdateReceivedSocketAddress(SocketAddress socketAddress)
{
socketAddress.Size = _socketAddressSize;
}
partial void FinishOperationReceiveMessageFrom();
partial void FinishOperationSendPackets();
private void CompletionCallback(int bytesTransferred, SocketFlags flags, SocketError socketError)
{
if (socketError == SocketError.Success)
{
FinishOperationAsyncSuccess(bytesTransferred, flags);
}
else
{
if (_currentSocket!.Disposed)
{
socketError = SocketError.OperationAborted;
}
FinishOperationAsyncFailure(socketError, bytesTransferred, flags);
}
}
}
}
|