RaspberryIO/Unosquare.Swan/Networking/DnsClient.Request.cs
2019-12-03 18:44:25 +01:00

582 lines
17 KiB
C#

using Unosquare.Swan.Formatters;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Text;
using Unosquare.Swan.Exceptions;
using Unosquare.Swan.Attributes;
namespace Unosquare.Swan.Networking {
/// <summary>
/// DnsClient Request inner class.
/// </summary>
internal partial class DnsClient {
public class DnsClientRequest : IDnsRequest {
private readonly IDnsRequestResolver _resolver;
private readonly IDnsRequest _request;
public DnsClientRequest(IPEndPoint dns, IDnsRequest request = null, IDnsRequestResolver resolver = null) {
this.Dns = dns;
this._request = request == null ? new DnsRequest() : new DnsRequest(request);
this._resolver = resolver ?? new DnsUdpRequestResolver();
}
public Int32 Id {
get => this._request.Id;
set => this._request.Id = value;
}
public DnsOperationCode OperationCode {
get => this._request.OperationCode;
set => this._request.OperationCode = value;
}
public Boolean RecursionDesired {
get => this._request.RecursionDesired;
set => this._request.RecursionDesired = value;
}
public IList<DnsQuestion> Questions => this._request.Questions;
public Int32 Size => this._request.Size;
public IPEndPoint Dns {
get; set;
}
public Byte[] ToArray() => this._request.ToArray();
public override String ToString() => this._request.ToString();
/// <summary>
/// Resolves this request into a response using the provided DNS information. The given
/// request strategy is used to retrieve the response.
/// </summary>
/// <exception cref="DnsQueryException">Throw if a malformed response is received from the server.</exception>
/// <exception cref="IOException">Thrown if a IO error occurs.</exception>
/// <exception cref="SocketException">Thrown if a the reading or writing to the socket fails.</exception>
/// <returns>The response received from server.</returns>
public DnsClientResponse Resolve() {
try {
DnsClientResponse response = this._resolver.Request(this);
if(response.Id != this.Id) {
throw new DnsQueryException(response, "Mismatching request/response IDs");
}
if(response.ResponseCode != DnsResponseCode.NoError) {
throw new DnsQueryException(response);
}
return response;
} catch(ArgumentException e) {
throw new DnsQueryException("Invalid response", e);
}
}
}
public class DnsRequest : IDnsRequest {
private static readonly Random Random = new Random();
private DnsHeader header;
public DnsRequest() {
this.Questions = new List<DnsQuestion>();
this.header = new DnsHeader {
OperationCode = DnsOperationCode.Query,
Response = false,
Id = Random.Next(UInt16.MaxValue),
};
}
public DnsRequest(IDnsRequest request) {
this.header = new DnsHeader();
this.Questions = new List<DnsQuestion>(request.Questions);
this.header.Response = false;
this.Id = request.Id;
this.OperationCode = request.OperationCode;
this.RecursionDesired = request.RecursionDesired;
}
public IList<DnsQuestion> Questions {
get;
}
public Int32 Size => this.header.Size + this.Questions.Sum(q => q.Size);
public Int32 Id {
get => this.header.Id;
set => this.header.Id = value;
}
public DnsOperationCode OperationCode {
get => this.header.OperationCode;
set => this.header.OperationCode = value;
}
public Boolean RecursionDesired {
get => this.header.RecursionDesired;
set => this.header.RecursionDesired = value;
}
public Byte[] ToArray() {
this.UpdateHeader();
MemoryStream result = new MemoryStream(this.Size);
_ = result
.Append(this.header.ToArray())
.Append(this.Questions.Select(q => q.ToArray()));
return result.ToArray();
}
public override String ToString() {
this.UpdateHeader();
return Json.Serialize(this, true);
}
private void UpdateHeader() => this.header.QuestionCount = this.Questions.Count;
}
public class DnsTcpRequestResolver : IDnsRequestResolver {
public DnsClientResponse Request(DnsClientRequest request) {
TcpClient tcp = new TcpClient();
try {
tcp.Client.Connect(request.Dns);
NetworkStream stream = tcp.GetStream();
Byte[] buffer = request.ToArray();
Byte[] length = BitConverter.GetBytes((UInt16)buffer.Length);
if(BitConverter.IsLittleEndian) {
Array.Reverse(length);
}
stream.Write(length, 0, length.Length);
stream.Write(buffer, 0, buffer.Length);
buffer = new Byte[2];
Read(stream, buffer);
if(BitConverter.IsLittleEndian) {
Array.Reverse(buffer);
}
buffer = new Byte[BitConverter.ToUInt16(buffer, 0)];
Read(stream, buffer);
DnsResponse response = DnsResponse.FromArray(buffer);
return new DnsClientResponse(request, response, buffer);
} finally {
#if NET452
tcp.Close();
#else
tcp.Dispose();
#endif
}
}
private static void Read(Stream stream, Byte[] buffer) {
Int32 length = buffer.Length;
Int32 offset = 0;
Int32 size;
while(length > 0 && (size = stream.Read(buffer, offset, length)) > 0) {
offset += size;
length -= size;
}
if(length > 0) {
throw new IOException("Unexpected end of stream");
}
}
}
public class DnsUdpRequestResolver : IDnsRequestResolver {
private readonly IDnsRequestResolver _fallback;
public DnsUdpRequestResolver(IDnsRequestResolver fallback) => this._fallback = fallback;
public DnsUdpRequestResolver() => this._fallback = new DnsNullRequestResolver();
public DnsClientResponse Request(DnsClientRequest request) {
UdpClient udp = new UdpClient();
IPEndPoint dns = request.Dns;
try {
udp.Client.SendTimeout = 7000;
udp.Client.ReceiveTimeout = 7000;
udp.Client.Connect(dns);
udp.Client.Send(request.ToArray());
List<Byte> bufferList = new List<Byte>();
do {
Byte[] tempBuffer = new Byte[1024];
Int32 receiveCount = udp.Client.Receive(tempBuffer);
bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount));
} while(udp.Client.Available > 0 || bufferList.Count == 0);
Byte[] buffer = bufferList.ToArray();
DnsResponse response = DnsResponse.FromArray(buffer);
return response.IsTruncated
? this._fallback.Request(request)
: new DnsClientResponse(request, response, buffer);
} finally {
#if NET452
udp.Close();
#else
udp.Dispose();
#endif
}
}
}
public class DnsNullRequestResolver : IDnsRequestResolver {
public DnsClientResponse Request(DnsClientRequest request) => throw new DnsQueryException("Request failed");
}
// 12 bytes message header
[StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct DnsHeader {
public const Int32 SIZE = 12;
public static DnsHeader FromArray(Byte[] header) {
if(header.Length < SIZE) {
throw new ArgumentException("Header length too small");
}
return header.ToStruct<DnsHeader>(0, SIZE);
}
private UInt16 id;
// Question count: number of questions in the Question section
private UInt16 questionCount;
// Answer record count: number of records in the Answer section
private UInt16 answerCount;
// Authority record count: number of records in the Authority section
private UInt16 authorityCount;
// Additional record count: number of records in the Additional section
private UInt16 addtionalCount;
public Int32 Id {
get => this.id;
set => this.id = (UInt16)value;
}
public Int32 QuestionCount {
get => this.questionCount;
set => this.questionCount = (UInt16)value;
}
public Int32 AnswerRecordCount {
get => this.answerCount;
set => this.answerCount = (UInt16)value;
}
public Int32 AuthorityRecordCount {
get => this.authorityCount;
set => this.authorityCount = (UInt16)value;
}
public Int32 AdditionalRecordCount {
get => this.addtionalCount;
set => this.addtionalCount = (UInt16)value;
}
public Boolean Response {
get => this.Qr == 1;
set => this.Qr = Convert.ToByte(value);
}
public DnsOperationCode OperationCode {
get => (DnsOperationCode)this.Opcode;
set => this.Opcode = (Byte)value;
}
public Boolean AuthorativeServer {
get => this.Aa == 1;
set => this.Aa = Convert.ToByte(value);
}
public Boolean Truncated {
get => this.Tc == 1;
set => this.Tc = Convert.ToByte(value);
}
public Boolean RecursionDesired {
get => this.Rd == 1;
set => this.Rd = Convert.ToByte(value);
}
public Boolean RecursionAvailable {
get => this.Ra == 1;
set => this.Ra = Convert.ToByte(value);
}
public DnsResponseCode ResponseCode {
get => (DnsResponseCode)this.RCode;
set => this.RCode = (Byte)value;
}
public Int32 Size => SIZE;
// Query/Response Flag
private Byte Qr {
get => this.Flag0.GetBitValueAt(7);
set => this.Flag0 = this.Flag0.SetBitValueAt(7, 1, value);
}
// Operation Code
private Byte Opcode {
get => this.Flag0.GetBitValueAt(3, 4);
set => this.Flag0 = this.Flag0.SetBitValueAt(3, 4, value);
}
// Authorative Answer Flag
private Byte Aa {
get => this.Flag0.GetBitValueAt(2);
set => this.Flag0 = this.Flag0.SetBitValueAt(2, 1, value);
}
// Truncation Flag
private Byte Tc {
get => this.Flag0.GetBitValueAt(1);
set => this.Flag0 = this.Flag0.SetBitValueAt(1, 1, value);
}
// Recursion Desired
private Byte Rd {
get => this.Flag0.GetBitValueAt(0);
set => this.Flag0 = this.Flag0.SetBitValueAt(0, 1, value);
}
// Recursion Available
private Byte Ra {
get => this.Flag1.GetBitValueAt(7);
set => this.Flag1 = this.Flag1.SetBitValueAt(7, 1, value);
}
// Zero (Reserved)
private Byte Z {
get => this.Flag1.GetBitValueAt(4, 3);
set {
}
}
// Response Code
private Byte RCode {
get => this.Flag1.GetBitValueAt(0, 4);
set => this.Flag1 = this.Flag1.SetBitValueAt(0, 4, value);
}
private Byte Flag0 { get;
set; }
private Byte Flag1 { get;
set; }
public Byte[] ToArray() => this.ToBytes();
public override String ToString()
=> Json.SerializeExcluding(this, true, nameof(this.Size));
}
public class DnsDomain : IComparable<DnsDomain> {
private readonly String[] _labels;
public DnsDomain(String domain)
: this(domain.Split('.')) {
}
public DnsDomain(String[] labels) => this._labels = labels;
public Int32 Size => this._labels.Sum(l => l.Length) + this._labels.Length + 1;
public static DnsDomain FromArray(Byte[] message, Int32 offset)
=> FromArray(message, offset, out _);
public static DnsDomain FromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
List<Byte[]> labels = new List<Byte[]>();
Boolean endOffsetAssigned = false;
endOffset = 0;
Byte lengthOrPointer;
while((lengthOrPointer = message[offset++]) > 0) {
// Two heighest bits are set (pointer)
if(lengthOrPointer.GetBitValueAt(6, 2) == 3) {
if(!endOffsetAssigned) {
endOffsetAssigned = true;
endOffset = offset + 1;
}
UInt16 pointer = lengthOrPointer.GetBitValueAt(0, 6);
offset = (pointer << 8) | message[offset];
continue;
}
if(lengthOrPointer.GetBitValueAt(6, 2) != 0) {
throw new ArgumentException("Unexpected bit pattern in label length");
}
Byte length = lengthOrPointer;
Byte[] label = new Byte[length];
Array.Copy(message, offset, label, 0, length);
labels.Add(label);
offset += length;
}
if(!endOffsetAssigned) {
endOffset = offset;
}
return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray());
}
public static DnsDomain PointerName(IPAddress ip)
=> new DnsDomain(FormatReverseIP(ip));
public Byte[] ToArray() {
Byte[] result = new Byte[this.Size];
Int32 offset = 0;
foreach(Byte[] l in this._labels.Select(label => Encoding.ASCII.GetBytes(label))) {
result[offset++] = (Byte)l.Length;
l.CopyTo(result, offset);
offset += l.Length;
}
result[offset] = 0;
return result;
}
public override String ToString()
=> String.Join(".", this._labels);
public Int32 CompareTo(DnsDomain other)
=> String.Compare(this.ToString(), other.ToString(), StringComparison.Ordinal);
public override Boolean Equals(Object obj)
=> obj is DnsDomain domain && this.CompareTo(domain) == 0;
public override Int32 GetHashCode() => this.ToString().GetHashCode();
private static String FormatReverseIP(IPAddress ip) {
Byte[] address = ip.GetAddressBytes();
if(address.Length == 4) {
return String.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa";
}
Byte[] nibbles = new Byte[address.Length * 2];
for(Int32 i = 0, j = 0; i < address.Length; i++, j = 2 * i) {
Byte b = address[i];
nibbles[j] = b.GetBitValueAt(4, 4);
nibbles[j + 1] = b.GetBitValueAt(0, 4);
}
return String.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa";
}
}
public class DnsQuestion : IDnsMessageEntry {
public static IList<DnsQuestion> GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount) =>
GetAllFromArray(message, offset, questionCount, out _);
public static IList<DnsQuestion> GetAllFromArray(
Byte[] message,
Int32 offset,
Int32 questionCount,
out Int32 endOffset) {
IList<DnsQuestion> questions = new List<DnsQuestion>(questionCount);
for(Int32 i = 0; i < questionCount; i++) {
questions.Add(FromArray(message, offset, out offset));
}
endOffset = offset;
return questions;
}
public static DnsQuestion FromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
DnsDomain domain = DnsDomain.FromArray(message, offset, out offset);
Tail tail = message.ToStruct<Tail>(offset, Tail.SIZE);
endOffset = offset + Tail.SIZE;
return new DnsQuestion(domain, tail.Type, tail.Class);
}
public DnsQuestion(
DnsDomain domain,
DnsRecordType type = DnsRecordType.A,
DnsRecordClass klass = DnsRecordClass.IN) {
this.Name = domain;
this.Type = type;
this.Class = klass;
}
public DnsDomain Name {
get;
}
public DnsRecordType Type {
get;
}
public DnsRecordClass Class {
get;
}
public Int32 Size => this.Name.Size + Tail.SIZE;
public Byte[] ToArray() => new MemoryStream(this.Size)
.Append(this.Name.ToArray())
.Append(new Tail { Type = Type, Class = Class }.ToBytes())
.ToArray();
public override String ToString()
=> Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class));
[StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 2)]
private struct Tail {
public const Int32 SIZE = 4;
private UInt16 type;
private UInt16 klass;
public DnsRecordType Type {
get => (DnsRecordType)this.type;
set => this.type = (UInt16)value;
}
public DnsRecordClass Class {
get => (DnsRecordClass)this.klass;
set => this.klass = (UInt16)value;
}
}
}
}
}