namespace Swan.Net.Dns
{
using Formatters;
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Text;
///
/// DnsClient Request inner class.
///
internal partial class DnsClient
{
public class DnsClientRequest : IDnsRequest
{
private readonly IDnsRequestResolver _resolver;
private readonly IDnsRequest _request;
public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null)
{
Dns = dns;
_request = request == null ? new DnsRequest() : new DnsRequest(request);
_resolver = resolver ?? new DnsUdpRequestResolver();
}
public int Id
{
get => _request.Id;
set => _request.Id = value;
}
public DnsOperationCode OperationCode
{
get => _request.OperationCode;
set => _request.OperationCode = value;
}
public bool RecursionDesired
{
get => _request.RecursionDesired;
set => _request.RecursionDesired = value;
}
public IList Questions => _request.Questions;
public int Size => _request.Size;
public IPEndPoint Dns { get; set; }
public byte[] ToArray() => _request.ToArray();
public override string ToString() => _request.ToString();
///
/// Resolves this request into a response using the provided DNS information. The given
/// request strategy is used to retrieve the response.
///
/// Throw if a malformed response is received from the server.
/// Thrown if a IO error occurs.
/// Thrown if a the reading or writing to the socket fails.
/// The response received from server.
public async Task Resolve()
{
try
{
var response = await _resolver.Request(this).ConfigureAwait(false);
if (response.Id != Id)
{
throw new DnsQueryException(response, "Mismatching request/response IDs");
}
if (response.ResponseCode != DnsResponseCode.NoError)
{
throw new DnsQueryException(response);
}
return response;
}
catch (Exception e)
{
if (e is ArgumentException || e is SocketException)
throw new DnsQueryException("Invalid response", e);
throw;
}
}
}
public class DnsRequest : IDnsRequest
{
private static readonly Random Random = new Random();
private DnsHeader header;
public DnsRequest()
{
Questions = new List();
header = new DnsHeader
{
OperationCode = DnsOperationCode.Query,
Response = false,
Id = Random.Next(ushort.MaxValue),
};
}
public DnsRequest(IDnsRequest request)
{
header = new DnsHeader();
Questions = new List(request.Questions);
header.Response = false;
Id = request.Id;
OperationCode = request.OperationCode;
RecursionDesired = request.RecursionDesired;
}
public IList Questions { get; }
public int Size => header.Size + Questions.Sum(q => q.Size);
public int Id
{
get => header.Id;
set => header.Id = value;
}
public DnsOperationCode OperationCode
{
get => header.OperationCode;
set => header.OperationCode = value;
}
public bool RecursionDesired
{
get => header.RecursionDesired;
set => header.RecursionDesired = value;
}
public byte[] ToArray()
{
UpdateHeader();
using var result = new MemoryStream(Size);
return result
.Append(header.ToArray())
.Append(Questions.Select(q => q.ToArray()))
.ToArray();
}
public override string ToString()
{
UpdateHeader();
return Json.Serialize(this, true);
}
private void UpdateHeader()
{
header.QuestionCount = Questions.Count;
}
}
public class DnsTcpRequestResolver : IDnsRequestResolver
{
public async Task Request(DnsClientRequest request)
{
var tcp = new TcpClient();
try
{
#if !NET461
await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false);
#else
tcp.Client.Connect(request.Dns);
#endif
var stream = tcp.GetStream();
var buffer = request.ToArray();
var length = BitConverter.GetBytes((ushort)buffer.Length);
if (BitConverter.IsLittleEndian)
Array.Reverse(length);
await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false);
await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
buffer = new byte[2];
await Read(stream, buffer).ConfigureAwait(false);
if (BitConverter.IsLittleEndian)
Array.Reverse(buffer);
buffer = new byte[BitConverter.ToUInt16(buffer, 0)];
await Read(stream, buffer).ConfigureAwait(false);
var response = DnsResponse.FromArray(buffer);
return new DnsClientResponse(request, response, buffer);
}
finally
{
#if NET461
tcp.Close();
#else
tcp.Dispose();
#endif
}
}
private static async Task Read(Stream stream, byte[] buffer)
{
var length = buffer.Length;
var offset = 0;
int size;
while (length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0)
{
offset += size;
length -= size;
}
if (length > 0)
{
throw new IOException("Unexpected end of stream");
}
}
}
public class DnsUdpRequestResolver : IDnsRequestResolver
{
private readonly IDnsRequestResolver _fallback;
public DnsUdpRequestResolver(IDnsRequestResolver fallback)
{
_fallback = fallback;
}
public DnsUdpRequestResolver()
{
_fallback = new DnsNullRequestResolver();
}
public async Task Request(DnsClientRequest request)
{
var udp = new UdpClient();
var dns = request.Dns;
try
{
udp.Client.SendTimeout = 7000;
udp.Client.ReceiveTimeout = 7000;
#if !NET461
await udp.Client.ConnectAsync(dns).ConfigureAwait(false);
#else
udp.Client.Connect(dns);
#endif
await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false);
var bufferList = new List();
do
{
var tempBuffer = new byte[1024];
var receiveCount = udp.Client.Receive(tempBuffer);
bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount));
}
while (udp.Client.Available > 0 || bufferList.Count == 0);
var buffer = bufferList.ToArray();
var response = DnsResponse.FromArray(buffer);
return response.IsTruncated
? await _fallback.Request(request).ConfigureAwait(false)
: new DnsClientResponse(request, response, buffer);
}
finally
{
#if NET461
udp.Close();
#else
udp.Dispose();
#endif
}
}
}
public class DnsNullRequestResolver : IDnsRequestResolver
{
public Task Request(DnsClientRequest request) => throw new DnsQueryException("Request failed");
}
// 12 bytes message header
[StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct DnsHeader
{
public const int SIZE = 12;
private ushort id;
private byte flag0;
private byte flag1;
// Question count: number of questions in the Question section
private ushort questionCount;
// Answer record count: number of records in the Answer section
private ushort answerCount;
// Authority record count: number of records in the Authority section
private ushort authorityCount;
// Additional record count: number of records in the Additional section
private ushort addtionalCount;
public int Id
{
get => id;
set => id = (ushort)value;
}
public int QuestionCount
{
get => questionCount;
set => questionCount = (ushort)value;
}
public int AnswerRecordCount
{
get => answerCount;
set => answerCount = (ushort)value;
}
public int AuthorityRecordCount
{
get => authorityCount;
set => authorityCount = (ushort)value;
}
public int AdditionalRecordCount
{
get => addtionalCount;
set => addtionalCount = (ushort)value;
}
public bool Response
{
get => Qr == 1;
set => Qr = Convert.ToByte(value);
}
public DnsOperationCode OperationCode
{
get => (DnsOperationCode)Opcode;
set => Opcode = (byte)value;
}
public bool AuthorativeServer
{
get => Aa == 1;
set => Aa = Convert.ToByte(value);
}
public bool Truncated
{
get => Tc == 1;
set => Tc = Convert.ToByte(value);
}
public bool RecursionDesired
{
get => Rd == 1;
set => Rd = Convert.ToByte(value);
}
public bool RecursionAvailable
{
get => Ra == 1;
set => Ra = Convert.ToByte(value);
}
public DnsResponseCode ResponseCode
{
get => (DnsResponseCode)RCode;
set => RCode = (byte)value;
}
public int Size => SIZE;
// Query/Response Flag
private byte Qr
{
get => Flag0.GetBitValueAt(7);
set => Flag0 = Flag0.SetBitValueAt(7, 1, value);
}
// Operation Code
private byte Opcode
{
get => Flag0.GetBitValueAt(3, 4);
set => Flag0 = Flag0.SetBitValueAt(3, 4, value);
}
// Authorative Answer Flag
private byte Aa
{
get => Flag0.GetBitValueAt(2);
set => Flag0 = Flag0.SetBitValueAt(2, 1, value);
}
// Truncation Flag
private byte Tc
{
get => Flag0.GetBitValueAt(1);
set => Flag0 = Flag0.SetBitValueAt(1, 1, value);
}
// Recursion Desired
private byte Rd
{
get => Flag0.GetBitValueAt(0);
set => Flag0 = Flag0.SetBitValueAt(0, 1, value);
}
// Recursion Available
private byte Ra
{
get => Flag1.GetBitValueAt(7);
set => Flag1 = Flag1.SetBitValueAt(7, 1, value);
}
// Zero (Reserved)
private byte Z
{
get => Flag1.GetBitValueAt(4, 3);
set { }
}
// Response Code
private byte RCode
{
get => Flag1.GetBitValueAt(0, 4);
set => Flag1 = Flag1.SetBitValueAt(0, 4, value);
}
private byte Flag0
{
get => flag0;
set => flag0 = value;
}
private byte Flag1
{
get => flag1;
set => flag1 = value;
}
public static DnsHeader FromArray(byte[] header) =>
header.Length < SIZE
? throw new ArgumentException("Header length too small")
: header.ToStruct(0, SIZE);
public byte[] ToArray() => this.ToBytes();
public override string ToString()
=> Json.SerializeExcluding(this, true, nameof(Size));
}
public class DnsDomain : IComparable
{
private readonly string[] _labels;
public DnsDomain(string domain)
: this(domain.Split('.'))
{
}
public DnsDomain(string[] labels)
{
_labels = labels;
}
public int Size => _labels.Sum(l => l.Length) + _labels.Length + 1;
public static DnsDomain FromArray(byte[] message, int offset)
=> FromArray(message, offset, out offset);
public static DnsDomain FromArray(byte[] message, int offset, out int endOffset)
{
var labels = new List();
var endOffsetAssigned = false;
endOffset = 0;
byte lengthOrPointer;
while ((lengthOrPointer = message[offset++]) > 0)
{
// Two heighest bits are set (pointer)
if (lengthOrPointer.GetBitValueAt(6, 2) == 3)
{
if (!endOffsetAssigned)
{
endOffsetAssigned = true;
endOffset = offset + 1;
}
ushort pointer = lengthOrPointer.GetBitValueAt(0, 6);
offset = (pointer << 8) | message[offset];
continue;
}
if (lengthOrPointer.GetBitValueAt(6, 2) != 0)
{
throw new ArgumentException("Unexpected bit pattern in label length");
}
var length = lengthOrPointer;
var label = new byte[length];
Array.Copy(message, offset, label, 0, length);
labels.Add(label);
offset += length;
}
if (!endOffsetAssigned)
{
endOffset = offset;
}
return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray());
}
public static DnsDomain PointerName(IPAddress ip)
=> new DnsDomain(FormatReverseIP(ip));
public byte[] ToArray()
{
var result = new byte[Size];
var offset = 0;
foreach (var l in _labels.Select(label => Encoding.ASCII.GetBytes(label)))
{
result[offset++] = (byte)l.Length;
l.CopyTo(result, offset);
offset += l.Length;
}
result[offset] = 0;
return result;
}
public override string ToString()
=> string.Join(".", _labels);
public int CompareTo(DnsDomain other)
=> string.Compare(ToString(), other.ToString(), StringComparison.Ordinal);
public override bool Equals(object obj)
=> obj is DnsDomain domain && CompareTo(domain) == 0;
public override int GetHashCode() => ToString().GetHashCode();
private static string FormatReverseIP(IPAddress ip)
{
var address = ip.GetAddressBytes();
if (address.Length == 4)
{
return string.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa";
}
var nibbles = new byte[address.Length * 2];
for (int i = 0, j = 0; i < address.Length; i++, j = 2 * i)
{
var b = address[i];
nibbles[j] = b.GetBitValueAt(4, 4);
nibbles[j + 1] = b.GetBitValueAt(0, 4);
}
return string.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa";
}
}
public class DnsQuestion : IDnsMessageEntry
{
private readonly DnsRecordType _type;
private readonly DnsRecordClass _klass;
public static IList GetAllFromArray(byte[] message, int offset, int questionCount) =>
GetAllFromArray(message, offset, questionCount, out offset);
public static IList GetAllFromArray(
byte[] message,
int offset,
int questionCount,
out int endOffset)
{
IList questions = new List(questionCount);
for (var i = 0; i < questionCount; i++)
{
questions.Add(FromArray(message, offset, out offset));
}
endOffset = offset;
return questions;
}
public static DnsQuestion FromArray(byte[] message, int offset, out int endOffset)
{
var domain = DnsDomain.FromArray(message, offset, out offset);
var tail = message.ToStruct(offset, Tail.SIZE);
endOffset = offset + Tail.SIZE;
return new DnsQuestion(domain, tail.Type, tail.Class);
}
public DnsQuestion(
DnsDomain domain,
DnsRecordType type = DnsRecordType.A,
DnsRecordClass klass = DnsRecordClass.IN)
{
Name = domain;
_type = type;
_klass = klass;
}
public DnsDomain Name { get; }
public DnsRecordType Type => _type;
public DnsRecordClass Class => _klass;
public int Size => Name.Size + Tail.SIZE;
public byte[] ToArray() =>
new MemoryStream(Size)
.Append(Name.ToArray())
.Append(new Tail { Type = Type, Class = Class }.ToBytes())
.ToArray();
public override string ToString()
=> Json.SerializeOnly(this, true, nameof(Name), nameof(Type), nameof(Class));
[StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 2)]
private struct Tail
{
public const int SIZE = 4;
private ushort type;
private ushort klass;
public DnsRecordType Type
{
get => (DnsRecordType)type;
set => type = (ushort)value;
}
public DnsRecordClass Class
{
get => (DnsRecordClass)klass;
set => klass = (ushort)value;
}
}
}
}
}