#nullable enable using Swan.Formatters; using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; using System.Linq; using System.Net; using System.Net.Sockets; using System.Runtime.InteropServices; using System.Text; namespace Swan.Net.Dns { /// /// DnsClient Request inner class. /// internal partial class DnsClient { public class DnsClientRequest : IDnsRequest { private readonly IDnsRequestResolver _resolver; private readonly IDnsRequest _request; public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null) { this.Dns = dns; this._request = request == null ? new DnsRequest() : new DnsRequest(request); this._resolver = resolver ?? new DnsUdpRequestResolver(); } public Int32 Id { get => this._request.Id; set => this._request.Id = value; } public DnsOperationCode OperationCode { get => this._request.OperationCode; set => this._request.OperationCode = value; } public Boolean RecursionDesired { get => this._request.RecursionDesired; set => this._request.RecursionDesired = value; } public IList Questions => this._request.Questions; public Int32 Size => this._request.Size; public IPEndPoint Dns { get; set; } public Byte[] ToArray() => this._request.ToArray(); public override String ToString() => this._request.ToString()!; /// /// Resolves this request into a response using the provided DNS information. The given /// request strategy is used to retrieve the response. /// /// Throw if a malformed response is received from the server. /// Thrown if a IO error occurs. /// Thrown if a the reading or writing to the socket fails. /// The response received from server. public async Task Resolve() { try { DnsClientResponse response = await this._resolver.Request(this).ConfigureAwait(false); if(response.Id != this.Id) { throw new DnsQueryException(response, "Mismatching request/response IDs"); } if(response.ResponseCode != DnsResponseCode.NoError) { throw new DnsQueryException(response); } return response; } catch(Exception e) { if(e is ArgumentException || e is SocketException) { throw new DnsQueryException("Invalid response", e); } throw; } } } public class DnsRequest : IDnsRequest { private static readonly Random Random = new Random(); private DnsHeader header; public DnsRequest() { this.Questions = new List(); this.header = new DnsHeader { OperationCode = DnsOperationCode.Query, Response = false, Id = Random.Next(UInt16.MaxValue), }; } public DnsRequest(IDnsRequest request) { this.header = new DnsHeader(); this.Questions = new List(request.Questions); this.header.Response = false; this.Id = request.Id; this.OperationCode = request.OperationCode; this.RecursionDesired = request.RecursionDesired; } public IList Questions { get; } public Int32 Size => this.header.Size + this.Questions.Sum(q => q.Size); public Int32 Id { get => this.header.Id; set => this.header.Id = value; } public DnsOperationCode OperationCode { get => this.header.OperationCode; set => this.header.OperationCode = value; } public Boolean RecursionDesired { get => this.header.RecursionDesired; set => this.header.RecursionDesired = value; } public Byte[] ToArray() { this.UpdateHeader(); using MemoryStream result = new MemoryStream(this.Size); return result.Append(this.header.ToArray()).Append(this.Questions.Select(q => q.ToArray())).ToArray(); } public override String ToString() { this.UpdateHeader(); return Json.Serialize(this, true); } private void UpdateHeader() => this.header.QuestionCount = this.Questions.Count; } public class DnsTcpRequestResolver : IDnsRequestResolver { public async Task Request(DnsClientRequest request) { TcpClient tcp = new TcpClient(); try { await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false); NetworkStream stream = tcp.GetStream(); Byte[] buffer = request.ToArray(); Byte[] length = BitConverter.GetBytes((UInt16)buffer.Length); if(BitConverter.IsLittleEndian) { Array.Reverse(length); } await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false); await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false); buffer = new Byte[2]; await Read(stream, buffer).ConfigureAwait(false); if(BitConverter.IsLittleEndian) { Array.Reverse(buffer); } buffer = new Byte[BitConverter.ToUInt16(buffer, 0)]; await Read(stream, buffer).ConfigureAwait(false); DnsResponse response = DnsResponse.FromArray(buffer); return new DnsClientResponse(request, response, buffer); } finally { tcp.Dispose(); } } private static async Task Read(Stream stream, Byte[] buffer) { Int32 length = buffer.Length; Int32 offset = 0; Int32 size; while(length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0) { offset += size; length -= size; } if(length > 0) { throw new IOException("Unexpected end of stream"); } } } public class DnsUdpRequestResolver : IDnsRequestResolver { private readonly IDnsRequestResolver _fallback; public DnsUdpRequestResolver(IDnsRequestResolver fallback) => this._fallback = fallback; public DnsUdpRequestResolver() => this._fallback = new DnsNullRequestResolver(); public async Task Request(DnsClientRequest request) { UdpClient udp = new UdpClient(); IPEndPoint dns = request.Dns; try { udp.Client.SendTimeout = 7000; udp.Client.ReceiveTimeout = 7000; await udp.Client.ConnectAsync(dns).ConfigureAwait(false); _ = await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false); List bufferList = new List(); do { Byte[] tempBuffer = new Byte[1024]; Int32 receiveCount = udp.Client.Receive(tempBuffer); bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount)); } while(udp.Client.Available > 0 || bufferList.Count == 0); Byte[] buffer = bufferList.ToArray(); DnsResponse response = DnsResponse.FromArray(buffer); return response.IsTruncated ? await this._fallback.Request(request).ConfigureAwait(false) : new DnsClientResponse(request, response, buffer); } finally { udp.Dispose(); } } } public class DnsNullRequestResolver : IDnsRequestResolver { public Task Request(DnsClientRequest request) => throw new DnsQueryException("Request failed"); } // 12 bytes message header [StructEndianness(Endianness.Big)] [StructLayout(LayoutKind.Sequential, Pack = 1)] public struct DnsHeader { public const Int32 SIZE = 12; private UInt16 id; // Question count: number of questions in the Question section private UInt16 questionCount; // Answer record count: number of records in the Answer section private UInt16 answerCount; // Authority record count: number of records in the Authority section private UInt16 authorityCount; // Additional record count: number of records in the Additional section private UInt16 addtionalCount; public Int32 Id { get => this.id; set => this.id = (UInt16)value; } public Int32 QuestionCount { get => this.questionCount; set => this.questionCount = (UInt16)value; } public Int32 AnswerRecordCount { get => this.answerCount; set => this.answerCount = (UInt16)value; } public Int32 AuthorityRecordCount { get => this.authorityCount; set => this.authorityCount = (UInt16)value; } public Int32 AdditionalRecordCount { get => this.addtionalCount; set => this.addtionalCount = (UInt16)value; } public Boolean Response { get => this.Qr == 1; set => this.Qr = Convert.ToByte(value); } public DnsOperationCode OperationCode { get => (DnsOperationCode)this.Opcode; set => this.Opcode = (Byte)value; } public Boolean AuthorativeServer { get => this.Aa == 1; set => this.Aa = Convert.ToByte(value); } public Boolean Truncated { get => this.Tc == 1; set => this.Tc = Convert.ToByte(value); } public Boolean RecursionDesired { get => this.Rd == 1; set => this.Rd = Convert.ToByte(value); } public Boolean RecursionAvailable { get => this.Ra == 1; set => this.Ra = Convert.ToByte(value); } public DnsResponseCode ResponseCode { get => (DnsResponseCode)this.RCode; set => this.RCode = (Byte)value; } public Int32 Size => SIZE; // Query/Response Flag private Byte Qr { get => this.Flag0.GetBitValueAt(7); set => this.Flag0 = this.Flag0.SetBitValueAt(7, 1, value); } // Operation Code private Byte Opcode { get => this.Flag0.GetBitValueAt(3, 4); set => this.Flag0 = this.Flag0.SetBitValueAt(3, 4, value); } // Authorative Answer Flag private Byte Aa { get => this.Flag0.GetBitValueAt(2); set => this.Flag0 = this.Flag0.SetBitValueAt(2, 1, value); } // Truncation Flag private Byte Tc { get => this.Flag0.GetBitValueAt(1); set => this.Flag0 = this.Flag0.SetBitValueAt(1, 1, value); } // Recursion Desired private Byte Rd { get => this.Flag0.GetBitValueAt(0); set => this.Flag0 = this.Flag0.SetBitValueAt(0, 1, value); } // Recursion Available private Byte Ra { get => this.Flag1.GetBitValueAt(7); set => this.Flag1 = this.Flag1.SetBitValueAt(7, 1, value); } // Zero (Reserved) private Byte Z { get => this.Flag1.GetBitValueAt(4, 3); set { } } // Response Code private Byte RCode { get => this.Flag1.GetBitValueAt(0, 4); set => this.Flag1 = this.Flag1.SetBitValueAt(0, 4, value); } private Byte Flag0 { get; set; } private Byte Flag1 { get; set; } public static DnsHeader FromArray(Byte[] header) => header.Length < SIZE ? throw new ArgumentException("Header length too small") : header.ToStruct(0, SIZE); public Byte[] ToArray() => this.ToBytes(); public override String ToString() => Json.SerializeExcluding(this, true, nameof(this.Size)); } public class DnsDomain : IComparable { private readonly String[] _labels; public DnsDomain(String domain) : this(domain.Split('.')) { } public DnsDomain(String[] labels) => this._labels = labels; public Int32 Size => this._labels.Sum(l => l.Length) + this._labels.Length + 1; public static DnsDomain FromArray(Byte[] message, Int32 offset) => FromArray(message, offset, out _); public static DnsDomain FromArray(Byte[] message, Int32 offset, out Int32 endOffset) { List labels = new List(); Boolean endOffsetAssigned = false; endOffset = 0; Byte lengthOrPointer; while((lengthOrPointer = message[offset++]) > 0) { // Two heighest bits are set (pointer) if(lengthOrPointer.GetBitValueAt(6, 2) == 3) { if(!endOffsetAssigned) { endOffsetAssigned = true; endOffset = offset + 1; } UInt16 pointer = lengthOrPointer.GetBitValueAt(0, 6); offset = (pointer << 8) | message[offset]; continue; } if(lengthOrPointer.GetBitValueAt(6, 2) != 0) { throw new ArgumentException("Unexpected bit pattern in label length"); } Byte length = lengthOrPointer; Byte[] label = new Byte[length]; Array.Copy(message, offset, label, 0, length); labels.Add(label); offset += length; } if(!endOffsetAssigned) { endOffset = offset; } return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray()); } public static DnsDomain PointerName(IPAddress ip) => new DnsDomain(FormatReverseIP(ip)); public Byte[] ToArray() { Byte[] result = new Byte[this.Size]; Int32 offset = 0; foreach(Byte[] l in this._labels.Select(label => Encoding.ASCII.GetBytes(label))) { result[offset++] = (Byte)l.Length; l.CopyTo(result, offset); offset += l.Length; } result[offset] = 0; return result; } public override String ToString() => String.Join(".", this._labels); public Int32 CompareTo(DnsDomain other) => String.Compare(this.ToString(), other.ToString(), StringComparison.Ordinal); public override Boolean Equals(Object? obj) => obj is DnsDomain domain && this.CompareTo(domain) == 0; public override Int32 GetHashCode() => this.ToString().GetHashCode(); private static String FormatReverseIP(IPAddress ip) { Byte[] address = ip.GetAddressBytes(); if(address.Length == 4) { return String.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa"; } Byte[] nibbles = new Byte[address.Length * 2]; for(Int32 i = 0, j = 0; i < address.Length; i++, j = 2 * i) { Byte b = address[i]; nibbles[j] = b.GetBitValueAt(4, 4); nibbles[j + 1] = b.GetBitValueAt(0, 4); } return String.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa"; } } public class DnsQuestion : IDnsMessageEntry { public static IList GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount) => GetAllFromArray(message, offset, questionCount, out _); public static IList GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount, out Int32 endOffset) { IList questions = new List(questionCount); for(Int32 i = 0; i < questionCount; i++) { questions.Add(FromArray(message, offset, out offset)); } endOffset = offset; return questions; } public static DnsQuestion FromArray(Byte[] message, Int32 offset, out Int32 endOffset) { DnsDomain domain = DnsDomain.FromArray(message, offset, out offset); Tail tail = message.ToStruct(offset, Tail.SIZE); endOffset = offset + Tail.SIZE; return new DnsQuestion(domain, tail.Type, tail.Class); } public DnsQuestion(DnsDomain domain, DnsRecordType type = DnsRecordType.A, DnsRecordClass klass = DnsRecordClass.IN) { this.Name = domain; this.Type = type; this.Class = klass; } public DnsDomain Name { get; } public DnsRecordType Type { get; } public DnsRecordClass Class { get; } public Int32 Size => this.Name.Size + Tail.SIZE; public Byte[] ToArray() => new MemoryStream(this.Size).Append(this.Name.ToArray()).Append(new Tail { Type = Type, Class = Class }.ToBytes()).ToArray(); public override String ToString() => Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class)); [StructEndianness(Endianness.Big)] [StructLayout(LayoutKind.Sequential, Pack = 2)] private struct Tail { public const Int32 SIZE = 4; private UInt16 type; private UInt16 klass; public DnsRecordType Type { get => (DnsRecordType)this.type; set => this.type = (UInt16)value; } public DnsRecordClass Class { get => (DnsRecordClass)this.klass; set => this.klass = (UInt16)value; } } } } }