File: System\Net\NetworkInformation\Ping.RawSocket.cs
Web Access
Project: src\src\libraries\System.Net.Ping\src\System.Net.Ping.csproj (System.Net.Ping)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.Net.NetworkInformation
{
    public partial class Ping
    {
        private const int IcmpHeaderLengthInBytes = 8;
        private const int MinIpHeaderLengthInBytes = 20;
        private const int MaxIpHeaderLengthInBytes = 60;
        private const int IpV6HeaderLengthInBytes = 40;
        private static readonly ushort DontFragment = OperatingSystem.IsFreeBSD() ? (ushort)IPAddress.HostToNetworkOrder((short)0x4000) : (ushort)0x4000;
 
        private static unsafe SocketConfig GetSocketConfig(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
        {
            // Use a random value as the identifier. This doesn't need to be perfectly random
            // or very unpredictable, rather just good enough to avoid unexpected conflicts.
            ushort id = (ushort)Random.Shared.Next(ushort.MaxValue + 1);
            IpHeader iph = default;
 
            bool ipv4 = address.AddressFamily == AddressFamily.InterNetwork;
            bool sendIpHeader = ipv4 && options != null && SendIpHeader;
            int totalLength = 0;
 
            if (sendIpHeader)
            {
                iph.VersionAndLength = 0x45;
                totalLength = sizeof(IpHeader) + checked(sizeof(IcmpHeader) + buffer.Length);
                // On OSX this strangely must be host byte order.
                iph.TotalLength = OperatingSystem.IsFreeBSD() ? (ushort)IPAddress.HostToNetworkOrder((short)totalLength) : (ushort)totalLength;
                iph.Protocol = 1; // ICMP
                iph.Ttl = (byte)options!.Ttl;
                iph.Flags = (ushort)(options.DontFragment ? DontFragment : 0);
#pragma warning disable 618
                iph.DestinationAddress = (uint)address.Address;
#pragma warning restore 618
                // No need to fill in SourceAddress or checksum.
                // If left blank, kernel will fill it in - at least on OSX.
            }
 
            return new SocketConfig(
                new IPEndPoint(address, 0), timeout, options,
                ipv4, ipv4 ? ProtocolType.Icmp : ProtocolType.IcmpV6, id,
                CreateSendMessageBuffer(iph, new IcmpHeader()
                {
                    Type = ipv4 ? (byte)IcmpV4MessageType.EchoRequest : (byte)IcmpV6MessageType.EchoRequest,
                    Identifier = id,
                }, buffer, totalLength));
        }
 
        private static Socket GetRawSocket(SocketConfig socketConfig)
        {
            IPEndPoint ep = (IPEndPoint)socketConfig.EndPoint;
            AddressFamily addrFamily = ep.Address.AddressFamily;
            SocketType socketType = RawSocketPermissions.CanUseRawSockets(addrFamily) ?
                SocketType.Raw :
                SocketType.Dgram; // macOS/iOS has ability to send ICMP echo without RAW
 
            Socket socket = new Socket(addrFamily, socketType, socketConfig.ProtocolType);
            socket.ReceiveTimeout = socketConfig.Timeout;
            socket.SendTimeout = socketConfig.Timeout;
            if (addrFamily == AddressFamily.InterNetworkV6 && !SupportsDualMode)
            {
                socket.DualMode = false;
            }
 
            if (socketConfig.Options != null)
            {
                if (socketConfig.Options.Ttl > 0)
                {
                    socket.Ttl = (short)socketConfig.Options.Ttl;
                }
 
                if (addrFamily == AddressFamily.InterNetwork)
                {
                    if (SendIpHeader)
                    {
                        // some platforms like OSX don't support DontFragment so we construct IP header instead.
                        socket.SetSocketOption(SocketOptionLevel.IP, SocketOptionName.HeaderIncluded, 1);
                    }
                    else
                    {
                        socket.DontFragment = socketConfig.Options.DontFragment;
                    }
                }
            }
 
#pragma warning disable 618
            // Disable warning about obsolete property. We could use GetAddressBytes but that allocates.
            // IPv4 multicast address starts with 1110 bits so mask rest and test if we get correct value e.g. 0xe0.
            if (NeedsConnect && !ep.Address.IsIPv6Multicast && !(addrFamily == AddressFamily.InterNetwork && (ep.Address.Address & 0xf0) == 0xe0))
            {
                // If it is not multicast, use Connect to scope responses only to the target address.
                socket.Connect(socketConfig.EndPoint);
                unsafe
                {
                    int opt = 1;
                    // setsockopt(fd, IPPROTO_IP, IP_RECVERR, &value, sizeof(int))
                    socket.SetRawSocketOption(0, 11, new ReadOnlySpan<byte>(&opt, sizeof(int)));
                }
            }
#pragma warning restore 618
 
            return socket;
        }
 
        private static bool TryGetPingReply(
            SocketConfig socketConfig, byte[] receiveBuffer, int bytesReceived, long startingTimestamp, ref int ipHeaderLength,
            [NotNullWhen(true)] out PingReply? reply)
        {
            byte type, code;
            reply = null;
 
            if (socketConfig.IsIpv4)
            {
                // Determine actual size of IP header
                byte ihl = (byte)(receiveBuffer[0] & 0x0f); // Internet Header Length
                ipHeaderLength = 4 * ihl;
                if (bytesReceived - ipHeaderLength < IcmpHeaderLengthInBytes)
                {
                    return false; // Not enough bytes to reconstruct actual IP header + ICMP header.
                }
            }
 
            int icmpHeaderOffset = ipHeaderLength;
            int dataOffset = ipHeaderLength + IcmpHeaderLengthInBytes;
 
            // Skip IP header.
            IcmpHeader receivedHeader = MemoryMarshal.Read<IcmpHeader>(receiveBuffer.AsSpan(icmpHeaderOffset));
            ushort identifier;
            type = receivedHeader.Type;
            code = receivedHeader.Code;
 
            // Validate the ICMP header and get the identifier
            if (socketConfig.IsIpv4)
            {
                if (type == (byte)IcmpV4MessageType.EchoReply)
                {
                    // Reply packet has the identifier in the ICMP header.
                    identifier = receivedHeader.Identifier;
                }
                else if (type == (byte)IcmpV4MessageType.DestinationUnreachable ||
                         type == (byte)IcmpV4MessageType.TimeExceeded ||
                         type == (byte)IcmpV4MessageType.ParameterProblemBadIPHeader ||
                         type == (byte)IcmpV4MessageType.SourceQuench ||
                         type == (byte)IcmpV4MessageType.RedirectMessage)
                {
                    // Original IP+ICMP request is in the payload. Read the ICMP header from
                    // the payload to get identifier.
 
                    if (dataOffset + MinIpHeaderLengthInBytes + IcmpHeaderLengthInBytes > bytesReceived)
                    {
                        return false;
                    }
 
                    byte ihl = (byte)(receiveBuffer[dataOffset] & 0x0f); // Internet Header Length
                    int payloadIpHeaderLength = 4 * ihl;
 
                    if (bytesReceived - dataOffset - payloadIpHeaderLength < IcmpHeaderLengthInBytes)
                    {
                        return false; // Not enough bytes to reconstruct actual IP header + ICMP header.
                    }
 
                    IcmpHeader originalRequestHeader = MemoryMarshal.Read<IcmpHeader>(receiveBuffer.AsSpan(dataOffset + payloadIpHeaderLength));
                    identifier = originalRequestHeader.Identifier;
 
                    // Update the date offset to point past the payload IP+ICMP headers. While the specification
                    // doesn't indicate there should be any additional data the reality is that we often get the
                    // original packet data back.
                    dataOffset += payloadIpHeaderLength + IcmpHeaderLengthInBytes;
                }
                else
                {
                    return false;
                }
            }
            else
            {
                if (type == (byte)IcmpV6MessageType.EchoReply)
                {
                    // Reply packet has the identifier in the ICMP header.
                    identifier = receivedHeader.Identifier;
                }
                else if (type == (byte)IcmpV6MessageType.DestinationUnreachable ||
                         type == (byte)IcmpV6MessageType.TimeExceeded ||
                         type == (byte)IcmpV6MessageType.ParameterProblem ||
                         type == (byte)IcmpV6MessageType.PacketTooBig)
                {
                    // Original IP+ICMP request is in the payload. Read the ICMP header from
                    // the payload to get identifier.
 
                    if (bytesReceived - dataOffset < IpV6HeaderLengthInBytes + IcmpHeaderLengthInBytes)
                    {
                        return false; // Not enough bytes to reconstruct actual IP header + ICMP header.
                    }
 
                    IcmpHeader originalRequestHeader = MemoryMarshal.Read<IcmpHeader>(receiveBuffer.AsSpan(dataOffset + IpV6HeaderLengthInBytes));
                    identifier = originalRequestHeader.Identifier;
 
                    // Update the date offset to point past the payload IP+ICMP headers. While the specification
                    // doesn't indicate there should be any additional data the reality is that we often get the
                    // original packet data back.
                    dataOffset += IpV6HeaderLengthInBytes + IcmpHeaderLengthInBytes;
                }
                else
                {
                    return false;
                }
            }
 
            if (socketConfig.Identifier != identifier)
            {
                return false;
            }
 
            long roundTripTime = (long)Stopwatch.GetElapsedTime(startingTimestamp).TotalMilliseconds;
 
            // We want to return a buffer with the actual data we sent out, not including the header data.
            byte[] dataBuffer = new byte[bytesReceived - dataOffset];
            Buffer.BlockCopy(receiveBuffer, dataOffset, dataBuffer, 0, dataBuffer.Length);
 
            IPStatus status = socketConfig.IsIpv4
                ? IcmpV4MessageConstants.MapV4TypeToIPStatus(type, code)
                : IcmpV6MessageConstants.MapV6TypeToIPStatus(type, code);
 
            IPAddress address = ((IPEndPoint)socketConfig.EndPoint).Address;
            reply = new PingReply(address, socketConfig.Options, status, roundTripTime, dataBuffer);
            return true;
        }
 
        private static unsafe PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
        {
            SocketConfig socketConfig = GetSocketConfig(address, buffer, timeout, options);
            using (Socket socket = GetRawSocket(socketConfig))
            {
                Span<byte> socketAddress = stackalloc byte[SocketAddress.GetMaximumAddressSize(address.AddressFamily)];
                int ipHeaderLength = socketConfig.IsIpv4 ? MinIpHeaderLengthInBytes : 0;
                try
                {
                    socket.SendTo(socketConfig.SendBuffer, SocketFlags.None, socketConfig.EndPoint);
 
                    byte[] receiveBuffer = new byte[2 * (MaxIpHeaderLengthInBytes + IcmpHeaderLengthInBytes) + buffer.Length];
 
                    // Read from the socket in a loop. We may receive messages that are not echo replies, or that are not in response
                    // to the echo request we just sent. We need to filter such messages out, and continue reading until our timeout.
                    // For example, when pinging the local host, we need to filter out our own echo requests that the socket reads.
                    long startingTimestamp = Stopwatch.GetTimestamp();
                    while (Stopwatch.GetElapsedTime(startingTimestamp).TotalMilliseconds < timeout)
                    {
                        int bytesReceived = socket.ReceiveFrom(receiveBuffer, SocketFlags.None, ref socketConfig.EndPoint);
 
                        if (bytesReceived - ipHeaderLength < IcmpHeaderLengthInBytes)
                        {
                            continue; // Not enough bytes to reconstruct IP header + ICMP header.
                        }
 
                        if (TryGetPingReply(socketConfig, receiveBuffer, bytesReceived, startingTimestamp, ref ipHeaderLength, out PingReply? reply))
                        {
                            return reply;
                        }
                    }
                }
                catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut)
                {
                }
                catch (SocketException ex) when (ex.SocketErrorCode == SocketError.MessageSize)
                {
                    return CreatePingReply(IPStatus.PacketTooBig);
                }
                catch (SocketException ex) when (ex.SocketErrorCode == SocketError.HostUnreachable)
                {
                    // This happens on Linux where we explicitly subscribed to error messages
                    // We should be able to get more info by getting extended socket error from error queue.
 
                    Interop.Sys.MessageHeader header = default;
 
                    SocketError result;
                    fixed (byte* sockAddr = &MemoryMarshal.GetReference(socketAddress))
                    {
                        header.SocketAddress = sockAddr;
                        header.SocketAddressLen = socketAddress.Length;
                        header.IOVectors = null;
                        header.IOVectorCount = 0;
 
                        result = Interop.Sys.ReceiveSocketError(socket.SafeHandle, &header);
                    }
 
                    if (result == SocketError.Success && header.SocketAddressLen > 0)
                    {
                        return CreatePingReply(IPStatus.TtlExpired, IPEndPointExtensions.GetIPAddress(socketAddress.Slice(0, header.SocketAddressLen)));
                    }
                }
 
                // We have exceeded our timeout duration, and no reply has been received.
                return CreatePingReply(IPStatus.TimedOut);
            }
        }
 
        private async Task<PingReply> SendIcmpEchoRequestOverRawSocketAsync(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
        {
            SocketConfig socketConfig = GetSocketConfig(address, buffer, timeout, options);
            using Socket socket = GetRawSocket(socketConfig);
            int ipHeaderLength = socketConfig.IsIpv4 ? MinIpHeaderLengthInBytes : 0;
 
            try
            {
                CancellationToken timeoutOrCancellationToken = _timeoutOrCancellationSource!.Token;
 
                await socket.SendToAsync(
                    socketConfig.SendBuffer.AsMemory(),
                    SocketFlags.None,
                    socketConfig.EndPoint,
                    timeoutOrCancellationToken)
                    .ConfigureAwait(false);
 
                byte[] receiveBuffer = new byte[2 * (MaxIpHeaderLengthInBytes + IcmpHeaderLengthInBytes) + buffer.Length];
 
                // Read from the socket in a loop. We may receive messages that are not echo replies, or that are not in response
                // to the echo request we just sent. We need to filter such messages out, and continue reading until our timeout.
                // For example, when pinging the local host, we need to filter out our own echo requests that the socket reads.
                long startingTimestamp = Stopwatch.GetTimestamp();
                while (!timeoutOrCancellationToken.IsCancellationRequested)
                {
                    SocketReceiveFromResult receiveResult = await socket.ReceiveFromAsync(
                        receiveBuffer.AsMemory(),
                        SocketFlags.None,
                        socketConfig.EndPoint,
                        timeoutOrCancellationToken)
                        .ConfigureAwait(false);
 
                    int bytesReceived = receiveResult.ReceivedBytes;
                    if (bytesReceived - ipHeaderLength < IcmpHeaderLengthInBytes)
                    {
                        continue; // Not enough bytes to reconstruct IP header + ICMP header.
                    }
 
                    if (TryGetPingReply(socketConfig, receiveBuffer, bytesReceived, startingTimestamp, ref ipHeaderLength, out PingReply? reply))
                    {
                        return reply;
                    }
                }
            }
            catch (SocketException ex) when (ex.SocketErrorCode == SocketError.TimedOut)
            {
            }
            catch (SocketException ex) when (ex.SocketErrorCode == SocketError.MessageSize)
            {
                return CreatePingReply(IPStatus.PacketTooBig);
            }
            catch (OperationCanceledException) when (!_canceled)
            {
            }
 
            // We have exceeded our timeout duration, and no reply has been received.
            return CreatePingReply(IPStatus.TimedOut);
        }
 
        private static PingReply CreatePingReply(IPStatus status, IPAddress? address = null, long rtt = 0)
        {
            // Documentation indicates that you should only pay attention to the IPStatus value when
            // its value is not "Success", but the rest of these values match that of the Windows implementation.
            return new PingReply(address ?? new IPAddress(0), null, status, rtt, Array.Empty<byte>());
        }
 
#if DEBUG
        static Ping()
        {
            Debug.Assert(Marshal.SizeOf<IcmpHeader>() == 8, "The size of an ICMP Header must be 8 bytes.");
        }
#endif
 
        [StructLayout(LayoutKind.Sequential)]
        internal struct IpHeader
        {
            internal byte VersionAndLength;
            internal byte Tos;
            internal ushort TotalLength;
 
            internal ushort Identifier;
            internal ushort Flags;
 
            internal byte Ttl;
            internal byte Protocol;
            internal ushort HeaderChecksum;
 
            internal uint SourceAddress;
            internal uint DestinationAddress;
        };
 
        // Must be 8 bytes total.
        [StructLayout(LayoutKind.Sequential)]
        internal struct IcmpHeader
        {
            public byte Type;
            public byte Code;
            public ushort HeaderChecksum;
            public ushort Identifier;
            public ushort SequenceNumber;
        }
 
        // Since this is private should be safe to trust that the calling code
        // will behave. To get a little performance boost raw fields are exposed
        // and no validation is performed.
        private sealed class SocketConfig
        {
            public SocketConfig(EndPoint endPoint, int timeout, PingOptions? options, bool isIPv4, ProtocolType protocolType, ushort id, byte[] sendBuffer)
            {
                EndPoint = endPoint;
                Timeout = timeout;
                Options = options;
                IsIpv4 = isIPv4;
                ProtocolType = protocolType;
                Identifier = id;
                SendBuffer = sendBuffer;
            }
 
            public EndPoint EndPoint;
            public readonly int Timeout;
            public readonly PingOptions? Options;
            public readonly ushort Identifier;
            public readonly bool IsIpv4;
            public readonly ProtocolType ProtocolType;
            public readonly byte[] SendBuffer;
        }
 
        private static unsafe byte[] CreateSendMessageBuffer(IpHeader ipHeader, IcmpHeader icmpHeader, byte[] payload, int totalLength = 0)
        {
            int icmpHeaderSize = sizeof(IcmpHeader);
            int offset = 0;
            int packetSize = totalLength != 0 ? totalLength : checked(icmpHeaderSize + payload.Length);
            byte[] result = new byte[packetSize];
 
            if (totalLength != 0)
            {
                int ipHeaderSize = sizeof(IpHeader);
                new Span<byte>(&ipHeader, sizeof(IpHeader)).CopyTo(result);
                offset = ipHeaderSize;
            }
 
            //byte[] result = new byte[headerSize + payload.Length];
            Marshal.Copy(new IntPtr(&icmpHeader), result, offset, icmpHeaderSize);
            payload.CopyTo(result, offset + icmpHeaderSize);
 
            // offset now still points to beginning of ICMP header.
            ushort checksum = ComputeBufferChecksum(result.AsSpan(offset));
            // Jam the checksum into the buffer.
            result[offset + 2] = (byte)(checksum >> 8);
            result[offset + 3] = (byte)(checksum & (0xFF));
 
            return result;
        }
 
        private static ushort ComputeBufferChecksum(ReadOnlySpan<byte> buffer)
        {
            // This is using the "deferred carries" approach outlined in RFC 1071.
            uint sum = 0;
            for (int i = 0; i < buffer.Length; i += 2)
            {
                // Combine each pair of bytes into a 16-bit number and add it to the sum
                ushort element0 = (ushort)((buffer[i] << 8) & 0xFF00);
                ushort element1 = (i + 1 < buffer.Length)
                    ? (ushort)(buffer[i + 1] & 0x00FF)
                    : (ushort)0; // If there's an odd number of bytes, pad by one octet of zeros.
                ushort combined = (ushort)(element0 | element1);
                sum += (uint)combined;
            }
 
            // Add back the "carry bits" which have risen to the upper 16 bits of the sum.
            while ((sum >> 16) != 0)
            {
                var partialSum = sum & 0xFFFF;
                var carries = sum >> 16;
                sum = partialSum + carries;
            }
 
            return unchecked((ushort)~sum);
        }
    }
}