Codingstyle nullable

This commit is contained in:
BlubbFish 2019-12-08 21:23:54 +01:00
parent aa9fcd4a36
commit d0b26111dd
50 changed files with 8669 additions and 9749 deletions

View File

@ -1,58 +1,54 @@
namespace Swan.DependencyInjection using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// The concrete implementation of a simple IoC container /// The concrete implementation of a simple IoC container
/// based largely on TinyIoC (https://github.com/grumpydev/TinyIoC). /// based largely on TinyIoC (https://github.com/grumpydev/TinyIoC).
/// </summary> /// </summary>
/// <seealso cref="System.IDisposable" /> /// <seealso cref="System.IDisposable" />
public partial class DependencyContainer : IDisposable public partial class DependencyContainer : IDisposable {
{ private readonly Object _autoRegisterLock = new Object();
private readonly object _autoRegisterLock = new object();
private bool _disposed; private Boolean _disposed;
static DependencyContainer() static DependencyContainer() {
{
} }
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DependencyContainer"/> class. /// Initializes a new instance of the <see cref="DependencyContainer"/> class.
/// </summary> /// </summary>
public DependencyContainer() public DependencyContainer() {
{ this.RegisteredTypes = new TypesConcurrentDictionary(this);
RegisteredTypes = new TypesConcurrentDictionary(this); _ = this.Register(this);
Register(this);
} }
private DependencyContainer(DependencyContainer parent) private DependencyContainer(DependencyContainer parent) : this() => this.Parent = parent;
: this()
{
Parent = parent;
}
/// <summary> /// <summary>
/// Lazy created Singleton instance of the container for simple scenarios. /// Lazy created Singleton instance of the container for simple scenarios.
/// </summary> /// </summary>
public static DependencyContainer Current { get; } = new DependencyContainer(); public static DependencyContainer Current { get; } = new DependencyContainer();
internal DependencyContainer Parent { get; } internal DependencyContainer Parent {
get;
}
internal TypesConcurrentDictionary RegisteredTypes { get; } internal TypesConcurrentDictionary RegisteredTypes {
get;
}
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ if(this._disposed) {
if (_disposed) return; return;
}
_disposed = true; this._disposed = true;
foreach (var disposable in RegisteredTypes.Values.Select(item => item as IDisposable)) foreach(IDisposable disposable in this.RegisteredTypes.Values.Select(item => item as IDisposable)) {
{
disposable?.Dispose(); disposable?.Dispose();
} }
@ -73,16 +69,7 @@
/// </summary> /// </summary>
/// <param name="duplicateAction">What action to take when encountering duplicate implementations of an interface/base class.</param> /// <param name="duplicateAction">What action to take when encountering duplicate implementations of an interface/base class.</param>
/// <param name="registrationPredicate">Predicate to determine if a particular type should be registered.</param> /// <param name="registrationPredicate">Predicate to determine if a particular type should be registered.</param>
public void AutoRegister( public void AutoRegister(DependencyContainerDuplicateImplementationAction duplicateAction = DependencyContainerDuplicateImplementationAction.RegisterSingle, Func<Type, Boolean> registrationPredicate = null) => this.AutoRegister(AppDomain.CurrentDomain.GetAssemblies().Where(a => !IsIgnoredAssembly(a)), duplicateAction, registrationPredicate);
DependencyContainerDuplicateImplementationAction duplicateAction =
DependencyContainerDuplicateImplementationAction.RegisterSingle,
Func<Type, bool> registrationPredicate = null)
{
AutoRegister(
AppDomain.CurrentDomain.GetAssemblies().Where(a => !IsIgnoredAssembly(a)),
duplicateAction,
registrationPredicate);
}
/// <summary> /// <summary>
/// Attempt to automatically register all non-generic classes and interfaces in the specified assemblies /// Attempt to automatically register all non-generic classes and interfaces in the specified assemblies
@ -91,69 +78,45 @@
/// <param name="assemblies">Assemblies to process.</param> /// <param name="assemblies">Assemblies to process.</param>
/// <param name="duplicateAction">What action to take when encountering duplicate implementations of an interface/base class.</param> /// <param name="duplicateAction">What action to take when encountering duplicate implementations of an interface/base class.</param>
/// <param name="registrationPredicate">Predicate to determine if a particular type should be registered.</param> /// <param name="registrationPredicate">Predicate to determine if a particular type should be registered.</param>
public void AutoRegister( public void AutoRegister(IEnumerable<Assembly> assemblies, DependencyContainerDuplicateImplementationAction duplicateAction = DependencyContainerDuplicateImplementationAction.RegisterSingle, Func<Type, Boolean> registrationPredicate = null) {
IEnumerable<Assembly> assemblies, lock(this._autoRegisterLock) {
DependencyContainerDuplicateImplementationAction duplicateAction = List<Type> types = assemblies.SelectMany(a => a.GetAllTypes()).Where(t => !IsIgnoredType(t, registrationPredicate)).ToList();
DependencyContainerDuplicateImplementationAction.RegisterSingle,
Func<Type, bool> registrationPredicate = null)
{
lock (_autoRegisterLock)
{
var types = assemblies
.SelectMany(a => a.GetAllTypes())
.Where(t => !IsIgnoredType(t, registrationPredicate))
.ToList();
var concreteTypes = types List<Type> concreteTypes = types.Where(type => type.IsClass && !type.IsAbstract && type != this.GetType() && type.DeclaringType != this.GetType() && !type.IsGenericTypeDefinition).ToList();
.Where(type =>
type.IsClass && !type.IsAbstract &&
(type != GetType() && (type.DeclaringType != GetType()) && !type.IsGenericTypeDefinition))
.ToList();
foreach (var type in concreteTypes) foreach(Type type in concreteTypes) {
{ try {
try _ = this.RegisteredTypes.Register(type, String.Empty, GetDefaultObjectFactory(type, type));
{ } catch(MethodAccessException) {
RegisteredTypes.Register(type, string.Empty, GetDefaultObjectFactory(type, type));
}
catch (MethodAccessException)
{
// Ignore methods we can't access - added for Silverlight // Ignore methods we can't access - added for Silverlight
} }
} }
var abstractInterfaceTypes = types.Where( IEnumerable<Type> abstractInterfaceTypes = types.Where(type => (type.IsInterface || type.IsAbstract) && type.DeclaringType != this.GetType() && !type.IsGenericTypeDefinition);
type =>
((type.IsInterface || type.IsAbstract) && (type.DeclaringType != GetType()) &&
(!type.IsGenericTypeDefinition)));
foreach (var type in abstractInterfaceTypes) foreach(Type type in abstractInterfaceTypes) {
{ Type localType = type;
var localType = type; List<Type> implementations = concreteTypes.Where(implementationType => localType.IsAssignableFrom(implementationType)).ToList();
var implementations = concreteTypes
.Where(implementationType => localType.IsAssignableFrom(implementationType)).ToList();
if (implementations.Skip(1).Any()) if(implementations.Skip(1).Any()) {
{ if(duplicateAction == DependencyContainerDuplicateImplementationAction.Fail) {
if (duplicateAction == DependencyContainerDuplicateImplementationAction.Fail)
throw new DependencyContainerRegistrationException(type, implementations); throw new DependencyContainerRegistrationException(type, implementations);
}
if (duplicateAction == DependencyContainerDuplicateImplementationAction.RegisterMultiple) if(duplicateAction == DependencyContainerDuplicateImplementationAction.RegisterMultiple) {
{ _ = this.RegisterMultiple(type, implementations);
RegisterMultiple(type, implementations);
} }
} }
var firstImplementation = implementations.FirstOrDefault(); Type firstImplementation = implementations.FirstOrDefault();
if (firstImplementation == null) continue; if(firstImplementation == null) {
continue;
try
{
RegisteredTypes.Register(type, string.Empty, GetDefaultObjectFactory(type, firstImplementation));
} }
catch (MethodAccessException)
{ try {
_ = this.RegisteredTypes.Register(type, String.Empty, GetDefaultObjectFactory(type, firstImplementation));
} catch(MethodAccessException) {
// Ignore methods we can't access - added for Silverlight // Ignore methods we can't access - added for Silverlight
} }
} }
@ -166,11 +129,7 @@
/// <param name="registerType">Type to register.</param> /// <param name="registerType">Type to register.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register(Type registerType, string name = "") public RegisterOptions Register(Type registerType, String name = "") => this.RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerType));
=> RegisteredTypes.Register(
registerType,
name,
GetDefaultObjectFactory(registerType, registerType));
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a given implementation and default options. /// Creates/replaces a named container class registration with a given implementation and default options.
@ -179,8 +138,7 @@
/// <param name="registerImplementation">Type to instantiate that implements RegisterType.</param> /// <param name="registerImplementation">Type to instantiate that implements RegisterType.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register(Type registerType, Type registerImplementation, string name = "") => public RegisterOptions Register(Type registerType, Type registerImplementation, String name = "") => this.RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerImplementation));
RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerImplementation));
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a specific, strong referenced, instance. /// Creates/replaces a named container class registration with a specific, strong referenced, instance.
@ -189,8 +147,7 @@
/// <param name="instance">Instance of RegisterType to register.</param> /// <param name="instance">Instance of RegisterType to register.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register(Type registerType, object instance, string name = "") => public RegisterOptions Register(Type registerType, Object instance, String name = "") => this.RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerType, instance));
RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerType, instance));
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a specific, strong referenced, instance. /// Creates/replaces a named container class registration with a specific, strong referenced, instance.
@ -200,12 +157,7 @@
/// <param name="instance">Instance of RegisterImplementation to register.</param> /// <param name="instance">Instance of RegisterImplementation to register.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register( public RegisterOptions Register(Type registerType, Type registerImplementation, Object instance, String name = "") => this.RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerImplementation, instance));
Type registerType,
Type registerImplementation,
object instance,
string name = "")
=> RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerImplementation, instance));
/// <summary> /// <summary>
/// Creates/replaces a container class registration with a user specified factory. /// Creates/replaces a container class registration with a user specified factory.
@ -214,11 +166,7 @@
/// <param name="factory">Factory/lambda that returns an instance of RegisterType.</param> /// <param name="factory">Factory/lambda that returns an instance of RegisterType.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register( public RegisterOptions Register(Type registerType, Func<DependencyContainer, Dictionary<String, Object>, Object> factory, String name = "") => this.RegisteredTypes.Register(registerType, name, new DelegateFactory(registerType, factory));
Type registerType,
Func<DependencyContainer, Dictionary<string, object>, object> factory,
string name = "")
=> RegisteredTypes.Register(registerType, name, new DelegateFactory(registerType, factory));
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with default options. /// Creates/replaces a named container class registration with default options.
@ -226,11 +174,7 @@
/// <typeparam name="TRegister">Type to register.</typeparam> /// <typeparam name="TRegister">Type to register.</typeparam>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register<TRegister>(string name = "") public RegisterOptions Register<TRegister>(String name = "") where TRegister : class => this.Register(typeof(TRegister), name);
where TRegister : class
{
return Register(typeof(TRegister), name);
}
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a given implementation and default options. /// Creates/replaces a named container class registration with a given implementation and default options.
@ -239,12 +183,7 @@
/// <typeparam name="TRegisterImplementation">Type to instantiate that implements RegisterType.</typeparam> /// <typeparam name="TRegisterImplementation">Type to instantiate that implements RegisterType.</typeparam>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register<TRegister, TRegisterImplementation>(string name = "") public RegisterOptions Register<TRegister, TRegisterImplementation>(String name = "") where TRegister : class where TRegisterImplementation : class, TRegister => this.Register(typeof(TRegister), typeof(TRegisterImplementation), name);
where TRegister : class
where TRegisterImplementation : class, TRegister
{
return Register(typeof(TRegister), typeof(TRegisterImplementation), name);
}
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a specific, strong referenced, instance. /// Creates/replaces a named container class registration with a specific, strong referenced, instance.
@ -253,11 +192,7 @@
/// <param name="instance">Instance of RegisterType to register.</param> /// <param name="instance">Instance of RegisterType to register.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register<TRegister>(TRegister instance, string name = "") public RegisterOptions Register<TRegister>(TRegister instance, String name = "") where TRegister : class => this.Register(typeof(TRegister), instance, name);
where TRegister : class
{
return Register(typeof(TRegister), instance, name);
}
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a specific, strong referenced, instance. /// Creates/replaces a named container class registration with a specific, strong referenced, instance.
@ -267,13 +202,7 @@
/// <param name="instance">Instance of RegisterImplementation to register.</param> /// <param name="instance">Instance of RegisterImplementation to register.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register<TRegister, TRegisterImplementation>(TRegisterImplementation instance, public RegisterOptions Register<TRegister, TRegisterImplementation>(TRegisterImplementation instance, String name = "") where TRegister : class where TRegisterImplementation : class, TRegister => this.Register(typeof(TRegister), typeof(TRegisterImplementation), instance, name);
string name = "")
where TRegister : class
where TRegisterImplementation : class, TRegister
{
return Register(typeof(TRegister), typeof(TRegisterImplementation), instance, name);
}
/// <summary> /// <summary>
/// Creates/replaces a named container class registration with a user specified factory. /// Creates/replaces a named container class registration with a user specified factory.
@ -282,14 +211,12 @@
/// <param name="factory">Factory/lambda that returns an instance of RegisterType.</param> /// <param name="factory">Factory/lambda that returns an instance of RegisterType.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns>RegisterOptions for fluent API.</returns> /// <returns>RegisterOptions for fluent API.</returns>
public RegisterOptions Register<TRegister>( public RegisterOptions Register<TRegister>(Func<DependencyContainer, Dictionary<String, Object>, TRegister> factory, String name = "") where TRegister : class {
Func<DependencyContainer, Dictionary<string, object>, TRegister> factory, string name = "") if(factory == null) {
where TRegister : class
{
if (factory == null)
throw new ArgumentNullException(nameof(factory)); throw new ArgumentNullException(nameof(factory));
}
return Register(typeof(TRegister), factory, name); return this.Register(typeof(TRegister), factory, name);
} }
/// <summary> /// <summary>
@ -300,8 +227,7 @@
/// <typeparam name="TRegister">Type that each implementation implements.</typeparam> /// <typeparam name="TRegister">Type that each implementation implements.</typeparam>
/// <param name="implementationTypes">Types that implement RegisterType.</param> /// <param name="implementationTypes">Types that implement RegisterType.</param>
/// <returns>MultiRegisterOptions for the fluent API.</returns> /// <returns>MultiRegisterOptions for the fluent API.</returns>
public MultiRegisterOptions RegisterMultiple<TRegister>(IEnumerable<Type> implementationTypes) => public MultiRegisterOptions RegisterMultiple<TRegister>(IEnumerable<Type> implementationTypes) => this.RegisterMultiple(typeof(TRegister), implementationTypes);
RegisterMultiple(typeof(TRegister), implementationTypes);
/// <summary> /// <summary>
/// Register multiple implementations of a type. /// Register multiple implementations of a type.
@ -311,33 +237,24 @@
/// <param name="registrationType">Type that each implementation implements.</param> /// <param name="registrationType">Type that each implementation implements.</param>
/// <param name="implementationTypes">Types that implement RegisterType.</param> /// <param name="implementationTypes">Types that implement RegisterType.</param>
/// <returns>MultiRegisterOptions for the fluent API.</returns> /// <returns>MultiRegisterOptions for the fluent API.</returns>
public MultiRegisterOptions RegisterMultiple(Type registrationType, IEnumerable<Type> implementationTypes) public MultiRegisterOptions RegisterMultiple(Type registrationType, IEnumerable<Type> implementationTypes) {
{ if(implementationTypes == null) {
if (implementationTypes == null)
throw new ArgumentNullException(nameof(implementationTypes), "types is null."); throw new ArgumentNullException(nameof(implementationTypes), "types is null.");
foreach (var type in implementationTypes.Where(type => !registrationType.IsAssignableFrom(type)))
{
throw new ArgumentException(
$"types: The type {registrationType.FullName} is not assignable from {type.FullName}");
} }
if (implementationTypes.Count() != implementationTypes.Distinct().Count()) foreach(Type type in implementationTypes.Where(type => !registrationType.IsAssignableFrom(type))) {
{ throw new ArgumentException($"types: The type {registrationType.FullName} is not assignable from {type.FullName}");
var queryForDuplicatedTypes = implementationTypes
.GroupBy(i => i)
.Where(j => j.Count() > 1)
.Select(j => j.Key.FullName);
var fullNamesOfDuplicatedTypes = string.Join(",\n", queryForDuplicatedTypes.ToArray());
throw new ArgumentException(
$"types: The same implementation type cannot be specified multiple times for {registrationType.FullName}\n\n{fullNamesOfDuplicatedTypes}");
} }
var registerOptions = implementationTypes if(implementationTypes.Count() != implementationTypes.Distinct().Count()) {
.Select(type => Register(registrationType, type, type.FullName)) IEnumerable<String> queryForDuplicatedTypes = implementationTypes.GroupBy(i => i).Where(j => j.Count() > 1).Select(j => j.Key.FullName);
.ToList();
String fullNamesOfDuplicatedTypes = String.Join(",\n", queryForDuplicatedTypes.ToArray());
throw new ArgumentException($"types: The same implementation type cannot be specified multiple times for {registrationType.FullName}\n\n{fullNamesOfDuplicatedTypes}");
}
List<RegisterOptions> registerOptions = implementationTypes.Select(type => this.Register(registrationType, type, type.FullName)).ToList();
return new MultiRegisterOptions(registerOptions); return new MultiRegisterOptions(registerOptions);
} }
@ -352,7 +269,7 @@
/// <typeparam name="TRegister">Type to unregister.</typeparam> /// <typeparam name="TRegister">Type to unregister.</typeparam>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns><c>true</c> if the registration is successfully found and removed; otherwise, <c>false</c>.</returns> /// <returns><c>true</c> if the registration is successfully found and removed; otherwise, <c>false</c>.</returns>
public bool Unregister<TRegister>(string name = "") => Unregister(typeof(TRegister), name); public Boolean Unregister<TRegister>(String name = "") => this.Unregister(typeof(TRegister), name);
/// <summary> /// <summary>
/// Remove a named container class registration. /// Remove a named container class registration.
@ -360,8 +277,7 @@
/// <param name="registerType">Type to unregister.</param> /// <param name="registerType">Type to unregister.</param>
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <returns><c>true</c> if the registration is successfully found and removed; otherwise, <c>false</c>.</returns> /// <returns><c>true</c> if the registration is successfully found and removed; otherwise, <c>false</c>.</returns>
public bool Unregister(Type registerType, string name = "") => public Boolean Unregister(Type registerType, String name = "") => this.RegisteredTypes.RemoveRegistration(new TypeRegistration(registerType, name));
RegisteredTypes.RemoveRegistration(new TypeRegistration(registerType, name));
#endregion #endregion
@ -378,11 +294,7 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <returns>Instance of type.</returns> /// <returns>Instance of type.</returns>
/// <exception cref="DependencyContainerResolutionException">Unable to resolve the type.</exception> /// <exception cref="DependencyContainerResolutionException">Unable to resolve the type.</exception>
public object Resolve( public Object Resolve(Type resolveType, String name = null, DependencyContainerResolveOptions options = null) => this.RegisteredTypes.ResolveInternal(new TypeRegistration(resolveType, name), options ?? DependencyContainerResolveOptions.Default);
Type resolveType,
string name = null,
DependencyContainerResolveOptions options = null)
=> RegisteredTypes.ResolveInternal(new TypeRegistration(resolveType, name), options ?? DependencyContainerResolveOptions.Default);
/// <summary> /// <summary>
/// Attempts to resolve a named type using specified options and the supplied constructor parameters. /// Attempts to resolve a named type using specified options and the supplied constructor parameters.
@ -395,13 +307,7 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <returns>Instance of type.</returns> /// <returns>Instance of type.</returns>
/// <exception cref="DependencyContainerResolutionException">Unable to resolve the type.</exception> /// <exception cref="DependencyContainerResolutionException">Unable to resolve the type.</exception>
public TResolveType Resolve<TResolveType>( public TResolveType Resolve<TResolveType>(String name = null, DependencyContainerResolveOptions options = null) where TResolveType : class => (TResolveType)this.Resolve(typeof(TResolveType), name, options);
string name = null,
DependencyContainerResolveOptions options = null)
where TResolveType : class
{
return (TResolveType)Resolve(typeof(TResolveType), name, options);
}
/// <summary> /// <summary>
/// Attempts to predict whether a given type can be resolved with the supplied constructor parameters options. /// Attempts to predict whether a given type can be resolved with the supplied constructor parameters options.
@ -415,11 +321,7 @@
/// <returns> /// <returns>
/// Bool indicating whether the type can be resolved. /// Bool indicating whether the type can be resolved.
/// </returns> /// </returns>
public bool CanResolve( public Boolean CanResolve(Type resolveType, String name = null, DependencyContainerResolveOptions options = null) => this.RegisteredTypes.CanResolve(new TypeRegistration(resolveType, name), options);
Type resolveType,
string name = null,
DependencyContainerResolveOptions options = null) =>
RegisteredTypes.CanResolve(new TypeRegistration(resolveType, name), options);
/// <summary> /// <summary>
/// Attempts to predict whether a given named type can be resolved with the supplied constructor parameters options. /// Attempts to predict whether a given named type can be resolved with the supplied constructor parameters options.
@ -433,13 +335,7 @@
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <returns>Bool indicating whether the type can be resolved.</returns> /// <returns>Bool indicating whether the type can be resolved.</returns>
public bool CanResolve<TResolveType>( public Boolean CanResolve<TResolveType>(String name = null, DependencyContainerResolveOptions options = null) where TResolveType : class => this.CanResolve(typeof(TResolveType), name, options);
string name = null,
DependencyContainerResolveOptions options = null)
where TResolveType : class
{
return CanResolve(typeof(TResolveType), name, options);
}
/// <summary> /// <summary>
/// Attempts to resolve a type using the default options. /// Attempts to resolve a type using the default options.
@ -447,15 +343,11 @@
/// <param name="resolveType">Type to resolve.</param> /// <param name="resolveType">Type to resolve.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve(Type resolveType, out object resolvedType) public Boolean TryResolve(Type resolveType, out Object resolvedType) {
{ try {
try resolvedType = this.Resolve(resolveType);
{
resolvedType = Resolve(resolveType);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = null; resolvedType = null;
return false; return false;
} }
@ -468,15 +360,11 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve(Type resolveType, DependencyContainerResolveOptions options, out object resolvedType) public Boolean TryResolve(Type resolveType, DependencyContainerResolveOptions options, out Object resolvedType) {
{ try {
try resolvedType = this.Resolve(resolveType, options: options);
{
resolvedType = Resolve(resolveType, options: options);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = null; resolvedType = null;
return false; return false;
} }
@ -489,15 +377,11 @@
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve(Type resolveType, string name, out object resolvedType) public Boolean TryResolve(Type resolveType, String name, out Object resolvedType) {
{ try {
try resolvedType = this.Resolve(resolveType, name);
{
resolvedType = Resolve(resolveType, name);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = null; resolvedType = null;
return false; return false;
} }
@ -511,19 +395,11 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve( public Boolean TryResolve(Type resolveType, String name, DependencyContainerResolveOptions options, out Object resolvedType) {
Type resolveType, try {
string name, resolvedType = this.Resolve(resolveType, name, options);
DependencyContainerResolveOptions options,
out object resolvedType)
{
try
{
resolvedType = Resolve(resolveType, name, options);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = null; resolvedType = null;
return false; return false;
} }
@ -535,16 +411,11 @@
/// <typeparam name="TResolveType">Type to resolve.</typeparam> /// <typeparam name="TResolveType">Type to resolve.</typeparam>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve<TResolveType>(out TResolveType resolvedType) public Boolean TryResolve<TResolveType>(out TResolveType resolvedType) where TResolveType : class {
where TResolveType : class try {
{ resolvedType = this.Resolve<TResolveType>();
try
{
resolvedType = Resolve<TResolveType>();
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = default; resolvedType = default;
return false; return false;
} }
@ -557,16 +428,11 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve<TResolveType>(DependencyContainerResolveOptions options, out TResolveType resolvedType) public Boolean TryResolve<TResolveType>(DependencyContainerResolveOptions options, out TResolveType resolvedType) where TResolveType : class {
where TResolveType : class try {
{ resolvedType = this.Resolve<TResolveType>(options: options);
try
{
resolvedType = Resolve<TResolveType>(options: options);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = default; resolvedType = default;
return false; return false;
} }
@ -579,16 +445,11 @@
/// <param name="name">Name of registration.</param> /// <param name="name">Name of registration.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve<TResolveType>(string name, out TResolveType resolvedType) public Boolean TryResolve<TResolveType>(String name, out TResolveType resolvedType) where TResolveType : class {
where TResolveType : class try {
{ resolvedType = this.Resolve<TResolveType>(name);
try
{
resolvedType = Resolve<TResolveType>(name);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = default; resolvedType = default;
return false; return false;
} }
@ -602,19 +463,11 @@
/// <param name="options">Resolution options.</param> /// <param name="options">Resolution options.</param>
/// <param name="resolvedType">Resolved type or default if resolve fails.</param> /// <param name="resolvedType">Resolved type or default if resolve fails.</param>
/// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns> /// <returns><c>true</c> if resolved successfully, <c>false</c> otherwise.</returns>
public bool TryResolve<TResolveType>( public Boolean TryResolve<TResolveType>(String name, DependencyContainerResolveOptions options, out TResolveType resolvedType) where TResolveType : class {
string name, try {
DependencyContainerResolveOptions options, resolvedType = this.Resolve<TResolveType>(name, options);
out TResolveType resolvedType)
where TResolveType : class
{
try
{
resolvedType = Resolve<TResolveType>(name, options);
return true; return true;
} } catch(DependencyContainerResolutionException) {
catch (DependencyContainerResolutionException)
{
resolvedType = default; resolvedType = default;
return false; return false;
} }
@ -626,8 +479,7 @@
/// <param name="resolveType">Type to resolveAll.</param> /// <param name="resolveType">Type to resolveAll.</param>
/// <param name="includeUnnamed">Whether to include un-named (default) registrations.</param> /// <param name="includeUnnamed">Whether to include un-named (default) registrations.</param>
/// <returns>IEnumerable.</returns> /// <returns>IEnumerable.</returns>
public IEnumerable<object> ResolveAll(Type resolveType, bool includeUnnamed = false) public IEnumerable<Object> ResolveAll(Type resolveType, Boolean includeUnnamed = false) => this.RegisteredTypes.Resolve(resolveType, includeUnnamed);
=> RegisteredTypes.Resolve(resolveType, includeUnnamed);
/// <summary> /// <summary>
/// Returns all registrations of a type. /// Returns all registrations of a type.
@ -635,38 +487,24 @@
/// <typeparam name="TResolveType">Type to resolveAll.</typeparam> /// <typeparam name="TResolveType">Type to resolveAll.</typeparam>
/// <param name="includeUnnamed">Whether to include un-named (default) registrations.</param> /// <param name="includeUnnamed">Whether to include un-named (default) registrations.</param>
/// <returns>IEnumerable.</returns> /// <returns>IEnumerable.</returns>
public IEnumerable<TResolveType> ResolveAll<TResolveType>(bool includeUnnamed = true) public IEnumerable<TResolveType> ResolveAll<TResolveType>(Boolean includeUnnamed = true) where TResolveType : class => this.ResolveAll(typeof(TResolveType), includeUnnamed).Cast<TResolveType>();
where TResolveType : class
{
return ResolveAll(typeof(TResolveType), includeUnnamed).Cast<TResolveType>();
}
/// <summary> /// <summary>
/// Attempts to resolve all public property dependencies on the given object using the given resolve options. /// Attempts to resolve all public property dependencies on the given object using the given resolve options.
/// </summary> /// </summary>
/// <param name="input">Object to "build up".</param> /// <param name="input">Object to "build up".</param>
/// <param name="resolveOptions">Resolve options to use.</param> /// <param name="resolveOptions">Resolve options to use.</param>
public void BuildUp(object input, DependencyContainerResolveOptions resolveOptions = null) public void BuildUp(Object input, DependencyContainerResolveOptions resolveOptions = null) {
{ if(resolveOptions == null) {
if (resolveOptions == null)
resolveOptions = DependencyContainerResolveOptions.Default; resolveOptions = DependencyContainerResolveOptions.Default;
var properties = input.GetType()
.GetProperties()
.Where(property => property.GetCacheGetMethod() != null && property.GetCacheSetMethod() != null &&
!property.PropertyType.IsValueType);
foreach (var property in properties.Where(property => property.GetValue(input, null) == null))
{
try
{
property.SetValue(
input,
RegisteredTypes.ResolveInternal(new TypeRegistration(property.PropertyType), resolveOptions),
null);
} }
catch (DependencyContainerResolutionException)
{ IEnumerable<PropertyInfo> properties = input.GetType().GetProperties().Where(property => property.GetCacheGetMethod() != null && property.GetCacheSetMethod() != null && !property.PropertyType.IsValueType);
foreach(PropertyInfo property in properties.Where(property => property.GetValue(input, null) == null)) {
try {
property.SetValue(input, this.RegisteredTypes.ResolveInternal(new TypeRegistration(property.PropertyType), resolveOptions), null);
} catch(DependencyContainerResolutionException) {
// Catch any resolution errors and ignore them // Catch any resolution errors and ignore them
} }
} }
@ -676,29 +514,27 @@
#region Internal Methods #region Internal Methods
internal static bool IsValidAssignment(Type registerType, Type registerImplementation) internal static Boolean IsValidAssignment(Type registerType, Type registerImplementation) {
{ if(!registerType.IsGenericTypeDefinition) {
if (!registerType.IsGenericTypeDefinition) if(!registerType.IsAssignableFrom(registerImplementation)) {
{
if (!registerType.IsAssignableFrom(registerImplementation))
return false; return false;
} }
else } else {
{ if(registerType.IsInterface && registerImplementation.GetInterfaces().All(t => t.Name != registerType.Name)) {
if (registerType.IsInterface && registerImplementation.GetInterfaces().All(t => t.Name != registerType.Name))
return false; return false;
}
if (registerType.IsAbstract && registerImplementation.BaseType != registerType) if(registerType.IsAbstract && registerImplementation.BaseType != registerType) {
return false; return false;
} }
}
return true; return true;
} }
private static bool IsIgnoredAssembly(Assembly assembly) private static Boolean IsIgnoredAssembly(Assembly assembly) {
{
// TODO - find a better way to remove "system" assemblies from the auto registration // TODO - find a better way to remove "system" assemblies from the auto registration
var ignoreChecks = new List<Func<Assembly, bool>> List<Func<Assembly, Boolean>> ignoreChecks = new List<Func<Assembly, Boolean>>
{ {
asm => asm.FullName.StartsWith("Microsoft.", StringComparison.Ordinal), asm => asm.FullName.StartsWith("Microsoft.", StringComparison.Ordinal),
asm => asm.FullName.StartsWith("System.", StringComparison.Ordinal), asm => asm.FullName.StartsWith("System.", StringComparison.Ordinal),
@ -713,30 +549,26 @@
return ignoreChecks.Any(check => check(assembly)); return ignoreChecks.Any(check => check(assembly));
} }
private static bool IsIgnoredType(Type type, Func<Type, bool> registrationPredicate) private static Boolean IsIgnoredType(Type type, Func<Type, Boolean> registrationPredicate) {
{
// TODO - find a better way to remove "system" types from the auto registration // TODO - find a better way to remove "system" types from the auto registration
var ignoreChecks = new List<Func<Type, bool>>() List<Func<Type, Boolean>> ignoreChecks = new List<Func<Type, Boolean>>()
{ {
t => t.FullName?.StartsWith("System.", StringComparison.Ordinal) ?? false, t => t.FullName?.StartsWith("System.", StringComparison.Ordinal) ?? false,
t => t.FullName?.StartsWith("Microsoft.", StringComparison.Ordinal) ?? false, t => t.FullName?.StartsWith("Microsoft.", StringComparison.Ordinal) ?? false,
t => t.IsPrimitive, t => t.IsPrimitive,
t => t.IsGenericTypeDefinition, t => t.IsGenericTypeDefinition,
t => (t.GetConstructors(BindingFlags.Instance | BindingFlags.Public).Length == 0) && t => t.GetConstructors(BindingFlags.Instance | BindingFlags.Public).Length == 0 &&
!(t.IsInterface || t.IsAbstract), !(t.IsInterface || t.IsAbstract),
}; };
if (registrationPredicate != null) if(registrationPredicate != null) {
{
ignoreChecks.Add(t => !registrationPredicate(t)); ignoreChecks.Add(t => !registrationPredicate(t));
} }
return ignoreChecks.Any(check => check(type)); return ignoreChecks.Any(check => check(type));
} }
private static ObjectFactoryBase GetDefaultObjectFactory(Type registerType, Type registerImplementation) => registerType.IsInterface || registerType.IsAbstract private static ObjectFactoryBase GetDefaultObjectFactory(Type registerType, Type registerImplementation) => registerType.IsInterface || registerType.IsAbstract ? (ObjectFactoryBase)new SingletonFactory(registerType, registerImplementation) : new MultiInstanceFactory(registerType, registerImplementation);
? (ObjectFactoryBase)new SingletonFactory(registerType, registerImplementation)
: new MultiInstanceFactory(registerType, registerImplementation);
#endregion #endregion
} }

View File

@ -1,28 +1,23 @@
namespace Swan.DependencyInjection using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Generic Constraint Registration Exception. /// Generic Constraint Registration Exception.
/// </summary> /// </summary>
/// <seealso cref="Exception" /> /// <seealso cref="Exception" />
public class DependencyContainerRegistrationException : Exception public class DependencyContainerRegistrationException : Exception {
{ private const String ConvertErrorText = "Cannot convert current registration of {0} to {1}";
private const string ConvertErrorText = "Cannot convert current registration of {0} to {1}"; private const String RegisterErrorText = "Cannot register type {0} - abstract classes or interfaces are not valid implementation types for {1}.";
private const string RegisterErrorText = private const String ErrorText = "Duplicate implementation of type {0} found ({1}).";
"Cannot register type {0} - abstract classes or interfaces are not valid implementation types for {1}.";
private const string ErrorText = "Duplicate implementation of type {0} found ({1}).";
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DependencyContainerRegistrationException"/> class. /// Initializes a new instance of the <see cref="DependencyContainerRegistrationException"/> class.
/// </summary> /// </summary>
/// <param name="registerType">Type of the register.</param> /// <param name="registerType">Type of the register.</param>
/// <param name="types">The types.</param> /// <param name="types">The types.</param>
public DependencyContainerRegistrationException(Type registerType, IEnumerable<Type> types) public DependencyContainerRegistrationException(Type registerType, IEnumerable<Type> types) : base(String.Format(ErrorText, registerType, GetTypesString(types))) {
: base(string.Format(ErrorText, registerType, GetTypesString(types)))
{
} }
/// <summary> /// <summary>
@ -31,16 +26,9 @@
/// <param name="type">The type.</param> /// <param name="type">The type.</param>
/// <param name="method">The method.</param> /// <param name="method">The method.</param>
/// <param name="isTypeFactory">if set to <c>true</c> [is type factory].</param> /// <param name="isTypeFactory">if set to <c>true</c> [is type factory].</param>
public DependencyContainerRegistrationException(Type type, string method, bool isTypeFactory = false) public DependencyContainerRegistrationException(Type type, String method, Boolean isTypeFactory = false) : base(isTypeFactory ? String.Format(RegisterErrorText, type.FullName, method) : String.Format(ConvertErrorText, type.FullName, method)) {
: base(isTypeFactory
? string.Format(RegisterErrorText, type.FullName, method)
: string.Format(ConvertErrorText, type.FullName, method))
{
} }
private static string GetTypesString(IEnumerable<Type> types) private static String GetTypesString(IEnumerable<Type> types) => String.Join(",", types.Select(type => type.FullName));
{
return string.Join(",", types.Select(type => type.FullName));
}
} }
} }

View File

@ -1,21 +1,17 @@
namespace Swan.DependencyInjection using System;
{
using System;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// An exception for dependency resolutions. /// An exception for dependency resolutions.
/// </summary> /// </summary>
/// <seealso cref="System.Exception" /> /// <seealso cref="System.Exception" />
[Serializable] [Serializable]
public class DependencyContainerResolutionException : Exception public class DependencyContainerResolutionException : Exception {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DependencyContainerResolutionException"/> class. /// Initializes a new instance of the <see cref="DependencyContainerResolutionException"/> class.
/// </summary> /// </summary>
/// <param name="type">The type.</param> /// <param name="type">The type.</param>
public DependencyContainerResolutionException(Type type) public DependencyContainerResolutionException(Type type) : base($"Unable to resolve type: {type.FullName}") {
: base($"Unable to resolve type: {type.FullName}")
{
} }
/// <summary> /// <summary>
@ -23,9 +19,7 @@
/// </summary> /// </summary>
/// <param name="type">The type.</param> /// <param name="type">The type.</param>
/// <param name="innerException">The inner exception.</param> /// <param name="innerException">The inner exception.</param>
public DependencyContainerResolutionException(Type type, Exception innerException) public DependencyContainerResolutionException(Type type, Exception innerException) : base($"Unable to resolve type: {type.FullName}", innerException) {
: base($"Unable to resolve type: {type.FullName}", innerException)
{
} }
} }
} }

View File

@ -1,12 +1,10 @@
namespace Swan.DependencyInjection using System.Collections.Generic;
{
using System.Collections.Generic;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Resolution settings. /// Resolution settings.
/// </summary> /// </summary>
public class DependencyContainerResolveOptions public class DependencyContainerResolveOptions {
{
/// <summary> /// <summary>
/// Gets the default options (attempt resolution of unregistered types, fail on named resolution if name not found). /// Gets the default options (attempt resolution of unregistered types, fail on named resolution if name not found).
/// </summary> /// </summary>
@ -18,8 +16,7 @@
/// <value> /// <value>
/// The unregistered resolution action. /// The unregistered resolution action.
/// </value> /// </value>
public DependencyContainerUnregisteredResolutionAction UnregisteredResolutionAction { get; set; } = public DependencyContainerUnregisteredResolutionAction UnregisteredResolutionAction { get; set; } = DependencyContainerUnregisteredResolutionAction.AttemptResolve;
DependencyContainerUnregisteredResolutionAction.AttemptResolve;
/// <summary> /// <summary>
/// Gets or sets the named resolution failure action. /// Gets or sets the named resolution failure action.
@ -27,8 +24,7 @@
/// <value> /// <value>
/// The named resolution failure action. /// The named resolution failure action.
/// </value> /// </value>
public DependencyContainerNamedResolutionFailureAction NamedResolutionFailureAction { get; set; } = public DependencyContainerNamedResolutionFailureAction NamedResolutionFailureAction { get; set; } = DependencyContainerNamedResolutionFailureAction.Fail;
DependencyContainerNamedResolutionFailureAction.Fail;
/// <summary> /// <summary>
/// Gets the constructor parameters. /// Gets the constructor parameters.
@ -36,14 +32,13 @@
/// <value> /// <value>
/// The constructor parameters. /// The constructor parameters.
/// </value> /// </value>
public Dictionary<string, object> ConstructorParameters { get; } = new Dictionary<string, object>(); public Dictionary<System.String, System.Object> ConstructorParameters { get; } = new Dictionary<System.String, System.Object>();
/// <summary> /// <summary>
/// Clones this instance. /// Clones this instance.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public DependencyContainerResolveOptions Clone() => new DependencyContainerResolveOptions public DependencyContainerResolveOptions Clone() => new DependencyContainerResolveOptions {
{
NamedResolutionFailureAction = NamedResolutionFailureAction, NamedResolutionFailureAction = NamedResolutionFailureAction,
UnregisteredResolutionAction = UnregisteredResolutionAction, UnregisteredResolutionAction = UnregisteredResolutionAction,
}; };
@ -52,8 +47,7 @@
/// <summary> /// <summary>
/// Defines Resolution actions. /// Defines Resolution actions.
/// </summary> /// </summary>
public enum DependencyContainerUnregisteredResolutionAction public enum DependencyContainerUnregisteredResolutionAction {
{
/// <summary> /// <summary>
/// Attempt to resolve type, even if the type isn't registered. /// Attempt to resolve type, even if the type isn't registered.
/// ///
@ -78,8 +72,7 @@
/// <summary> /// <summary>
/// Enumerates failure actions. /// Enumerates failure actions.
/// </summary> /// </summary>
public enum DependencyContainerNamedResolutionFailureAction public enum DependencyContainerNamedResolutionFailureAction {
{
/// <summary> /// <summary>
/// The attempt unnamed resolution /// The attempt unnamed resolution
/// </summary> /// </summary>
@ -94,8 +87,7 @@
/// <summary> /// <summary>
/// Enumerates duplicate definition actions. /// Enumerates duplicate definition actions.
/// </summary> /// </summary>
public enum DependencyContainerDuplicateImplementationAction public enum DependencyContainerDuplicateImplementationAction {
{
/// <summary> /// <summary>
/// The register single /// The register single
/// </summary> /// </summary>

View File

@ -1,22 +1,18 @@
namespace Swan.DependencyInjection using System;
{
using System;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Weak Reference Exception. /// Weak Reference Exception.
/// </summary> /// </summary>
/// <seealso cref="System.Exception" /> /// <seealso cref="System.Exception" />
public class DependencyContainerWeakReferenceException : Exception public class DependencyContainerWeakReferenceException : Exception {
{ private const String ErrorText = "Unable to instantiate {0} - referenced object has been reclaimed";
private const string ErrorText = "Unable to instantiate {0} - referenced object has been reclaimed";
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DependencyContainerWeakReferenceException"/> class. /// Initializes a new instance of the <see cref="DependencyContainerWeakReferenceException"/> class.
/// </summary> /// </summary>
/// <param name="type">The type.</param> /// <param name="type">The type.</param>
public DependencyContainerWeakReferenceException(Type type) public DependencyContainerWeakReferenceException(Type type) : base(String.Format(ErrorText, type.FullName)) {
: base(string.Format(ErrorText, type.FullName))
{
} }
} }
} }

View File

@ -1,31 +1,33 @@
namespace Swan.DependencyInjection using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Reflection; using System.Reflection;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Represents an abstract class for Object Factory. /// Represents an abstract class for Object Factory.
/// </summary> /// </summary>
public abstract class ObjectFactoryBase public abstract class ObjectFactoryBase {
{
/// <summary> /// <summary>
/// Whether to assume this factory successfully constructs its objects /// Whether to assume this factory successfully constructs its objects
/// ///
/// Generally set to true for delegate style factories as CanResolve cannot delve /// Generally set to true for delegate style factories as CanResolve cannot delve
/// into the delegates they contain. /// into the delegates they contain.
/// </summary> /// </summary>
public virtual bool AssumeConstruction => false; public virtual Boolean AssumeConstruction => false;
/// <summary> /// <summary>
/// The type the factory instantiates. /// The type the factory instantiates.
/// </summary> /// </summary>
public abstract Type CreatesType { get; } public abstract Type CreatesType {
get;
}
/// <summary> /// <summary>
/// Constructor to use, if specified. /// Constructor to use, if specified.
/// </summary> /// </summary>
public ConstructorInfo Constructor { get; private set; } public ConstructorInfo Constructor {
get; private set;
}
/// <summary> /// <summary>
/// Gets the singleton variant. /// Gets the singleton variant.
@ -34,8 +36,7 @@
/// The singleton variant. /// The singleton variant.
/// </value> /// </value>
/// <exception cref="DependencyContainerRegistrationException">singleton.</exception> /// <exception cref="DependencyContainerRegistrationException">singleton.</exception>
public virtual ObjectFactoryBase SingletonVariant => public virtual ObjectFactoryBase SingletonVariant => throw new DependencyContainerRegistrationException(this.GetType(), "singleton");
throw new DependencyContainerRegistrationException(GetType(), "singleton");
/// <summary> /// <summary>
/// Gets the multi instance variant. /// Gets the multi instance variant.
@ -44,8 +45,7 @@
/// The multi instance variant. /// The multi instance variant.
/// </value> /// </value>
/// <exception cref="DependencyContainerRegistrationException">multi-instance.</exception> /// <exception cref="DependencyContainerRegistrationException">multi-instance.</exception>
public virtual ObjectFactoryBase MultiInstanceVariant => public virtual ObjectFactoryBase MultiInstanceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "multi-instance");
throw new DependencyContainerRegistrationException(GetType(), "multi-instance");
/// <summary> /// <summary>
/// Gets the strong reference variant. /// Gets the strong reference variant.
@ -54,8 +54,7 @@
/// The strong reference variant. /// The strong reference variant.
/// </value> /// </value>
/// <exception cref="DependencyContainerRegistrationException">strong reference.</exception> /// <exception cref="DependencyContainerRegistrationException">strong reference.</exception>
public virtual ObjectFactoryBase StrongReferenceVariant => public virtual ObjectFactoryBase StrongReferenceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "strong reference");
throw new DependencyContainerRegistrationException(GetType(), "strong reference");
/// <summary> /// <summary>
/// Gets the weak reference variant. /// Gets the weak reference variant.
@ -64,8 +63,7 @@
/// The weak reference variant. /// The weak reference variant.
/// </value> /// </value>
/// <exception cref="DependencyContainerRegistrationException">weak reference.</exception> /// <exception cref="DependencyContainerRegistrationException">weak reference.</exception>
public virtual ObjectFactoryBase WeakReferenceVariant => public virtual ObjectFactoryBase WeakReferenceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "weak reference");
throw new DependencyContainerRegistrationException(GetType(), "weak reference");
/// <summary> /// <summary>
/// Create the type. /// Create the type.
@ -74,10 +72,7 @@
/// <param name="container">Container that requested the creation.</param> /// <param name="container">Container that requested the creation.</param>
/// <param name="options">The options.</param> /// <param name="options">The options.</param>
/// <returns> Instance of type. </returns> /// <returns> Instance of type. </returns>
public abstract object GetObject( public abstract Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options);
Type requestedType,
DependencyContainer container,
DependencyContainerResolveOptions options);
/// <summary> /// <summary>
/// Gets the factory for child container. /// Gets the factory for child container.
@ -86,63 +81,42 @@
/// <param name="parent">The parent.</param> /// <param name="parent">The parent.</param>
/// <param name="child">The child.</param> /// <param name="child">The child.</param>
/// <returns></returns> /// <returns></returns>
public virtual ObjectFactoryBase GetFactoryForChildContainer( public virtual ObjectFactoryBase GetFactoryForChildContainer(Type type, DependencyContainer parent, DependencyContainer child) => this;
Type type,
DependencyContainer parent,
DependencyContainer child)
{
return this;
}
} }
/// <inheritdoc /> /// <inheritdoc />
/// <summary> /// <summary>
/// IObjectFactory that creates new instances of types for each resolution. /// IObjectFactory that creates new instances of types for each resolution.
/// </summary> /// </summary>
internal class MultiInstanceFactory : ObjectFactoryBase internal class MultiInstanceFactory : ObjectFactoryBase {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly Type _registerImplementation; private readonly Type _registerImplementation;
public MultiInstanceFactory(Type registerType, Type registerImplementation) public MultiInstanceFactory(Type registerType, Type registerImplementation) {
{ if(registerImplementation.IsAbstract || registerImplementation.IsInterface) {
if (registerImplementation.IsAbstract || registerImplementation.IsInterface) throw new DependencyContainerRegistrationException(registerImplementation, "MultiInstanceFactory", true);
{
throw new DependencyContainerRegistrationException(registerImplementation,
"MultiInstanceFactory",
true);
} }
if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) {
{ throw new DependencyContainerRegistrationException(registerImplementation, "MultiInstanceFactory", true);
throw new DependencyContainerRegistrationException(registerImplementation,
"MultiInstanceFactory",
true);
} }
_registerType = registerType; this._registerType = registerType;
_registerImplementation = registerImplementation; this._registerImplementation = registerImplementation;
} }
public override Type CreatesType => _registerImplementation; public override Type CreatesType => this._registerImplementation;
public override ObjectFactoryBase SingletonVariant => public override ObjectFactoryBase SingletonVariant =>
new SingletonFactory(_registerType, _registerImplementation); new SingletonFactory(this._registerType, this._registerImplementation);
public override ObjectFactoryBase MultiInstanceVariant => this; public override ObjectFactoryBase MultiInstanceVariant => this;
public override object GetObject( public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) {
Type requestedType, try {
DependencyContainer container, return container.RegisteredTypes.ConstructType(this._registerImplementation, this.Constructor, options);
DependencyContainerResolveOptions options) } catch(DependencyContainerResolutionException ex) {
{ throw new DependencyContainerResolutionException(this._registerType, ex);
try
{
return container.RegisteredTypes.ConstructType(_registerImplementation, Constructor, options);
}
catch (DependencyContainerResolutionException ex)
{
throw new DependencyContainerResolutionException(_registerType, ex);
} }
} }
} }
@ -151,41 +125,32 @@
/// <summary> /// <summary>
/// IObjectFactory that invokes a specified delegate to construct the object. /// IObjectFactory that invokes a specified delegate to construct the object.
/// </summary> /// </summary>
internal class DelegateFactory : ObjectFactoryBase internal class DelegateFactory : ObjectFactoryBase {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly Func<DependencyContainer, Dictionary<string, object>, object> _factory; private readonly Func<DependencyContainer, Dictionary<String, Object>, Object> _factory;
public DelegateFactory( public DelegateFactory(
Type registerType, Type registerType,
Func<DependencyContainer, Dictionary<string, object>, object> factory) Func<DependencyContainer, Dictionary<String, Object>, Object> factory) {
{ this._factory = factory ?? throw new ArgumentNullException(nameof(factory));
_factory = factory ?? throw new ArgumentNullException(nameof(factory));
_registerType = registerType; this._registerType = registerType;
} }
public override bool AssumeConstruction => true; public override Boolean AssumeConstruction => true;
public override Type CreatesType => _registerType; public override Type CreatesType => this._registerType;
public override ObjectFactoryBase WeakReferenceVariant => new WeakDelegateFactory(_registerType, _factory); public override ObjectFactoryBase WeakReferenceVariant => new WeakDelegateFactory(this._registerType, this._factory);
public override ObjectFactoryBase StrongReferenceVariant => this; public override ObjectFactoryBase StrongReferenceVariant => this;
public override object GetObject( public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) {
Type requestedType, try {
DependencyContainer container, return this._factory.Invoke(container, options.ConstructorParameters);
DependencyContainerResolveOptions options) } catch(Exception ex) {
{ throw new DependencyContainerResolutionException(this._registerType, ex);
try
{
return _factory.Invoke(container, options.ConstructorParameters);
}
catch (Exception ex)
{
throw new DependencyContainerResolutionException(_registerType, ex);
} }
} }
} }
@ -195,56 +160,46 @@
/// IObjectFactory that invokes a specified delegate to construct the object /// IObjectFactory that invokes a specified delegate to construct the object
/// Holds the delegate using a weak reference. /// Holds the delegate using a weak reference.
/// </summary> /// </summary>
internal class WeakDelegateFactory : ObjectFactoryBase internal class WeakDelegateFactory : ObjectFactoryBase {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly WeakReference _factory; private readonly WeakReference _factory;
public WeakDelegateFactory( public WeakDelegateFactory(Type registerType, Func<DependencyContainer, Dictionary<String, Object>, Object> factory) {
Type registerType, if(factory == null) {
Func<DependencyContainer, Dictionary<string, object>, object> factory)
{
if (factory == null)
throw new ArgumentNullException(nameof(factory)); throw new ArgumentNullException(nameof(factory));
_factory = new WeakReference(factory);
_registerType = registerType;
} }
public override bool AssumeConstruction => true; this._factory = new WeakReference(factory);
public override Type CreatesType => _registerType; this._registerType = registerType;
}
public override ObjectFactoryBase StrongReferenceVariant public override Boolean AssumeConstruction => true;
{
get
{
if (!(_factory.Target is Func<DependencyContainer, Dictionary<string, object>, object> factory))
throw new DependencyContainerWeakReferenceException(_registerType);
return new DelegateFactory(_registerType, factory); public override Type CreatesType => this._registerType;
public override ObjectFactoryBase StrongReferenceVariant {
get {
if(!(this._factory.Target is Func<DependencyContainer, Dictionary<global::System.String, global::System.Object>, global::System.Object> factory)) {
throw new DependencyContainerWeakReferenceException(this._registerType);
}
return new DelegateFactory(this._registerType, factory);
} }
} }
public override ObjectFactoryBase WeakReferenceVariant => this; public override ObjectFactoryBase WeakReferenceVariant => this;
public override object GetObject( public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) {
Type requestedType, if(!(this._factory.Target is Func<DependencyContainer, Dictionary<global::System.String, global::System.Object>, global::System.Object> factory)) {
DependencyContainer container, throw new DependencyContainerWeakReferenceException(this._registerType);
DependencyContainerResolveOptions options)
{
if (!(_factory.Target is Func<DependencyContainer, Dictionary<string, object>, object> factory))
throw new DependencyContainerWeakReferenceException(_registerType);
try
{
return factory.Invoke(container, options.ConstructorParameters);
} }
catch (Exception ex)
{ try {
throw new DependencyContainerResolutionException(_registerType, ex); return factory.Invoke(container, options.ConstructorParameters);
} catch(Exception ex) {
throw new DependencyContainerResolutionException(this._registerType, ex);
} }
} }
} }
@ -252,45 +207,35 @@
/// <summary> /// <summary>
/// Stores an particular instance to return for a type. /// Stores an particular instance to return for a type.
/// </summary> /// </summary>
internal class InstanceFactory : ObjectFactoryBase, IDisposable internal class InstanceFactory : ObjectFactoryBase, IDisposable {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly Type _registerImplementation; private readonly Type _registerImplementation;
private readonly object _instance; private readonly Object _instance;
public InstanceFactory(Type registerType, Type registerImplementation, object instance) public InstanceFactory(Type registerType, Type registerImplementation, Object instance) {
{ if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) {
if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation))
throw new DependencyContainerRegistrationException(registerImplementation, "InstanceFactory", true); throw new DependencyContainerRegistrationException(registerImplementation, "InstanceFactory", true);
_registerType = registerType;
_registerImplementation = registerImplementation;
_instance = instance;
} }
public override bool AssumeConstruction => true; this._registerType = registerType;
this._registerImplementation = registerImplementation;
this._instance = instance;
}
public override Type CreatesType => _registerImplementation; public override Boolean AssumeConstruction => true;
public override ObjectFactoryBase MultiInstanceVariant => public override Type CreatesType => this._registerImplementation;
new MultiInstanceFactory(_registerType, _registerImplementation);
public override ObjectFactoryBase WeakReferenceVariant => public override ObjectFactoryBase MultiInstanceVariant => new MultiInstanceFactory(this._registerType, this._registerImplementation);
new WeakInstanceFactory(_registerType, _registerImplementation, _instance);
public override ObjectFactoryBase WeakReferenceVariant => new WeakInstanceFactory(this._registerType, this._registerImplementation, this._instance);
public override ObjectFactoryBase StrongReferenceVariant => this; public override ObjectFactoryBase StrongReferenceVariant => this;
public override object GetObject( public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) => this._instance;
Type requestedType,
DependencyContainer container,
DependencyContainerResolveOptions options)
{
return _instance;
}
public void Dispose() public void Dispose() {
{ IDisposable disposable = this._instance as IDisposable;
var disposable = _instance as IDisposable;
disposable?.Dispose(); disposable?.Dispose();
} }
@ -299,125 +244,109 @@
/// <summary> /// <summary>
/// Stores the instance with a weak reference. /// Stores the instance with a weak reference.
/// </summary> /// </summary>
internal class WeakInstanceFactory : ObjectFactoryBase, IDisposable internal class WeakInstanceFactory : ObjectFactoryBase, IDisposable {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly Type _registerImplementation; private readonly Type _registerImplementation;
private readonly WeakReference _instance; private readonly WeakReference _instance;
public WeakInstanceFactory(Type registerType, Type registerImplementation, object instance) public WeakInstanceFactory(Type registerType, Type registerImplementation, Object instance) {
{ if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) {
if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) throw new DependencyContainerRegistrationException(registerImplementation, "WeakInstanceFactory", true);
{
throw new DependencyContainerRegistrationException(
registerImplementation,
"WeakInstanceFactory",
true);
} }
_registerType = registerType; this._registerType = registerType;
_registerImplementation = registerImplementation; this._registerImplementation = registerImplementation;
_instance = new WeakReference(instance); this._instance = new WeakReference(instance);
} }
public override Type CreatesType => _registerImplementation; public override Type CreatesType => this._registerImplementation;
public override ObjectFactoryBase MultiInstanceVariant => public override ObjectFactoryBase MultiInstanceVariant => new MultiInstanceFactory(this._registerType, this._registerImplementation);
new MultiInstanceFactory(_registerType, _registerImplementation);
public override ObjectFactoryBase WeakReferenceVariant => this; public override ObjectFactoryBase WeakReferenceVariant => this;
public override ObjectFactoryBase StrongReferenceVariant public override ObjectFactoryBase StrongReferenceVariant {
{ get {
get Object instance = this._instance.Target;
{
var instance = _instance.Target;
if (instance == null) if(instance == null) {
throw new DependencyContainerWeakReferenceException(_registerType); throw new DependencyContainerWeakReferenceException(this._registerType);
}
return new InstanceFactory(_registerType, _registerImplementation, instance); return new InstanceFactory(this._registerType, this._registerImplementation, instance);
} }
} }
public override object GetObject( public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) {
Type requestedType, Object instance = this._instance.Target;
DependencyContainer container,
DependencyContainerResolveOptions options)
{
var instance = _instance.Target;
if (instance == null) if(instance == null) {
throw new DependencyContainerWeakReferenceException(_registerType); throw new DependencyContainerWeakReferenceException(this._registerType);
}
return instance; return instance;
} }
public void Dispose() => (_instance.Target as IDisposable)?.Dispose(); public void Dispose() => (this._instance.Target as IDisposable)?.Dispose();
} }
/// <summary> /// <summary>
/// A factory that lazy instantiates a type and always returns the same instance. /// A factory that lazy instantiates a type and always returns the same instance.
/// </summary> /// </summary>
internal class SingletonFactory : ObjectFactoryBase, IDisposable internal class SingletonFactory : ObjectFactoryBase, IDisposable {
{
private readonly Type _registerType; private readonly Type _registerType;
private readonly Type _registerImplementation; private readonly Type _registerImplementation;
private readonly object _singletonLock = new object(); private readonly Object _singletonLock = new Object();
private object _current; private Object _current;
public SingletonFactory(Type registerType, Type registerImplementation) public SingletonFactory(Type registerType, Type registerImplementation) {
{ if(registerImplementation.IsAbstract || registerImplementation.IsInterface) {
if (registerImplementation.IsAbstract || registerImplementation.IsInterface)
{
throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true);
} }
if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) {
{
throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true);
} }
_registerType = registerType; this._registerType = registerType;
_registerImplementation = registerImplementation; this._registerImplementation = registerImplementation;
} }
public override Type CreatesType => _registerImplementation; public override Type CreatesType => this._registerImplementation;
public override ObjectFactoryBase SingletonVariant => this; public override ObjectFactoryBase SingletonVariant => this;
public override ObjectFactoryBase MultiInstanceVariant => public override ObjectFactoryBase MultiInstanceVariant =>
new MultiInstanceFactory(_registerType, _registerImplementation); new MultiInstanceFactory(this._registerType, this._registerImplementation);
public override object GetObject( public override Object GetObject(
Type requestedType, Type requestedType,
DependencyContainer container, DependencyContainer container,
DependencyContainerResolveOptions options) DependencyContainerResolveOptions options) {
{ if(options.ConstructorParameters.Count != 0) {
if (options.ConstructorParameters.Count != 0)
throw new ArgumentException("Cannot specify parameters for singleton types"); throw new ArgumentException("Cannot specify parameters for singleton types");
lock (_singletonLock)
{
if (_current == null)
_current = container.RegisteredTypes.ConstructType(_registerImplementation, Constructor, options);
} }
return _current; lock(this._singletonLock) {
if(this._current == null) {
this._current = container.RegisteredTypes.ConstructType(this._registerImplementation, this.Constructor, options);
}
}
return this._current;
} }
public override ObjectFactoryBase GetFactoryForChildContainer( public override ObjectFactoryBase GetFactoryForChildContainer(
Type type, Type type,
DependencyContainer parent, DependencyContainer parent,
DependencyContainer child) DependencyContainer child) {
{
// We make sure that the singleton is constructed before the child container takes the factory. // We make sure that the singleton is constructed before the child container takes the factory.
// Otherwise the results would vary depending on whether or not the parent container had resolved // Otherwise the results would vary depending on whether or not the parent container had resolved
// the type before the child container does. // the type before the child container does.
GetObject(type, parent, DependencyContainerResolveOptions.Default); _ = this.GetObject(type, parent, DependencyContainerResolveOptions.Default);
return this; return this;
} }
public void Dispose() => (_current as IDisposable)?.Dispose(); public void Dispose() => (this._current as IDisposable)?.Dispose();
} }
} }

View File

@ -1,14 +1,12 @@
namespace Swan.DependencyInjection using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Registration options for "fluent" API. /// Registration options for "fluent" API.
/// </summary> /// </summary>
public sealed class RegisterOptions public sealed class RegisterOptions {
{
private readonly TypesConcurrentDictionary _registeredTypes; private readonly TypesConcurrentDictionary _registeredTypes;
private readonly DependencyContainer.TypeRegistration _registration; private readonly DependencyContainer.TypeRegistration _registration;
@ -17,10 +15,9 @@
/// </summary> /// </summary>
/// <param name="registeredTypes">The registered types.</param> /// <param name="registeredTypes">The registered types.</param>
/// <param name="registration">The registration.</param> /// <param name="registration">The registration.</param>
public RegisterOptions(TypesConcurrentDictionary registeredTypes, DependencyContainer.TypeRegistration registration) public RegisterOptions(TypesConcurrentDictionary registeredTypes, DependencyContainer.TypeRegistration registration) {
{ this._registeredTypes = registeredTypes;
_registeredTypes = registeredTypes; this._registration = registration;
_registration = registration;
} }
/// <summary> /// <summary>
@ -28,14 +25,14 @@
/// </summary> /// </summary>
/// <returns>A registration options for fluent API.</returns> /// <returns>A registration options for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception>
public RegisterOptions AsSingleton() public RegisterOptions AsSingleton() {
{ ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration);
var currentFactory = _registeredTypes.GetCurrentFactory(_registration);
if (currentFactory == null) if(currentFactory == null) {
throw new DependencyContainerRegistrationException(_registration.Type, "singleton"); throw new DependencyContainerRegistrationException(this._registration.Type, "singleton");
}
return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.SingletonVariant); return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.SingletonVariant);
} }
/// <summary> /// <summary>
@ -43,14 +40,14 @@
/// </summary> /// </summary>
/// <returns>A registration options for fluent API.</returns> /// <returns>A registration options for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception>
public RegisterOptions AsMultiInstance() public RegisterOptions AsMultiInstance() {
{ ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration);
var currentFactory = _registeredTypes.GetCurrentFactory(_registration);
if (currentFactory == null) if(currentFactory == null) {
throw new DependencyContainerRegistrationException(_registration.Type, "multi-instance"); throw new DependencyContainerRegistrationException(this._registration.Type, "multi-instance");
}
return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.MultiInstanceVariant); return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.MultiInstanceVariant);
} }
/// <summary> /// <summary>
@ -58,14 +55,14 @@
/// </summary> /// </summary>
/// <returns>A registration options for fluent API.</returns> /// <returns>A registration options for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception>
public RegisterOptions WithWeakReference() public RegisterOptions WithWeakReference() {
{ ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration);
var currentFactory = _registeredTypes.GetCurrentFactory(_registration);
if (currentFactory == null) if(currentFactory == null) {
throw new DependencyContainerRegistrationException(_registration.Type, "weak reference"); throw new DependencyContainerRegistrationException(this._registration.Type, "weak reference");
}
return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.WeakReferenceVariant); return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.WeakReferenceVariant);
} }
/// <summary> /// <summary>
@ -73,41 +70,36 @@
/// </summary> /// </summary>
/// <returns>A registration options for fluent API.</returns> /// <returns>A registration options for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic constraint registration exception.</exception>
public RegisterOptions WithStrongReference() public RegisterOptions WithStrongReference() {
{ ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration);
var currentFactory = _registeredTypes.GetCurrentFactory(_registration);
if (currentFactory == null) if(currentFactory == null) {
throw new DependencyContainerRegistrationException(_registration.Type, "strong reference"); throw new DependencyContainerRegistrationException(this._registration.Type, "strong reference");
}
return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.StrongReferenceVariant); return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.StrongReferenceVariant);
} }
} }
/// <summary> /// <summary>
/// Registration options for "fluent" API when registering multiple implementations. /// Registration options for "fluent" API when registering multiple implementations.
/// </summary> /// </summary>
public sealed class MultiRegisterOptions public sealed class MultiRegisterOptions {
{
private IEnumerable<RegisterOptions> _registerOptions; private IEnumerable<RegisterOptions> _registerOptions;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="MultiRegisterOptions"/> class. /// Initializes a new instance of the <see cref="MultiRegisterOptions"/> class.
/// </summary> /// </summary>
/// <param name="registerOptions">The register options.</param> /// <param name="registerOptions">The register options.</param>
public MultiRegisterOptions(IEnumerable<RegisterOptions> registerOptions) public MultiRegisterOptions(IEnumerable<RegisterOptions> registerOptions) => this._registerOptions = registerOptions;
{
_registerOptions = registerOptions;
}
/// <summary> /// <summary>
/// Make registration a singleton (single instance) if possible. /// Make registration a singleton (single instance) if possible.
/// </summary> /// </summary>
/// <returns>A registration multi-instance for fluent API.</returns> /// <returns>A registration multi-instance for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic Constraint Registration Exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic Constraint Registration Exception.</exception>
public MultiRegisterOptions AsSingleton() public MultiRegisterOptions AsSingleton() {
{ this._registerOptions = this.ExecuteOnAllRegisterOptions(ro => ro.AsSingleton());
_registerOptions = ExecuteOnAllRegisterOptions(ro => ro.AsSingleton());
return this; return this;
} }
@ -116,16 +108,12 @@
/// </summary> /// </summary>
/// <returns>A registration multi-instance for fluent API.</returns> /// <returns>A registration multi-instance for fluent API.</returns>
/// <exception cref="DependencyContainerRegistrationException">Generic Constraint Registration Exception.</exception> /// <exception cref="DependencyContainerRegistrationException">Generic Constraint Registration Exception.</exception>
public MultiRegisterOptions AsMultiInstance() public MultiRegisterOptions AsMultiInstance() {
{ this._registerOptions = this.ExecuteOnAllRegisterOptions(ro => ro.AsMultiInstance());
_registerOptions = ExecuteOnAllRegisterOptions(ro => ro.AsMultiInstance());
return this; return this;
} }
private IEnumerable<RegisterOptions> ExecuteOnAllRegisterOptions( private IEnumerable<RegisterOptions> ExecuteOnAllRegisterOptions(
Func<RegisterOptions, RegisterOptions> action) Func<RegisterOptions, RegisterOptions> action) => this._registerOptions.Select(action).ToList();
{
return _registerOptions.Select(action).ToList();
}
} }
} }

View File

@ -1,27 +1,23 @@
namespace Swan.DependencyInjection using System;
{
using System;
public partial class DependencyContainer namespace Swan.DependencyInjection {
{ public partial class DependencyContainer {
/// <summary> /// <summary>
/// Represents a Type Registration within the IoC Container. /// Represents a Type Registration within the IoC Container.
/// </summary> /// </summary>
public sealed class TypeRegistration public sealed class TypeRegistration {
{ private readonly Int32 _hashCode;
private readonly int _hashCode;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="TypeRegistration"/> class. /// Initializes a new instance of the <see cref="TypeRegistration"/> class.
/// </summary> /// </summary>
/// <param name="type">The type.</param> /// <param name="type">The type.</param>
/// <param name="name">The name.</param> /// <param name="name">The name.</param>
public TypeRegistration(Type type, string name = null) public TypeRegistration(Type type, String name = null) {
{ this.Type = type;
Type = type; this.Name = name ?? String.Empty;
Name = name ?? string.Empty;
_hashCode = string.Concat(Type.FullName, "|", Name).GetHashCode(); this._hashCode = String.Concat(this.Type.FullName, "|", this.Name).GetHashCode();
} }
/// <summary> /// <summary>
@ -30,7 +26,9 @@
/// <value> /// <value>
/// The type. /// The type.
/// </value> /// </value>
public Type Type { get; } public Type Type {
get;
}
/// <summary> /// <summary>
/// Gets the name. /// Gets the name.
@ -38,7 +36,9 @@
/// <value> /// <value>
/// The name. /// The name.
/// </value> /// </value>
public string Name { get; } public String Name {
get;
}
/// <summary> /// <summary>
/// Determines whether the specified <see cref="System.Object" />, is equal to this instance. /// Determines whether the specified <see cref="System.Object" />, is equal to this instance.
@ -47,13 +47,7 @@
/// <returns> /// <returns>
/// <c>true</c> if the specified <see cref="System.Object" /> is equal to this instance; otherwise, <c>false</c>. /// <c>true</c> if the specified <see cref="System.Object" /> is equal to this instance; otherwise, <c>false</c>.
/// </returns> /// </returns>
public override bool Equals(object obj) public override Boolean Equals(Object obj) => !(obj is TypeRegistration typeRegistration) || typeRegistration.Type != this.Type ? false : String.Compare(this.Name, typeRegistration.Name, StringComparison.Ordinal) == 0;
{
if (!(obj is TypeRegistration typeRegistration) || typeRegistration.Type != Type)
return false;
return string.Compare(Name, typeRegistration.Name, StringComparison.Ordinal) == 0;
}
/// <summary> /// <summary>
/// Returns a hash code for this instance. /// Returns a hash code for this instance.
@ -61,7 +55,7 @@
/// <returns> /// <returns>
/// A hash code for this instance, suitable for use in hashing algorithms and data structures like a hash table. /// A hash code for this instance, suitable for use in hashing algorithms and data structures like a hash table.
/// </returns> /// </returns>
public override int GetHashCode() => _hashCode; public override Int32 GetHashCode() => this._hashCode;
} }
} }
} }

View File

@ -1,5 +1,4 @@
namespace Swan.DependencyInjection #nullable enable
{
using System; using System;
using System.Linq.Expressions; using System.Linq.Expressions;
using System.Reflection; using System.Reflection;
@ -7,275 +6,204 @@
using System.Linq; using System.Linq;
using System.Collections.Concurrent; using System.Collections.Concurrent;
namespace Swan.DependencyInjection {
/// <summary> /// <summary>
/// Represents a Concurrent Dictionary for TypeRegistration. /// Represents a Concurrent Dictionary for TypeRegistration.
/// </summary> /// </summary>
public class TypesConcurrentDictionary : ConcurrentDictionary<DependencyContainer.TypeRegistration, ObjectFactoryBase> public class TypesConcurrentDictionary : ConcurrentDictionary<DependencyContainer.TypeRegistration, ObjectFactoryBase> {
{ private static readonly ConcurrentDictionary<ConstructorInfo, ObjectConstructor> ObjectConstructorCache = new ConcurrentDictionary<ConstructorInfo, ObjectConstructor>();
private static readonly ConcurrentDictionary<ConstructorInfo, ObjectConstructor> ObjectConstructorCache =
new ConcurrentDictionary<ConstructorInfo, ObjectConstructor>();
private readonly DependencyContainer _dependencyContainer; private readonly DependencyContainer _dependencyContainer;
internal TypesConcurrentDictionary(DependencyContainer dependencyContainer) internal TypesConcurrentDictionary(DependencyContainer dependencyContainer) => this._dependencyContainer = dependencyContainer;
{
_dependencyContainer = dependencyContainer;
}
/// <summary> /// <summary>
/// Represents a delegate to build an object with the parameters. /// Represents a delegate to build an object with the parameters.
/// </summary> /// </summary>
/// <param name="parameters">The parameters.</param> /// <param name="parameters">The parameters.</param>
/// <returns>The built object.</returns> /// <returns>The built object.</returns>
public delegate object ObjectConstructor(params object[] parameters); public delegate Object ObjectConstructor(params Object?[] parameters);
internal IEnumerable<object> Resolve(Type resolveType, bool includeUnnamed) internal IEnumerable<Object> Resolve(Type resolveType, Boolean includeUnnamed) {
{ IEnumerable<DependencyContainer.TypeRegistration> registrations = this.Keys.Where(tr => tr.Type == resolveType).Concat(this.GetParentRegistrationsForType(resolveType)).Distinct();
var registrations = Keys.Where(tr => tr.Type == resolveType)
.Concat(GetParentRegistrationsForType(resolveType)).Distinct();
if (!includeUnnamed) if(!includeUnnamed) {
registrations = registrations.Where(tr => !string.IsNullOrEmpty(tr.Name)); registrations = registrations.Where(tr => !String.IsNullOrEmpty(tr.Name));
return registrations.Select(registration =>
ResolveInternal(registration, DependencyContainerResolveOptions.Default));
} }
internal ObjectFactoryBase GetCurrentFactory(DependencyContainer.TypeRegistration registration) return registrations.Select(registration => this.ResolveInternal(registration, DependencyContainerResolveOptions.Default));
{
TryGetValue(registration, out var current);
return current;
} }
internal RegisterOptions Register(Type registerType, string name, ObjectFactoryBase factory) internal ObjectFactoryBase GetCurrentFactory(DependencyContainer.TypeRegistration registration) {
=> AddUpdateRegistration(new DependencyContainer.TypeRegistration(registerType, name), factory); _ = this.TryGetValue(registration, out ObjectFactoryBase? current);
internal RegisterOptions AddUpdateRegistration(DependencyContainer.TypeRegistration typeRegistration, ObjectFactoryBase factory) return current!;
{ }
internal RegisterOptions Register(Type registerType, String name, ObjectFactoryBase factory) => this.AddUpdateRegistration(new DependencyContainer.TypeRegistration(registerType, name), factory);
internal RegisterOptions AddUpdateRegistration(DependencyContainer.TypeRegistration typeRegistration, ObjectFactoryBase factory) {
this[typeRegistration] = factory; this[typeRegistration] = factory;
return new RegisterOptions(this, typeRegistration); return new RegisterOptions(this, typeRegistration);
} }
internal bool RemoveRegistration(DependencyContainer.TypeRegistration typeRegistration) internal Boolean RemoveRegistration(DependencyContainer.TypeRegistration typeRegistration) => this.TryRemove(typeRegistration, out _);
=> TryRemove(typeRegistration, out _);
internal object ResolveInternal( internal Object ResolveInternal(DependencyContainer.TypeRegistration registration, DependencyContainerResolveOptions? options = null) {
DependencyContainer.TypeRegistration registration, if(options == null) {
DependencyContainerResolveOptions? options = null)
{
if (options == null)
options = DependencyContainerResolveOptions.Default; options = DependencyContainerResolveOptions.Default;
}
// Attempt container resolution // Attempt container resolution
if (TryGetValue(registration, out var factory)) if(this.TryGetValue(registration, out ObjectFactoryBase? factory)) {
{ try {
try return factory.GetObject(registration.Type, this._dependencyContainer, options);
{ } catch(DependencyContainerResolutionException) {
return factory.GetObject(registration.Type, _dependencyContainer, options);
}
catch (DependencyContainerResolutionException)
{
throw; throw;
} } catch(Exception ex) {
catch (Exception ex)
{
throw new DependencyContainerResolutionException(registration.Type, ex); throw new DependencyContainerResolutionException(registration.Type, ex);
} }
} }
// Attempt to get a factory from parent if we can // Attempt to get a factory from parent if we can
var bubbledObjectFactory = GetParentObjectFactory(registration); ObjectFactoryBase? bubbledObjectFactory = this.GetParentObjectFactory(registration);
if (bubbledObjectFactory != null) if(bubbledObjectFactory != null) {
{ try {
try return bubbledObjectFactory.GetObject(registration.Type, this._dependencyContainer, options);
{ } catch(DependencyContainerResolutionException) {
return bubbledObjectFactory.GetObject(registration.Type, _dependencyContainer, options);
}
catch (DependencyContainerResolutionException)
{
throw; throw;
} } catch(Exception ex) {
catch (Exception ex)
{
throw new DependencyContainerResolutionException(registration.Type, ex); throw new DependencyContainerResolutionException(registration.Type, ex);
} }
} }
// Fail if requesting named resolution and settings set to fail if unresolved // Fail if requesting named resolution and settings set to fail if unresolved
if (!string.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == if(!String.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.Fail) {
DependencyContainerNamedResolutionFailureAction.Fail)
throw new DependencyContainerResolutionException(registration.Type); throw new DependencyContainerResolutionException(registration.Type);
}
// Attempted unnamed fallback container resolution if relevant and requested // Attempted unnamed fallback container resolution if relevant and requested
if (!string.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == if(!String.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) {
DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) if(this.TryGetValue(new DependencyContainer.TypeRegistration(registration.Type, String.Empty), out factory)) {
{ try {
if (TryGetValue(new DependencyContainer.TypeRegistration(registration.Type, string.Empty), out factory)) return factory.GetObject(registration.Type, this._dependencyContainer, options);
{ } catch(DependencyContainerResolutionException) {
try
{
return factory.GetObject(registration.Type, _dependencyContainer, options);
}
catch (DependencyContainerResolutionException)
{
throw; throw;
} } catch(Exception ex) {
catch (Exception ex)
{
throw new DependencyContainerResolutionException(registration.Type, ex); throw new DependencyContainerResolutionException(registration.Type, ex);
} }
} }
} }
// Attempt unregistered construction if possible and requested // Attempt unregistered construction if possible and requested
var isValid = (options.UnregisteredResolutionAction == Boolean isValid = options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.AttemptResolve || registration.Type.IsGenericType && options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.GenericsOnly;
DependencyContainerUnregisteredResolutionAction.AttemptResolve) ||
(registration.Type.IsGenericType && options.UnregisteredResolutionAction ==
DependencyContainerUnregisteredResolutionAction.GenericsOnly);
return isValid && !registration.Type.IsAbstract && !registration.Type.IsInterface return isValid && !registration.Type.IsAbstract && !registration.Type.IsInterface ? this.ConstructType(registration.Type, null, options) : throw new DependencyContainerResolutionException(registration.Type);
? ConstructType(registration.Type, null, options)
: throw new DependencyContainerResolutionException(registration.Type);
} }
internal bool CanResolve( internal Boolean CanResolve(DependencyContainer.TypeRegistration registration, DependencyContainerResolveOptions? options = null) {
DependencyContainer.TypeRegistration registration, if(options == null) {
DependencyContainerResolveOptions? options = null)
{
if (options == null)
options = DependencyContainerResolveOptions.Default; options = DependencyContainerResolveOptions.Default;
}
var checkType = registration.Type; Type checkType = registration.Type;
var name = registration.Name; String name = registration.Name;
if (TryGetValue(new DependencyContainer.TypeRegistration(checkType, name), out var factory)) if(this.TryGetValue(new DependencyContainer.TypeRegistration(checkType, name), out ObjectFactoryBase? factory)) {
{ return factory.AssumeConstruction ? true : factory.Constructor == null ? this.GetBestConstructor(factory.CreatesType, options) != null : this.CanConstruct(factory.Constructor, options);
if (factory.AssumeConstruction)
return true;
if (factory.Constructor == null)
return GetBestConstructor(factory.CreatesType, options) != null;
return CanConstruct(factory.Constructor, options);
} }
// Fail if requesting named resolution and settings set to fail if unresolved // Fail if requesting named resolution and settings set to fail if unresolved
// Or bubble up if we have a parent // Or bubble up if we have a parent
if (!string.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == if(!String.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.Fail) {
DependencyContainerNamedResolutionFailureAction.Fail) return this._dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false;
return _dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false; }
// Attempted unnamed fallback container resolution if relevant and requested // Attempted unnamed fallback container resolution if relevant and requested
if (!string.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == if(!String.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) {
DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) if(this.TryGetValue(new DependencyContainer.TypeRegistration(checkType), out factory)) {
{ return factory.AssumeConstruction ? true : this.GetBestConstructor(factory.CreatesType, options) != null;
if (TryGetValue(new DependencyContainer.TypeRegistration(checkType), out factory))
{
if (factory.AssumeConstruction)
return true;
return GetBestConstructor(factory.CreatesType, options) != null;
} }
} }
// Check if type is an automatic lazy factory request or an IEnumerable<ResolveType> // Check if type is an automatic lazy factory request or an IEnumerable<ResolveType>
if (IsAutomaticLazyFactoryRequest(checkType) || registration.Type.IsIEnumerable()) if(IsAutomaticLazyFactoryRequest(checkType) || registration.Type.IsIEnumerable()) {
return true; return true;
}
// Attempt unregistered construction if possible and requested // Attempt unregistered construction if possible and requested
// If we cant', bubble if we have a parent // If we cant', bubble if we have a parent
if ((options.UnregisteredResolutionAction == if(options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.AttemptResolve || checkType.IsGenericType && options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.GenericsOnly) {
DependencyContainerUnregisteredResolutionAction.AttemptResolve) || return this.GetBestConstructor(checkType, options) != null || (this._dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false);
(checkType.IsGenericType && options.UnregisteredResolutionAction ==
DependencyContainerUnregisteredResolutionAction.GenericsOnly))
{
return (GetBestConstructor(checkType, options) != null) ||
(_dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false);
} }
// Bubble resolution up the container tree if we have a parent // Bubble resolution up the container tree if we have a parent
return _dependencyContainer.Parent != null && _dependencyContainer.Parent.RegisteredTypes.CanResolve(registration, options.Clone()); return this._dependencyContainer.Parent != null && this._dependencyContainer.Parent.RegisteredTypes.CanResolve(registration, options.Clone());
} }
internal object ConstructType( internal Object ConstructType(Type implementationType, ConstructorInfo? constructor, DependencyContainerResolveOptions? options = null) {
Type implementationType, Type typeToConstruct = implementationType;
ConstructorInfo constructor,
DependencyContainerResolveOptions? options = null)
{
var typeToConstruct = implementationType;
if (constructor == null) if(constructor == null) {
{
// Try and get the best constructor that we can construct // Try and get the best constructor that we can construct
// if we can't construct any then get the constructor // if we can't construct any then get the constructor
// with the least number of parameters so we can throw a meaningful // with the least number of parameters so we can throw a meaningful
// resolve exception // resolve exception
constructor = GetBestConstructor(typeToConstruct, options) ?? constructor = this.GetBestConstructor(typeToConstruct, options) ?? GetTypeConstructors(typeToConstruct).LastOrDefault();
GetTypeConstructors(typeToConstruct).LastOrDefault();
} }
if (constructor == null) if(constructor == null) {
throw new DependencyContainerResolutionException(typeToConstruct); throw new DependencyContainerResolutionException(typeToConstruct);
var ctorParams = constructor.GetParameters();
var args = new object?[ctorParams.Length];
for (var parameterIndex = 0; parameterIndex < ctorParams.Length; parameterIndex++)
{
var currentParam = ctorParams[parameterIndex];
try
{
args[parameterIndex] = options?.ConstructorParameters.GetValueOrDefault(currentParam.Name, ResolveInternal(new DependencyContainer.TypeRegistration(currentParam.ParameterType), options.Clone()));
} }
catch (DependencyContainerResolutionException ex)
{ ParameterInfo[] ctorParams = constructor.GetParameters();
Object?[] args = new Object?[ctorParams.Length];
for(Int32 parameterIndex = 0; parameterIndex < ctorParams.Length; parameterIndex++) {
ParameterInfo currentParam = ctorParams[parameterIndex];
try {
args[parameterIndex] = options?.ConstructorParameters.GetValueOrDefault(currentParam.Name, this.ResolveInternal(new DependencyContainer.TypeRegistration(currentParam.ParameterType), options.Clone()));
} catch(DependencyContainerResolutionException ex) {
// If a constructor parameter can't be resolved // If a constructor parameter can't be resolved
// it will throw, so wrap it and throw that this can't // it will throw, so wrap it and throw that this can't
// be resolved. // be resolved.
throw new DependencyContainerResolutionException(typeToConstruct, ex); throw new DependencyContainerResolutionException(typeToConstruct, ex);
} } catch(Exception ex) {
catch (Exception ex)
{
throw new DependencyContainerResolutionException(typeToConstruct, ex); throw new DependencyContainerResolutionException(typeToConstruct, ex);
} }
} }
try try {
{
return CreateObjectConstructionDelegateWithCache(constructor).Invoke(args); return CreateObjectConstructionDelegateWithCache(constructor).Invoke(args);
} } catch(Exception ex) {
catch (Exception ex)
{
throw new DependencyContainerResolutionException(typeToConstruct, ex); throw new DependencyContainerResolutionException(typeToConstruct, ex);
} }
} }
private static ObjectConstructor CreateObjectConstructionDelegateWithCache(ConstructorInfo constructor) private static ObjectConstructor CreateObjectConstructionDelegateWithCache(ConstructorInfo constructor) {
{ if(ObjectConstructorCache.TryGetValue(constructor, out ObjectConstructor? objectConstructor)) {
if (ObjectConstructorCache.TryGetValue(constructor, out var objectConstructor))
return objectConstructor; return objectConstructor;
}
// We could lock the cache here, but there's no real side // We could lock the cache here, but there's no real side
// effect to two threads creating the same ObjectConstructor // effect to two threads creating the same ObjectConstructor
// at the same time, compared to the cost of a lock for // at the same time, compared to the cost of a lock for
// every creation. // every creation.
var constructorParams = constructor.GetParameters(); ParameterInfo[] constructorParams = constructor.GetParameters();
var lambdaParams = Expression.Parameter(typeof(object[]), "parameters"); ParameterExpression lambdaParams = Expression.Parameter(typeof(Object[]), "parameters");
var newParams = new Expression[constructorParams.Length]; Expression[] newParams = new Expression[constructorParams.Length];
for (var i = 0; i < constructorParams.Length; i++) for(Int32 i = 0; i < constructorParams.Length; i++) {
{ BinaryExpression paramsParameter = Expression.ArrayIndex(lambdaParams, Expression.Constant(i));
var paramsParameter = Expression.ArrayIndex(lambdaParams, Expression.Constant(i));
newParams[i] = Expression.Convert(paramsParameter, constructorParams[i].ParameterType); newParams[i] = Expression.Convert(paramsParameter, constructorParams[i].ParameterType);
} }
var newExpression = Expression.New(constructor, newParams); NewExpression newExpression = Expression.New(constructor, newParams);
var constructionLambda = Expression.Lambda(typeof(ObjectConstructor), newExpression, lambdaParams); LambdaExpression constructionLambda = Expression.Lambda(typeof(ObjectConstructor), newExpression, lambdaParams);
objectConstructor = (ObjectConstructor)constructionLambda.Compile(); objectConstructor = (ObjectConstructor)constructionLambda.Compile();
@ -283,69 +211,55 @@
return objectConstructor; return objectConstructor;
} }
private static IEnumerable<ConstructorInfo> GetTypeConstructors(Type type) private static IEnumerable<ConstructorInfo> GetTypeConstructors(Type type) => type.GetConstructors().OrderByDescending(ctor => ctor.GetParameters().Length);
=> type.GetConstructors().OrderByDescending(ctor => ctor.GetParameters().Length);
private static bool IsAutomaticLazyFactoryRequest(Type type) private static Boolean IsAutomaticLazyFactoryRequest(Type type) {
{ if(!type.IsGenericType) {
if (!type.IsGenericType)
return false; return false;
}
var genericType = type.GetGenericTypeDefinition(); Type genericType = type.GetGenericTypeDefinition();
// Just a func // Just a func
if (genericType == typeof(Func<>)) if(genericType == typeof(Func<>)) {
return true; return true;
}
// 2 parameter func with string as first parameter (name) // 2 parameter func with string as first parameter (name)
if (genericType == typeof(Func<,>) && type.GetGenericArguments()[0] == typeof(string)) if(genericType == typeof(Func<,>) && type.GetGenericArguments()[0] == typeof(String)) {
return true; return true;
}
// 3 parameter func with string as first parameter (name) and IDictionary<string, object> as second (parameters) // 3 parameter func with string as first parameter (name) and IDictionary<string, object> as second (parameters)
return genericType == typeof(Func<,,>) && type.GetGenericArguments()[0] == typeof(string) && return genericType == typeof(Func<,,>) && type.GetGenericArguments()[0] == typeof(String) && type.GetGenericArguments()[1] == typeof(IDictionary<String, Object>);
type.GetGenericArguments()[1] == typeof(IDictionary<string, object>);
} }
private ObjectFactoryBase? GetParentObjectFactory(DependencyContainer.TypeRegistration registration) private ObjectFactoryBase? GetParentObjectFactory(DependencyContainer.TypeRegistration registration) => this._dependencyContainer.Parent == null
{ ? null
if (_dependencyContainer.Parent == null) : this._dependencyContainer.Parent.RegisteredTypes.TryGetValue(registration, out ObjectFactoryBase? factory) ? factory.GetFactoryForChildContainer(registration.Type, this._dependencyContainer.Parent, this._dependencyContainer) : this._dependencyContainer.Parent.RegisteredTypes.GetParentObjectFactory(registration);
return null;
return _dependencyContainer.Parent.RegisteredTypes.TryGetValue(registration, out var factory) private ConstructorInfo? GetBestConstructor(Type type, DependencyContainerResolveOptions? options) => type.IsValueType ? null : GetTypeConstructors(type).FirstOrDefault(ctor => this.CanConstruct(ctor, options));
? factory.GetFactoryForChildContainer(registration.Type, _dependencyContainer.Parent, _dependencyContainer)
: _dependencyContainer.Parent.RegisteredTypes.GetParentObjectFactory(registration); private Boolean CanConstruct(MethodBase ctor, DependencyContainerResolveOptions? options) {
foreach(ParameterInfo parameter in ctor.GetParameters()) {
if(String.IsNullOrEmpty(parameter.Name)) {
return false;
} }
private ConstructorInfo? GetBestConstructor( Boolean isParameterOverload = options!.ConstructorParameters.ContainsKey(parameter.Name);
Type type,
DependencyContainerResolveOptions options)
=> type.IsValueType ? null : GetTypeConstructors(type).FirstOrDefault(ctor => CanConstruct(ctor, options));
private bool CanConstruct( if(parameter.ParameterType.IsPrimitive && !isParameterOverload) {
MethodBase ctor,
DependencyContainerResolveOptions? options)
{
foreach (var parameter in ctor.GetParameters())
{
if (string.IsNullOrEmpty(parameter.Name))
return false; return false;
}
var isParameterOverload = options.ConstructorParameters.ContainsKey(parameter.Name); if(!isParameterOverload && !this.CanResolve(new DependencyContainer.TypeRegistration(parameter.ParameterType), options.Clone())) {
if (parameter.ParameterType.IsPrimitive && !isParameterOverload)
return false;
if (!isParameterOverload &&
!CanResolve(new DependencyContainer.TypeRegistration(parameter.ParameterType), options.Clone()))
return false; return false;
} }
}
return true; return true;
} }
private IEnumerable<DependencyContainer.TypeRegistration> GetParentRegistrationsForType(Type resolveType) private IEnumerable<DependencyContainer.TypeRegistration> GetParentRegistrationsForType(Type resolveType) => this._dependencyContainer.Parent == null ? Array.Empty<DependencyContainer.TypeRegistration>() : this._dependencyContainer.Parent.RegisteredTypes.Keys.Where(tr => tr.Type == resolveType).Concat(this._dependencyContainer.Parent.RegisteredTypes.GetParentRegistrationsForType(resolveType));
=> _dependencyContainer.Parent == null
? Array.Empty<DependencyContainer.TypeRegistration>()
: _dependencyContainer.Parent.RegisteredTypes.Keys.Where(tr => tr.Type == resolveType).Concat(_dependencyContainer.Parent.RegisteredTypes.GetParentRegistrationsForType(resolveType));
} }
} }

View File

@ -1,40 +1,32 @@
namespace Swan.Diagnostics #nullable enable
{
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using Threading; using Swan.Threading;
namespace Swan.Diagnostics {
/// <summary> /// <summary>
/// A time measurement artifact. /// A time measurement artifact.
/// </summary> /// </summary>
internal sealed class RealTimeClock : IDisposable internal sealed class RealTimeClock : IDisposable {
{
private readonly Stopwatch _chrono = new Stopwatch(); private readonly Stopwatch _chrono = new Stopwatch();
private ISyncLocker? _locker = SyncLockerFactory.Create(useSlim: true); private ISyncLocker? _locker = SyncLockerFactory.Create(useSlim: true);
private long _offsetTicks; private Int64 _offsetTicks;
private double _speedRatio = 1.0d; private Double _speedRatio = 1.0d;
private bool _isDisposed; private Boolean _isDisposed;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="RealTimeClock"/> class. /// Initializes a new instance of the <see cref="RealTimeClock"/> class.
/// The clock starts paused and at the 0 position. /// The clock starts paused and at the 0 position.
/// </summary> /// </summary>
public RealTimeClock() public RealTimeClock() => this.Reset();
{
Reset();
}
/// <summary> /// <summary>
/// Gets or sets the clock position. /// Gets or sets the clock position.
/// </summary> /// </summary>
public TimeSpan Position public TimeSpan Position {
{ get {
get using(this._locker?.AcquireReaderLock()) {
{ return TimeSpan.FromTicks(this._offsetTicks + Convert.ToInt64(this._chrono.Elapsed.Ticks * this.SpeedRatio));
using (_locker?.AcquireReaderLock())
{
return TimeSpan.FromTicks(
_offsetTicks + Convert.ToInt64(_chrono.Elapsed.Ticks * SpeedRatio));
} }
} }
} }
@ -42,13 +34,10 @@
/// <summary> /// <summary>
/// Gets a value indicating whether the clock is running. /// Gets a value indicating whether the clock is running.
/// </summary> /// </summary>
public bool IsRunning public Boolean IsRunning {
{ get {
get using(this._locker?.AcquireReaderLock()) {
{ return this._chrono.IsRunning;
using (_locker?.AcquireReaderLock())
{
return _chrono.IsRunning;
} }
} }
} }
@ -56,26 +45,23 @@
/// <summary> /// <summary>
/// Gets or sets the speed ratio at which the clock runs. /// Gets or sets the speed ratio at which the clock runs.
/// </summary> /// </summary>
public double SpeedRatio public Double SpeedRatio {
{ get {
get using(this._locker?.AcquireReaderLock()) {
{ return this._speedRatio;
using (_locker?.AcquireReaderLock())
{
return _speedRatio;
} }
} }
set set {
{ using(this._locker?.AcquireWriterLock()) {
using (_locker?.AcquireWriterLock()) if(value < 0d) {
{ value = 0d;
if (value < 0d) value = 0d; }
// Capture the initial position se we set it even after the Speed Ratio has changed // Capture the initial position se we set it even after the Speed Ratio has changed
// this ensures a smooth position transition // this ensures a smooth position transition
var initialPosition = Position; TimeSpan initialPosition = this.Position;
_speedRatio = value; this._speedRatio = value;
Update(initialPosition); this.Update(initialPosition);
} }
} }
} }
@ -84,37 +70,36 @@
/// Sets a new position value atomically. /// Sets a new position value atomically.
/// </summary> /// </summary>
/// <param name="value">The new value that the position property will hold.</param> /// <param name="value">The new value that the position property will hold.</param>
public void Update(TimeSpan value) public void Update(TimeSpan value) {
{ using(this._locker?.AcquireWriterLock()) {
using (_locker?.AcquireWriterLock()) Boolean resume = this._chrono.IsRunning;
{ this._chrono.Reset();
var resume = _chrono.IsRunning; this._offsetTicks = value.Ticks;
_chrono.Reset(); if(resume) {
_offsetTicks = value.Ticks; this._chrono.Start();
if (resume) _chrono.Start(); }
} }
} }
/// <summary> /// <summary>
/// Starts or resumes the clock. /// Starts or resumes the clock.
/// </summary> /// </summary>
public void Play() public void Play() {
{ using(this._locker?.AcquireWriterLock()) {
using (_locker?.AcquireWriterLock()) if(this._chrono.IsRunning) {
{ return;
if (_chrono.IsRunning) return; }
_chrono.Start();
this._chrono.Start();
} }
} }
/// <summary> /// <summary>
/// Pauses the clock. /// Pauses the clock.
/// </summary> /// </summary>
public void Pause() public void Pause() {
{ using(this._locker?.AcquireWriterLock()) {
using (_locker?.AcquireWriterLock()) this._chrono.Stop();
{
_chrono.Stop();
} }
} }
@ -122,22 +107,22 @@
/// Sets the clock position to 0 and stops it. /// Sets the clock position to 0 and stops it.
/// The speed ratio is not modified. /// The speed ratio is not modified.
/// </summary> /// </summary>
public void Reset() public void Reset() {
{ using(this._locker?.AcquireWriterLock()) {
using (_locker?.AcquireWriterLock()) this._offsetTicks = 0;
{ this._chrono.Reset();
_offsetTicks = 0;
_chrono.Reset();
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ if(this._isDisposed) {
if (_isDisposed) return; return;
_isDisposed = true; }
_locker?.Dispose();
_locker = null; this._isDisposed = true;
this._locker?.Dispose();
this._locker = null;
} }
} }
} }

View File

@ -1,15 +1,13 @@
namespace Swan using System;
{
using System;
using System.IO; using System.IO;
using System.Net.Mail; using System.Net.Mail;
using System.Reflection; using System.Reflection;
namespace Swan {
/// <summary> /// <summary>
/// Extension methods. /// Extension methods.
/// </summary> /// </summary>
public static class SmtpExtensions public static class SmtpExtensions {
{
private static readonly BindingFlags PrivateInstanceFlags = BindingFlags.Instance | BindingFlags.NonPublic; private static readonly BindingFlags PrivateInstanceFlags = BindingFlags.Instance | BindingFlags.NonPublic;
/// <summary> /// <summary>
@ -17,40 +15,29 @@
/// </summary> /// </summary>
/// <param name="this">The caller.</param> /// <param name="this">The caller.</param>
/// <returns>A MemoryStream with the raw contents of this MailMessage.</returns> /// <returns>A MemoryStream with the raw contents of this MailMessage.</returns>
public static MemoryStream ToMimeMessage(this MailMessage @this) public static MemoryStream ToMimeMessage(this MailMessage @this) {
{ if(@this == null) {
if (@this == null)
throw new ArgumentNullException(nameof(@this)); throw new ArgumentNullException(nameof(@this));
}
var result = new MemoryStream(); MemoryStream result = new MemoryStream();
var mailWriter = MimeMessageConstants.MailWriterConstructor.Invoke(new object[] { result }); Object mailWriter = MimeMessageConstants.MailWriterConstructor.Invoke(new Object[] { result });
MimeMessageConstants.SendMethod.Invoke( _ = MimeMessageConstants.SendMethod.Invoke(@this, PrivateInstanceFlags, null, MimeMessageConstants.IsRunningInDotNetFourPointFive ? new[] { mailWriter, true, true } : new[] { mailWriter, true }, null);
@this,
PrivateInstanceFlags,
null,
MimeMessageConstants.IsRunningInDotNetFourPointFive ? new[] { mailWriter, true, true } : new[] { mailWriter, true },
null);
result = new MemoryStream(result.ToArray()); result = new MemoryStream(result.ToArray());
MimeMessageConstants.CloseMethod.Invoke( _ = MimeMessageConstants.CloseMethod.Invoke(mailWriter, PrivateInstanceFlags, null, Array.Empty<Object>(), null);
mailWriter,
PrivateInstanceFlags,
null,
Array.Empty<object>(),
null);
result.Position = 0; result.Position = 0;
return result; return result;
} }
internal static class MimeMessageConstants internal static class MimeMessageConstants {
{
#pragma warning disable DE0005 // API is deprecated #pragma warning disable DE0005 // API is deprecated
public static readonly Type MailWriter = typeof(SmtpClient).Assembly.GetType("System.Net.Mail.MailWriter"); public static readonly Type MailWriter = typeof(SmtpClient).Assembly.GetType("System.Net.Mail.MailWriter");
#pragma warning restore DE0005 // API is deprecated #pragma warning restore DE0005 // API is deprecated
public static readonly ConstructorInfo MailWriterConstructor = MailWriter.GetConstructor(PrivateInstanceFlags, null, new[] { typeof(Stream) }, null); public static readonly ConstructorInfo MailWriterConstructor = MailWriter.GetConstructor(PrivateInstanceFlags, null, new[] { typeof(Stream) }, null);
public static readonly MethodInfo CloseMethod = MailWriter.GetMethod("Close", PrivateInstanceFlags); public static readonly MethodInfo CloseMethod = MailWriter.GetMethod("Close", PrivateInstanceFlags);
public static readonly MethodInfo SendMethod = typeof(MailMessage).GetMethod("Send", PrivateInstanceFlags); public static readonly MethodInfo SendMethod = typeof(MailMessage).GetMethod("Send", PrivateInstanceFlags);
public static readonly bool IsRunningInDotNetFourPointFive = SendMethod.GetParameters().Length == 3; public static readonly Boolean IsRunningInDotNetFourPointFive = SendMethod.GetParameters().Length == 3;
} }
} }
} }

View File

@ -1,15 +1,13 @@
namespace Swan using System;
{
using System;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
namespace Swan {
/// <summary> /// <summary>
/// Provides various extension methods for networking-related tasks. /// Provides various extension methods for networking-related tasks.
/// </summary> /// </summary>
public static class NetworkExtensions public static class NetworkExtensions {
{
/// <summary> /// <summary>
/// Determines whether the IP address is private. /// Determines whether the IP address is private.
/// </summary> /// </summary>
@ -18,15 +16,15 @@
/// True if the IP Address is private; otherwise, false. /// True if the IP Address is private; otherwise, false.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">address.</exception> /// <exception cref="ArgumentNullException">address.</exception>
public static bool IsPrivateAddress(this IPAddress @this) public static Boolean IsPrivateAddress(this IPAddress @this) {
{ if(@this == null) {
if (@this == null)
throw new ArgumentNullException(nameof(@this)); throw new ArgumentNullException(nameof(@this));
}
var octets = @this.ToString().Split(new[] { "." }, StringSplitOptions.RemoveEmptyEntries).Select(byte.Parse).ToArray(); Byte[] octets = @this.ToString().Split(new[] { "." }, StringSplitOptions.RemoveEmptyEntries).Select(Byte.Parse).ToArray();
var is24Bit = octets[0] == 10; Boolean is24Bit = octets[0] == 10;
var is20Bit = octets[0] == 172 && (octets[1] >= 16 && octets[1] <= 31); Boolean is20Bit = octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31;
var is16Bit = octets[0] == 192 && octets[1] == 168; Boolean is16Bit = octets[0] == 192 && octets[1] == 168;
return is24Bit || is20Bit || is16Bit; return is24Bit || is20Bit || is16Bit;
} }
@ -40,17 +38,19 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">address.</exception> /// <exception cref="ArgumentNullException">address.</exception>
/// <exception cref="ArgumentException">InterNetwork - address.</exception> /// <exception cref="ArgumentException">InterNetwork - address.</exception>
public static uint ToUInt32(this IPAddress @this) public static UInt32 ToUInt32(this IPAddress @this) {
{ if(@this == null) {
if (@this == null)
throw new ArgumentNullException(nameof(@this)); throw new ArgumentNullException(nameof(@this));
}
if (@this.AddressFamily != AddressFamily.InterNetwork) if(@this.AddressFamily != AddressFamily.InterNetwork) {
throw new ArgumentException($"Address has to be of family '{nameof(AddressFamily.InterNetwork)}'", nameof(@this)); throw new ArgumentException($"Address has to be of family '{nameof(AddressFamily.InterNetwork)}'", nameof(@this));
}
var addressBytes = @this.GetAddressBytes(); Byte[] addressBytes = @this.GetAddressBytes();
if (BitConverter.IsLittleEndian) if(BitConverter.IsLittleEndian) {
Array.Reverse(addressBytes); Array.Reverse(addressBytes);
}
return BitConverter.ToUInt32(addressBytes, 0); return BitConverter.ToUInt32(addressBytes, 0);
} }

View File

@ -1,21 +1,16 @@
namespace Swan using Swan.Logging;
{
using Logging;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Reflection; using System.Reflection;
using System.Threading; using System.Threading;
#if NET461
using System.ServiceProcess;
#else
using Services;
#endif
using Swan.Services;
namespace Swan {
/// <summary> /// <summary>
/// Extension methods. /// Extension methods.
/// </summary> /// </summary>
public static class WindowsServicesExtensions public static class WindowsServicesExtensions {
{
/// <summary> /// <summary>
/// Runs a service in console mode. /// Runs a service in console mode.
/// </summary> /// </summary>
@ -23,10 +18,10 @@
/// <param name="loggerSource">The logger source.</param> /// <param name="loggerSource">The logger source.</param>
/// <exception cref="ArgumentNullException">this.</exception> /// <exception cref="ArgumentNullException">this.</exception>
[Obsolete("This extension method will be removed in version 3.0")] [Obsolete("This extension method will be removed in version 3.0")]
public static void RunInConsoleMode(this ServiceBase @this, string loggerSource = null) public static void RunInConsoleMode(this ServiceBase @this, String loggerSource = null) {
{ if(@this == null) {
if (@this == null)
throw new ArgumentNullException(nameof(@this)); throw new ArgumentNullException(nameof(@this));
}
RunInConsoleMode(new[] { @this }, loggerSource); RunInConsoleMode(new[] { @this }, loggerSource);
} }
@ -39,30 +34,27 @@
/// <exception cref="ArgumentNullException">this.</exception> /// <exception cref="ArgumentNullException">this.</exception>
/// <exception cref="InvalidOperationException">The ServiceBase class isn't available.</exception> /// <exception cref="InvalidOperationException">The ServiceBase class isn't available.</exception>
[Obsolete("This extension method will be removed in version 3.0")] [Obsolete("This extension method will be removed in version 3.0")]
public static void RunInConsoleMode(this ServiceBase[] @this, string loggerSource = null) public static void RunInConsoleMode(this ServiceBase[] @this, String loggerSource = null) {
{ if(@this == null) {
if (@this == null)
throw new ArgumentNullException(nameof(@this)); throw new ArgumentNullException(nameof(@this));
}
const string onStartMethodName = "OnStart"; const String onStartMethodName = "OnStart";
const string onStopMethodName = "OnStop"; const String onStopMethodName = "OnStop";
var onStartMethod = typeof(ServiceBase).GetMethod(onStartMethodName, MethodInfo onStartMethod = typeof(ServiceBase).GetMethod(onStartMethodName, BindingFlags.Instance | BindingFlags.NonPublic);
BindingFlags.Instance | BindingFlags.NonPublic); MethodInfo onStopMethod = typeof(ServiceBase).GetMethod(onStopMethodName, BindingFlags.Instance | BindingFlags.NonPublic);
var onStopMethod = typeof(ServiceBase).GetMethod(onStopMethodName,
BindingFlags.Instance | BindingFlags.NonPublic);
if (onStartMethod == null || onStopMethod == null) if(onStartMethod == null || onStopMethod == null) {
throw new InvalidOperationException("The ServiceBase class isn't available."); throw new InvalidOperationException("The ServiceBase class isn't available.");
}
var serviceThreads = new List<Thread>(); List<Thread> serviceThreads = new List<Thread>();
"Starting services . . .".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); "Starting services . . .".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name);
foreach (var service in @this) foreach(ServiceBase service in @this) {
{ Thread thread = new Thread(() => {
var thread = new Thread(() => _ = onStartMethod.Invoke(service, new Object[] { Array.Empty<String>() });
{
onStartMethod.Invoke(service, new object[] { Array.Empty<string>() });
$"Started service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); $"Started service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name);
}); });
@ -71,17 +63,17 @@
} }
"Press any key to stop all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); "Press any key to stop all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name);
Terminal.ReadKey(true, true); _ = Terminal.ReadKey(true, true);
"Stopping services . . .".Info(SwanRuntime.EntryAssemblyName.Name); "Stopping services . . .".Info(SwanRuntime.EntryAssemblyName.Name);
foreach (var service in @this) foreach(ServiceBase service in @this) {
{ _ = onStopMethod.Invoke(service, null);
onStopMethod.Invoke(service, null);
$"Stopped service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); $"Stopped service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name);
} }
foreach (var thread in serviceThreads) foreach(Thread thread in serviceThreads) {
thread.Join(); thread.Join();
}
"Stopped all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); "Stopped all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name);
} }

View File

@ -1,13 +1,15 @@
namespace Swan.Messaging using System;
{
namespace Swan.Messaging {
/// <summary> /// <summary>
/// A Message to be published/delivered by Messenger. /// A Message to be published/delivered by Messenger.
/// </summary> /// </summary>
public interface IMessageHubMessage public interface IMessageHubMessage {
{
/// <summary> /// <summary>
/// The sender of the message, or null if not supported by the message implementation. /// The sender of the message, or null if not supported by the message implementation.
/// </summary> /// </summary>
object Sender { get; } Object Sender {
get;
}
} }
} }

View File

@ -1,21 +1,23 @@
namespace Swan.Messaging using System;
{
namespace Swan.Messaging {
/// <summary> /// <summary>
/// Represents a message subscription. /// Represents a message subscription.
/// </summary> /// </summary>
public interface IMessageHubSubscription public interface IMessageHubSubscription {
{
/// <summary> /// <summary>
/// Token returned to the subscribed to reference this subscription. /// Token returned to the subscribed to reference this subscription.
/// </summary> /// </summary>
MessageHubSubscriptionToken SubscriptionToken { get; } MessageHubSubscriptionToken SubscriptionToken {
get;
}
/// <summary> /// <summary>
/// Whether delivery should be attempted. /// Whether delivery should be attempted.
/// </summary> /// </summary>
/// <param name="message">Message that may potentially be delivered.</param> /// <param name="message">Message that may potentially be delivered.</param>
/// <returns><c>true</c> - ok to send, <c>false</c> - should not attempt to send.</returns> /// <returns><c>true</c> - ok to send, <c>false</c> - should not attempt to send.</returns>
bool ShouldAttemptDelivery(IMessageHubMessage message); Boolean ShouldAttemptDelivery(IMessageHubMessage message);
/// <summary> /// <summary>
/// Deliver the message. /// Deliver the message.

View File

@ -11,14 +11,13 @@
// LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND // LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
// FITNESS FOR A PARTICULAR PURPOSE. // FITNESS FOR A PARTICULAR PURPOSE.
// =============================================================================== // ===============================================================================
#nullable enable
namespace Swan.Messaging
{
using System.Threading.Tasks; using System.Threading.Tasks;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
namespace Swan.Messaging {
#region Message Types / Interfaces #region Message Types / Interfaces
/// <summary> /// <summary>
@ -27,8 +26,7 @@ namespace Swan.Messaging
/// A message proxy can be used to intercept/alter messages and/or /// A message proxy can be used to intercept/alter messages and/or
/// marshal delivery actions onto a particular thread. /// marshal delivery actions onto a particular thread.
/// </summary> /// </summary>
public interface IMessageHubProxy public interface IMessageHubProxy {
{
/// <summary> /// <summary>
/// Delivers the specified message. /// Delivers the specified message.
/// </summary> /// </summary>
@ -42,10 +40,8 @@ namespace Swan.Messaging
/// ///
/// Does nothing other than deliver the message. /// Does nothing other than deliver the message.
/// </summary> /// </summary>
public sealed class MessageHubDefaultProxy : IMessageHubProxy public sealed class MessageHubDefaultProxy : IMessageHubProxy {
{ private MessageHubDefaultProxy() {
private MessageHubDefaultProxy()
{
// placeholder // placeholder
} }
@ -59,8 +55,7 @@ namespace Swan.Messaging
/// </summary> /// </summary>
/// <param name="message">The message.</param> /// <param name="message">The message.</param>
/// <param name="subscription">The subscription.</param> /// <param name="subscription">The subscription.</param>
public void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription) public void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription) => subscription.Deliver(message);
=> subscription.Deliver(message);
} }
#endregion #endregion
@ -70,8 +65,7 @@ namespace Swan.Messaging
/// <summary> /// <summary>
/// Messenger hub responsible for taking subscriptions/publications and delivering of messages. /// Messenger hub responsible for taking subscriptions/publications and delivering of messages.
/// </summary> /// </summary>
public interface IMessageHub public interface IMessageHub {
{
/// <summary> /// <summary>
/// Subscribe to a message type with the given destination and delivery action. /// Subscribe to a message type with the given destination and delivery action.
/// Messages will be delivered via the specified proxy. /// Messages will be delivered via the specified proxy.
@ -83,11 +77,7 @@ namespace Swan.Messaging
/// <param name="useStrongReferences">Use strong references to destination and deliveryAction.</param> /// <param name="useStrongReferences">Use strong references to destination and deliveryAction.</param>
/// <param name="proxy">Proxy to use when delivering the messages.</param> /// <param name="proxy">Proxy to use when delivering the messages.</param>
/// <returns>MessageSubscription used to unsubscribing.</returns> /// <returns>MessageSubscription used to unsubscribing.</returns>
MessageHubSubscriptionToken Subscribe<TMessage>( MessageHubSubscriptionToken Subscribe<TMessage>(Action<TMessage> deliveryAction, Boolean useStrongReferences, IMessageHubProxy proxy) where TMessage : class, IMessageHubMessage;
Action<TMessage> deliveryAction,
bool useStrongReferences,
IMessageHubProxy proxy)
where TMessage : class, IMessageHubMessage;
/// <summary> /// <summary>
/// Subscribe to a message type with the given destination and delivery action with the given filter. /// Subscribe to a message type with the given destination and delivery action with the given filter.
@ -103,12 +93,7 @@ namespace Swan.Messaging
/// <returns> /// <returns>
/// MessageSubscription used to unsubscribing. /// MessageSubscription used to unsubscribing.
/// </returns> /// </returns>
MessageHubSubscriptionToken Subscribe<TMessage>( MessageHubSubscriptionToken Subscribe<TMessage>(Action<TMessage> deliveryAction, Func<TMessage, Boolean> messageFilter, Boolean useStrongReferences, IMessageHubProxy proxy) where TMessage : class, IMessageHubMessage;
Action<TMessage> deliveryAction,
Func<TMessage, bool> messageFilter,
bool useStrongReferences,
IMessageHubProxy proxy)
where TMessage : class, IMessageHubMessage;
/// <summary> /// <summary>
/// Unsubscribe from a particular message type. /// Unsubscribe from a particular message type.
@ -117,16 +102,14 @@ namespace Swan.Messaging
/// </summary> /// </summary>
/// <typeparam name="TMessage">Type of message.</typeparam> /// <typeparam name="TMessage">Type of message.</typeparam>
/// <param name="subscriptionToken">Subscription token received from Subscribe.</param> /// <param name="subscriptionToken">Subscription token received from Subscribe.</param>
void Unsubscribe<TMessage>(MessageHubSubscriptionToken subscriptionToken) void Unsubscribe<TMessage>(MessageHubSubscriptionToken subscriptionToken) where TMessage : class, IMessageHubMessage;
where TMessage : class, IMessageHubMessage;
/// <summary> /// <summary>
/// Publish a message to any subscribers. /// Publish a message to any subscribers.
/// </summary> /// </summary>
/// <typeparam name="TMessage">Type of message.</typeparam> /// <typeparam name="TMessage">Type of message.</typeparam>
/// <param name="message">Message to deliver.</param> /// <param name="message">Message to deliver.</param>
void Publish<TMessage>(TMessage message) void Publish<TMessage>(TMessage message) where TMessage : class, IMessageHubMessage;
where TMessage : class, IMessageHubMessage;
/// <summary> /// <summary>
/// Publish a message to any subscribers asynchronously. /// Publish a message to any subscribers asynchronously.
@ -134,8 +117,7 @@ namespace Swan.Messaging
/// <typeparam name="TMessage">Type of message.</typeparam> /// <typeparam name="TMessage">Type of message.</typeparam>
/// <param name="message">Message to deliver.</param> /// <param name="message">Message to deliver.</param>
/// <returns>A task from Publish action.</returns> /// <returns>A task from Publish action.</returns>
Task PublishAsync<TMessage>(TMessage message) Task PublishAsync<TMessage>(TMessage message) where TMessage : class, IMessageHubMessage;
where TMessage : class, IMessageHubMessage;
} }
#endregion #endregion
@ -182,18 +164,14 @@ namespace Swan.Messaging
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public sealed class MessageHub : IMessageHub public sealed class MessageHub : IMessageHub {
{
#region Private Types and Interfaces #region Private Types and Interfaces
private readonly object _subscriptionsPadlock = new object(); private readonly Object _subscriptionsPadlock = new Object();
private readonly Dictionary<Type, List<SubscriptionItem>> _subscriptions = private readonly Dictionary<Type, List<SubscriptionItem>> _subscriptions = new Dictionary<Type, List<SubscriptionItem>>();
new Dictionary<Type, List<SubscriptionItem>>();
private class WeakMessageSubscription<TMessage> : IMessageHubSubscription private class WeakMessageSubscription<TMessage> : IMessageHubSubscription where TMessage : class, IMessageHubMessage {
where TMessage : class, IMessageHubMessage
{
private readonly WeakReference _deliveryAction; private readonly WeakReference _deliveryAction;
private readonly WeakReference _messageFilter; private readonly WeakReference _messageFilter;
@ -208,38 +186,28 @@ namespace Swan.Messaging
/// deliveryAction /// deliveryAction
/// or /// or
/// messageFilter.</exception> /// messageFilter.</exception>
public WeakMessageSubscription( public WeakMessageSubscription(MessageHubSubscriptionToken subscriptionToken, Action<TMessage> deliveryAction, Func<TMessage, Boolean> messageFilter) {
MessageHubSubscriptionToken subscriptionToken, this.SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken));
Action<TMessage> deliveryAction, this._deliveryAction = new WeakReference(deliveryAction);
Func<TMessage, bool> messageFilter) this._messageFilter = new WeakReference(messageFilter);
{
SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken));
_deliveryAction = new WeakReference(deliveryAction);
_messageFilter = new WeakReference(messageFilter);
} }
public MessageHubSubscriptionToken SubscriptionToken { get; } public MessageHubSubscriptionToken SubscriptionToken {
get;
public bool ShouldAttemptDelivery(IMessageHubMessage message)
{
return _deliveryAction.IsAlive && _messageFilter.IsAlive &&
((Func<TMessage, bool>) _messageFilter.Target).Invoke((TMessage) message);
} }
public void Deliver(IMessageHubMessage message) public Boolean ShouldAttemptDelivery(IMessageHubMessage message) => this._deliveryAction.IsAlive && this._messageFilter.IsAlive && ((Func<TMessage, Boolean>)this._messageFilter.Target!).Invoke((TMessage)message);
{
if (_deliveryAction.IsAlive) public void Deliver(IMessageHubMessage message) {
{ if(this._deliveryAction.IsAlive) {
((Action<TMessage>) _deliveryAction.Target).Invoke((TMessage) message); ((Action<TMessage>)this._deliveryAction.Target!).Invoke((TMessage)message);
} }
} }
} }
private class StrongMessageSubscription<TMessage> : IMessageHubSubscription private class StrongMessageSubscription<TMessage> : IMessageHubSubscription where TMessage : class, IMessageHubMessage {
where TMessage : class, IMessageHubMessage
{
private readonly Action<TMessage> _deliveryAction; private readonly Action<TMessage> _deliveryAction;
private readonly Func<TMessage, bool> _messageFilter; private readonly Func<TMessage, Boolean> _messageFilter;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="StrongMessageSubscription{TMessage}" /> class. /// Initializes a new instance of the <see cref="StrongMessageSubscription{TMessage}" /> class.
@ -252,37 +220,37 @@ namespace Swan.Messaging
/// deliveryAction /// deliveryAction
/// or /// or
/// messageFilter.</exception> /// messageFilter.</exception>
public StrongMessageSubscription( public StrongMessageSubscription(MessageHubSubscriptionToken subscriptionToken, Action<TMessage> deliveryAction, Func<TMessage, Boolean> messageFilter) {
MessageHubSubscriptionToken subscriptionToken, this.SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken));
Action<TMessage> deliveryAction, this._deliveryAction = deliveryAction;
Func<TMessage, bool> messageFilter) this._messageFilter = messageFilter;
{
SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken));
_deliveryAction = deliveryAction;
_messageFilter = messageFilter;
} }
public MessageHubSubscriptionToken SubscriptionToken { get; } public MessageHubSubscriptionToken SubscriptionToken {
get;
}
public bool ShouldAttemptDelivery(IMessageHubMessage message) => _messageFilter.Invoke((TMessage) message); public Boolean ShouldAttemptDelivery(IMessageHubMessage message) => this._messageFilter.Invoke((TMessage)message);
public void Deliver(IMessageHubMessage message) => _deliveryAction.Invoke((TMessage) message); public void Deliver(IMessageHubMessage message) => this._deliveryAction.Invoke((TMessage)message);
} }
#endregion #endregion
#region Subscription dictionary #region Subscription dictionary
private class SubscriptionItem private class SubscriptionItem {
{ public SubscriptionItem(IMessageHubProxy proxy, IMessageHubSubscription subscription) {
public SubscriptionItem(IMessageHubProxy proxy, IMessageHubSubscription subscription) this.Proxy = proxy;
{ this.Subscription = subscription;
Proxy = proxy;
Subscription = subscription;
} }
public IMessageHubProxy Proxy { get; } public IMessageHubProxy Proxy {
public IMessageHubSubscription Subscription { get; } get;
}
public IMessageHubSubscription Subscription {
get;
}
} }
#endregion #endregion
@ -300,14 +268,8 @@ namespace Swan.Messaging
/// <param name="useStrongReferences">Use strong references to destination and deliveryAction. </param> /// <param name="useStrongReferences">Use strong references to destination and deliveryAction. </param>
/// <param name="proxy">Proxy to use when delivering the messages.</param> /// <param name="proxy">Proxy to use when delivering the messages.</param>
/// <returns>MessageSubscription used to unsubscribing.</returns> /// <returns>MessageSubscription used to unsubscribing.</returns>
public MessageHubSubscriptionToken Subscribe<TMessage>( public MessageHubSubscriptionToken Subscribe<TMessage>(Action<TMessage> deliveryAction, Boolean useStrongReferences = true, IMessageHubProxy? proxy = null) where TMessage : class, IMessageHubMessage => this.Subscribe(deliveryAction, m => true, useStrongReferences, proxy);
Action<TMessage> deliveryAction,
bool useStrongReferences = true,
IMessageHubProxy? proxy = null)
where TMessage : class, IMessageHubMessage
{
return Subscribe(deliveryAction, m => true, useStrongReferences, proxy);
}
/// <summary> /// <summary>
/// Subscribe to a message type with the given destination and delivery action with the given filter. /// Subscribe to a message type with the given destination and delivery action with the given filter.
@ -323,44 +285,25 @@ namespace Swan.Messaging
/// <returns> /// <returns>
/// MessageSubscription used to unsubscribing. /// MessageSubscription used to unsubscribing.
/// </returns> /// </returns>
public MessageHubSubscriptionToken Subscribe<TMessage>( [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0068:Empfohlenes Dispose-Muster verwenden", Justification = "<Ausstehend>")]
Action<TMessage> deliveryAction, public MessageHubSubscriptionToken Subscribe<TMessage>(Action<TMessage> deliveryAction, Func<TMessage, Boolean> messageFilter, Boolean useStrongReferences = true, IMessageHubProxy? proxy = null) where TMessage : class, IMessageHubMessage {
Func<TMessage, bool> messageFilter, if(deliveryAction == null) {
bool useStrongReferences = true,
IMessageHubProxy? proxy = null)
where TMessage : class, IMessageHubMessage
{
if (deliveryAction == null)
throw new ArgumentNullException(nameof(deliveryAction)); throw new ArgumentNullException(nameof(deliveryAction));
}
if (messageFilter == null) if(messageFilter == null) {
throw new ArgumentNullException(nameof(messageFilter)); throw new ArgumentNullException(nameof(messageFilter));
}
lock (_subscriptionsPadlock) lock(this._subscriptionsPadlock) {
{ if(!this._subscriptions.TryGetValue(typeof(TMessage), out List<SubscriptionItem>? currentSubscriptions)) {
if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions))
{
currentSubscriptions = new List<SubscriptionItem>(); currentSubscriptions = new List<SubscriptionItem>();
_subscriptions[typeof(TMessage)] = currentSubscriptions; this._subscriptions[typeof(TMessage)] = currentSubscriptions;
} }
var subscriptionToken = new MessageHubSubscriptionToken(this, typeof(TMessage)); MessageHubSubscriptionToken subscriptionToken = new MessageHubSubscriptionToken(this, typeof(TMessage));
IMessageHubSubscription subscription; IMessageHubSubscription subscription = useStrongReferences ? new StrongMessageSubscription<TMessage>(subscriptionToken, deliveryAction, messageFilter) : (IMessageHubSubscription)new WeakMessageSubscription<TMessage>(subscriptionToken, deliveryAction, messageFilter);
if (useStrongReferences)
{
subscription = new StrongMessageSubscription<TMessage>(
subscriptionToken,
deliveryAction,
messageFilter);
}
else
{
subscription = new WeakMessageSubscription<TMessage>(
subscriptionToken,
deliveryAction,
messageFilter);
}
currentSubscriptions.Add(new SubscriptionItem(proxy ?? MessageHubDefaultProxy.Instance, subscription)); currentSubscriptions.Add(new SubscriptionItem(proxy ?? MessageHubDefaultProxy.Instance, subscription));
@ -369,20 +312,17 @@ namespace Swan.Messaging
} }
/// <inheritdoc /> /// <inheritdoc />
public void Unsubscribe<TMessage>(MessageHubSubscriptionToken subscriptionToken) public void Unsubscribe<TMessage>(MessageHubSubscriptionToken subscriptionToken) where TMessage : class, IMessageHubMessage {
where TMessage : class, IMessageHubMessage if(subscriptionToken == null) {
{
if (subscriptionToken == null)
throw new ArgumentNullException(nameof(subscriptionToken)); throw new ArgumentNullException(nameof(subscriptionToken));
}
lock (_subscriptionsPadlock) lock(this._subscriptionsPadlock) {
{ if(!this._subscriptions.TryGetValue(typeof(TMessage), out List<SubscriptionItem>? currentSubscriptions)) {
if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions))
return; return;
}
var currentlySubscribed = currentSubscriptions List<SubscriptionItem> currentlySubscribed = currentSubscriptions.Where(sub => ReferenceEquals(sub.Subscription.SubscriptionToken, subscriptionToken)).ToList();
.Where(sub => ReferenceEquals(sub.Subscription.SubscriptionToken, subscriptionToken))
.ToList();
currentlySubscribed.ForEach(sub => currentSubscriptions.Remove(sub)); currentlySubscribed.ForEach(sub => currentSubscriptions.Remove(sub));
} }
@ -393,31 +333,24 @@ namespace Swan.Messaging
/// </summary> /// </summary>
/// <typeparam name="TMessage">Type of message.</typeparam> /// <typeparam name="TMessage">Type of message.</typeparam>
/// <param name="message">Message to deliver.</param> /// <param name="message">Message to deliver.</param>
public void Publish<TMessage>(TMessage message) public void Publish<TMessage>(TMessage message) where TMessage : class, IMessageHubMessage {
where TMessage : class, IMessageHubMessage if(message == null) {
{
if (message == null)
throw new ArgumentNullException(nameof(message)); throw new ArgumentNullException(nameof(message));
}
List<SubscriptionItem> currentlySubscribed; List<SubscriptionItem> currentlySubscribed;
lock (_subscriptionsPadlock) lock(this._subscriptionsPadlock) {
{ if(!this._subscriptions.TryGetValue(typeof(TMessage), out List<SubscriptionItem>? currentSubscriptions)) {
if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions))
return; return;
currentlySubscribed = currentSubscriptions
.Where(sub => sub.Subscription.ShouldAttemptDelivery(message))
.ToList();
} }
currentlySubscribed.ForEach(sub => currentlySubscribed = currentSubscriptions.Where(sub => sub.Subscription.ShouldAttemptDelivery(message)).ToList();
{ }
try
{ currentlySubscribed.ForEach(sub => {
try {
sub.Proxy.Deliver(message, sub.Subscription); sub.Proxy.Deliver(message, sub.Subscription);
} } catch {
catch
{
// Ignore any errors and carry on // Ignore any errors and carry on
} }
}); });
@ -429,11 +362,7 @@ namespace Swan.Messaging
/// <typeparam name="TMessage">Type of message.</typeparam> /// <typeparam name="TMessage">Type of message.</typeparam>
/// <param name="message">Message to deliver.</param> /// <param name="message">Message to deliver.</param>
/// <returns>A task with the publish.</returns> /// <returns>A task with the publish.</returns>
public Task PublishAsync<TMessage>(TMessage message) public Task PublishAsync<TMessage>(TMessage message) where TMessage : class, IMessageHubMessage => Task.Run(() => this.Publish(message));
where TMessage : class, IMessageHubMessage
{
return Task.Run(() => Publish(message));
}
#endregion #endregion
} }

View File

@ -1,13 +1,10 @@
namespace Swan.Messaging using System;
{
using System;
namespace Swan.Messaging {
/// <summary> /// <summary>
/// Base class for messages that provides weak reference storage of the sender. /// Base class for messages that provides weak reference storage of the sender.
/// </summary> /// </summary>
public abstract class MessageHubMessageBase public abstract class MessageHubMessageBase : IMessageHubMessage {
: IMessageHubMessage
{
/// <summary> /// <summary>
/// Store a WeakReference to the sender just in case anyone is daft enough to /// Store a WeakReference to the sender just in case anyone is daft enough to
/// keep the message around and prevent the sender from being collected. /// keep the message around and prevent the sender from being collected.
@ -19,39 +16,35 @@
/// </summary> /// </summary>
/// <param name="sender">The sender.</param> /// <param name="sender">The sender.</param>
/// <exception cref="System.ArgumentNullException">sender.</exception> /// <exception cref="System.ArgumentNullException">sender.</exception>
protected MessageHubMessageBase(object sender) protected MessageHubMessageBase(Object sender) {
{ if(sender == null) {
if (sender == null)
throw new ArgumentNullException(nameof(sender)); throw new ArgumentNullException(nameof(sender));
}
_sender = new WeakReference(sender); this._sender = new WeakReference(sender);
} }
/// <inheritdoc /> /// <inheritdoc />
public object Sender => _sender.Target; public Object Sender => this._sender.Target;
} }
/// <summary> /// <summary>
/// Generic message with user specified content. /// Generic message with user specified content.
/// </summary> /// </summary>
/// <typeparam name="TContent">Content type to store.</typeparam> /// <typeparam name="TContent">Content type to store.</typeparam>
public class MessageHubGenericMessage<TContent> public class MessageHubGenericMessage<TContent> : MessageHubMessageBase {
: MessageHubMessageBase
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="MessageHubGenericMessage{TContent}"/> class. /// Initializes a new instance of the <see cref="MessageHubGenericMessage{TContent}"/> class.
/// </summary> /// </summary>
/// <param name="sender">The sender.</param> /// <param name="sender">The sender.</param>
/// <param name="content">The content.</param> /// <param name="content">The content.</param>
public MessageHubGenericMessage(object sender, TContent content) public MessageHubGenericMessage(Object sender, TContent content) : base(sender) => this.Content = content;
: base(sender)
{
Content = content;
}
/// <summary> /// <summary>
/// Contents of the message. /// Contents of the message.
/// </summary> /// </summary>
public TContent Content { get; protected set; } public TContent Content {
get; protected set;
}
} }
} }

View File

@ -1,13 +1,11 @@
namespace Swan.Messaging using System;
{ using System.Reflection;
using System;
namespace Swan.Messaging {
/// <summary> /// <summary>
/// Represents an active subscription to a message. /// Represents an active subscription to a message.
/// </summary> /// </summary>
public sealed class MessageHubSubscriptionToken public sealed class MessageHubSubscriptionToken : IDisposable {
: IDisposable
{
private readonly WeakReference _hub; private readonly WeakReference _hub;
private readonly Type _messageType; private readonly Type _messageType;
@ -18,31 +16,25 @@
/// <param name="messageType">Type of the message.</param> /// <param name="messageType">Type of the message.</param>
/// <exception cref="System.ArgumentNullException">hub.</exception> /// <exception cref="System.ArgumentNullException">hub.</exception>
/// <exception cref="System.ArgumentOutOfRangeException">messageType.</exception> /// <exception cref="System.ArgumentOutOfRangeException">messageType.</exception>
public MessageHubSubscriptionToken(IMessageHub hub, Type messageType) public MessageHubSubscriptionToken(IMessageHub hub, Type messageType) {
{ if(hub == null) {
if (hub == null)
{
throw new ArgumentNullException(nameof(hub)); throw new ArgumentNullException(nameof(hub));
} }
if (!typeof(IMessageHubMessage).IsAssignableFrom(messageType)) if(!typeof(IMessageHubMessage).IsAssignableFrom(messageType)) {
{
throw new ArgumentOutOfRangeException(nameof(messageType)); throw new ArgumentOutOfRangeException(nameof(messageType));
} }
_hub = new WeakReference(hub); this._hub = new WeakReference(hub);
_messageType = messageType; this._messageType = messageType;
} }
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ if(this._hub.IsAlive && this._hub.Target is IMessageHub hub) {
if (_hub.IsAlive && _hub.Target is IMessageHub hub) MethodInfo unsubscribeMethod = typeof(IMessageHub).GetMethod(nameof(IMessageHub.Unsubscribe), new[] { typeof(MessageHubSubscriptionToken) });
{ unsubscribeMethod = unsubscribeMethod.MakeGenericMethod(this._messageType);
var unsubscribeMethod = typeof(IMessageHub).GetMethod(nameof(IMessageHub.Unsubscribe), _ = unsubscribeMethod.Invoke(hub, new Object[] { this });
new[] {typeof(MessageHubSubscriptionToken)});
unsubscribeMethod = unsubscribeMethod.MakeGenericMethod(_messageType);
unsubscribeMethod.Invoke(hub, new object[] {this});
} }
GC.SuppressFinalize(this); GC.SuppressFinalize(this);

View File

@ -1,6 +1,5 @@
namespace Swan.Net #nullable enable
{ using Swan.Logging;
using Logging;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
@ -13,6 +12,7 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Net {
/// <summary> /// <summary>
/// Represents a network connection either on the server or on the client. It wraps a TcpClient /// Represents a network connection either on the server or on the client. It wraps a TcpClient
/// and its corresponding network streams. It is capable of working in 2 modes. Typically on the server side /// and its corresponding network streams. It is capable of working in 2 modes. Typically on the server side
@ -78,31 +78,30 @@
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public sealed class Connection : IDisposable public sealed class Connection : IDisposable {
{
// New Line definitions for reading. This applies to both, events and read methods // New Line definitions for reading. This applies to both, events and read methods
private readonly string _newLineSequence; private readonly String _newLineSequence;
private readonly byte[] _newLineSequenceBytes; private readonly Byte[] _newLineSequenceBytes;
private readonly char[] _newLineSequenceChars; private readonly Char[] _newLineSequenceChars;
private readonly string[] _newLineSequenceLineSplitter; private readonly String[] _newLineSequenceLineSplitter;
private readonly byte[] _receiveBuffer; private readonly Byte[] _receiveBuffer;
private readonly TimeSpan _continuousReadingInterval = TimeSpan.FromMilliseconds(5); private readonly TimeSpan _continuousReadingInterval = TimeSpan.FromMilliseconds(5);
private readonly Queue<string> _readLineBuffer = new Queue<string>(); private readonly Queue<String> _readLineBuffer = new Queue<String>();
private readonly ManualResetEvent _writeDone = new ManualResetEvent(true); private readonly ManualResetEvent _writeDone = new ManualResetEvent(true);
// Disconnect and Dispose // Disconnect and Dispose
private bool _hasDisposed; private Boolean _hasDisposed;
private int _disconnectCalls; private Int32 _disconnectCalls;
// Continuous Reading // Continuous Reading
private Thread _continuousReadingThread; private Thread? _continuousReadingThread;
private int _receiveBufferPointer; private Int32 _receiveBufferPointer;
// Reading and writing // Reading and writing
private Task<int> _readTask; private Task<Int32>? _readTask;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="Connection"/> class. /// Initializes a new instance of the <see cref="Connection"/> class.
@ -112,57 +111,51 @@
/// <param name="newLineSequence">The new line sequence used for read and write operations.</param> /// <param name="newLineSequence">The new line sequence used for read and write operations.</param>
/// <param name="disableContinuousReading">if set to <c>true</c> [disable continuous reading].</param> /// <param name="disableContinuousReading">if set to <c>true</c> [disable continuous reading].</param>
/// <param name="blockSize">Size of the block. -- set to 0 or less to disable.</param> /// <param name="blockSize">Size of the block. -- set to 0 or less to disable.</param>
public Connection( public Connection(TcpClient client, Encoding textEncoding, String newLineSequence, Boolean disableContinuousReading, Int32 blockSize) {
TcpClient client,
Encoding textEncoding,
string newLineSequence,
bool disableContinuousReading,
int blockSize)
{
// Setup basic properties // Setup basic properties
Id = Guid.NewGuid(); this.Id = Guid.NewGuid();
TextEncoding = textEncoding; this.TextEncoding = textEncoding;
// Setup new line sequence // Setup new line sequence
if (string.IsNullOrEmpty(newLineSequence)) if(String.IsNullOrEmpty(newLineSequence)) {
throw new ArgumentException("Argument cannot be null", nameof(newLineSequence)); throw new ArgumentException("Argument cannot be null", nameof(newLineSequence));
}
_newLineSequence = newLineSequence; this._newLineSequence = newLineSequence;
_newLineSequenceBytes = TextEncoding.GetBytes(_newLineSequence); this._newLineSequenceBytes = this.TextEncoding.GetBytes(this._newLineSequence);
_newLineSequenceChars = _newLineSequence.ToCharArray(); this._newLineSequenceChars = this._newLineSequence.ToCharArray();
_newLineSequenceLineSplitter = new[] { _newLineSequence }; this._newLineSequenceLineSplitter = new[] { this._newLineSequence };
// Setup Connection timers // Setup Connection timers
ConnectionStartTimeUtc = DateTime.UtcNow; this.ConnectionStartTimeUtc = DateTime.UtcNow;
DataReceivedLastTimeUtc = ConnectionStartTimeUtc; this.DataReceivedLastTimeUtc = this.ConnectionStartTimeUtc;
DataSentLastTimeUtc = ConnectionStartTimeUtc; this.DataSentLastTimeUtc = this.ConnectionStartTimeUtc;
// Setup connection properties // Setup connection properties
RemoteClient = client; this.RemoteClient = client;
LocalEndPoint = client.Client.LocalEndPoint as IPEndPoint; this.LocalEndPoint = client.Client.LocalEndPoint as IPEndPoint;
NetworkStream = RemoteClient.GetStream(); this.NetworkStream = this.RemoteClient.GetStream();
RemoteEndPoint = RemoteClient.Client.RemoteEndPoint as IPEndPoint; this.RemoteEndPoint = this.RemoteClient.Client.RemoteEndPoint as IPEndPoint;
// Setup buffers // Setup buffers
_receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; this._receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2];
ProtocolBlockSize = blockSize; this.ProtocolBlockSize = blockSize;
_receiveBufferPointer = 0; this._receiveBufferPointer = 0;
// Setup continuous reading mode if enabled // Setup continuous reading mode if enabled
if (disableContinuousReading) return; if(disableContinuousReading) {
return;
ThreadPool.GetAvailableThreads(out var availableWorkerThreads, out _);
ThreadPool.GetMaxThreads(out var maxWorkerThreads, out _);
var activeThreadPoolTreads = maxWorkerThreads - availableWorkerThreads;
if (activeThreadPoolTreads < Environment.ProcessorCount / 4)
{
ThreadPool.QueueUserWorkItem(PerformContinuousReading, this);
} }
else
{ ThreadPool.GetAvailableThreads(out Int32 availableWorkerThreads, out _);
new Thread(PerformContinuousReading) { IsBackground = true }.Start(); ThreadPool.GetMaxThreads(out Int32 maxWorkerThreads, out _);
Int32 activeThreadPoolTreads = maxWorkerThreads - availableWorkerThreads;
if(activeThreadPoolTreads < Environment.ProcessorCount / 4) {
_ = ThreadPool.QueueUserWorkItem(this.PerformContinuousReading!, this);
} else {
new Thread(this.PerformContinuousReading!) { IsBackground = true }.Start();
} }
} }
@ -171,9 +164,7 @@
/// It uses UTF8 encoding, CRLF as a new line sequence and disables a protocol block size. /// It uses UTF8 encoding, CRLF as a new line sequence and disables a protocol block size.
/// </summary> /// </summary>
/// <param name="client">The client.</param> /// <param name="client">The client.</param>
public Connection(TcpClient client) public Connection(TcpClient client) : this(client, Encoding.UTF8, "\r\n", false, 0) {
: this(client, Encoding.UTF8, "\r\n", false, 0)
{
// placeholder // placeholder
} }
@ -183,9 +174,7 @@
/// </summary> /// </summary>
/// <param name="client">The client.</param> /// <param name="client">The client.</param>
/// <param name="blockSize">Size of the block.</param> /// <param name="blockSize">Size of the block.</param>
public Connection(TcpClient client, int blockSize) public Connection(TcpClient client, Int32 blockSize) : this(client, Encoding.UTF8, new String('\n', blockSize + 1), false, blockSize) {
: this(client, Encoding.UTF8, new string('\n', blockSize + 1), false, blockSize)
{
// placeholder // placeholder
} }
@ -217,7 +206,9 @@
/// <value> /// <value>
/// The identifier. /// The identifier.
/// </value> /// </value>
public Guid Id { get; } public Guid Id {
get;
}
/// <summary> /// <summary>
/// Gets the active stream. Returns an SSL stream if the connection is secure, otherwise returns /// Gets the active stream. Returns an SSL stream if the connection is secure, otherwise returns
@ -226,7 +217,7 @@
/// <value> /// <value>
/// The active stream. /// The active stream.
/// </value> /// </value>
public Stream ActiveStream => SecureStream ?? NetworkStream as Stream; public Stream? ActiveStream => this.SecureStream ?? this.NetworkStream as Stream;
/// <summary> /// <summary>
/// Gets a value indicating whether the current connection stream is an SSL stream. /// Gets a value indicating whether the current connection stream is an SSL stream.
@ -234,7 +225,7 @@
/// <value> /// <value>
/// <c>true</c> if this instance is active stream secure; otherwise, <c>false</c>. /// <c>true</c> if this instance is active stream secure; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsActiveStreamSecure => SecureStream != null; public Boolean IsActiveStreamSecure => this.SecureStream != null;
/// <summary> /// <summary>
/// Gets the text encoding for send and receive operations. /// Gets the text encoding for send and receive operations.
@ -242,7 +233,9 @@
/// <value> /// <value>
/// The text encoding. /// The text encoding.
/// </value> /// </value>
public Encoding TextEncoding { get; } public Encoding TextEncoding {
get;
}
/// <summary> /// <summary>
/// Gets the remote end point of this TCP connection. /// Gets the remote end point of this TCP connection.
@ -250,7 +243,9 @@
/// <value> /// <value>
/// The remote end point. /// The remote end point.
/// </value> /// </value>
public IPEndPoint RemoteEndPoint { get; } public IPEndPoint? RemoteEndPoint {
get;
}
/// <summary> /// <summary>
/// Gets the local end point of this TCP connection. /// Gets the local end point of this TCP connection.
@ -258,7 +253,9 @@
/// <value> /// <value>
/// The local end point. /// The local end point.
/// </value> /// </value>
public IPEndPoint LocalEndPoint { get; } public IPEndPoint? LocalEndPoint {
get;
}
/// <summary> /// <summary>
/// Gets the remote client of this TCP connection. /// Gets the remote client of this TCP connection.
@ -266,7 +263,9 @@
/// <value> /// <value>
/// The remote client. /// The remote client.
/// </value> /// </value>
public TcpClient RemoteClient { get; private set; } public TcpClient? RemoteClient {
get; private set;
}
/// <summary> /// <summary>
/// When in continuous reading mode, and if set to greater than 0, /// When in continuous reading mode, and if set to greater than 0,
@ -276,7 +275,9 @@
/// <value> /// <value>
/// The size of the protocol block. /// The size of the protocol block.
/// </value> /// </value>
public int ProtocolBlockSize { get; } public Int32 ProtocolBlockSize {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this connection is in continuous reading mode. /// Gets a value indicating whether this connection is in continuous reading mode.
@ -288,7 +289,7 @@
/// <value> /// <value>
/// <c>true</c> if this instance is continuous reading enabled; otherwise, <c>false</c>. /// <c>true</c> if this instance is continuous reading enabled; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsContinuousReadingEnabled => _continuousReadingThread != null; public Boolean IsContinuousReadingEnabled => this._continuousReadingThread != null;
/// <summary> /// <summary>
/// Gets the start time at which the connection was started in UTC. /// Gets the start time at which the connection was started in UTC.
@ -296,7 +297,9 @@
/// <value> /// <value>
/// The connection start time UTC. /// The connection start time UTC.
/// </value> /// </value>
public DateTime ConnectionStartTimeUtc { get; } public DateTime ConnectionStartTimeUtc {
get;
}
/// <summary> /// <summary>
/// Gets the start time at which the connection was started in local time. /// Gets the start time at which the connection was started in local time.
@ -304,7 +307,7 @@
/// <value> /// <value>
/// The connection start time. /// The connection start time.
/// </value> /// </value>
public DateTime ConnectionStartTime => ConnectionStartTimeUtc.ToLocalTime(); public DateTime ConnectionStartTime => this.ConnectionStartTimeUtc.ToLocalTime();
/// <summary> /// <summary>
/// Gets the duration of the connection. /// Gets the duration of the connection.
@ -312,7 +315,7 @@
/// <value> /// <value>
/// The duration of the connection. /// The duration of the connection.
/// </value> /// </value>
public TimeSpan ConnectionDuration => DateTime.UtcNow.Subtract(ConnectionStartTimeUtc); public TimeSpan ConnectionDuration => DateTime.UtcNow.Subtract(this.ConnectionStartTimeUtc);
/// <summary> /// <summary>
/// Gets the last time data was received at in UTC. /// Gets the last time data was received at in UTC.
@ -320,12 +323,14 @@
/// <value> /// <value>
/// The data received last time UTC. /// The data received last time UTC.
/// </value> /// </value>
public DateTime DataReceivedLastTimeUtc { get; private set; } public DateTime DataReceivedLastTimeUtc {
get; private set;
}
/// <summary> /// <summary>
/// Gets how long has elapsed since data was last received. /// Gets how long has elapsed since data was last received.
/// </summary> /// </summary>
public TimeSpan DataReceivedIdleDuration => DateTime.UtcNow.Subtract(DataReceivedLastTimeUtc); public TimeSpan DataReceivedIdleDuration => DateTime.UtcNow.Subtract(this.DataReceivedLastTimeUtc);
/// <summary> /// <summary>
/// Gets the last time at which data was sent in UTC. /// Gets the last time at which data was sent in UTC.
@ -333,7 +338,9 @@
/// <value> /// <value>
/// The data sent last time UTC. /// The data sent last time UTC.
/// </value> /// </value>
public DateTime DataSentLastTimeUtc { get; private set; } public DateTime DataSentLastTimeUtc {
get; private set;
}
/// <summary> /// <summary>
/// Gets how long has elapsed since data was last sent. /// Gets how long has elapsed since data was last sent.
@ -341,7 +348,7 @@
/// <value> /// <value>
/// The duration of the data sent idle. /// The duration of the data sent idle.
/// </value> /// </value>
public TimeSpan DataSentIdleDuration => DateTime.UtcNow.Subtract(DataSentLastTimeUtc); public TimeSpan DataSentIdleDuration => DateTime.UtcNow.Subtract(this.DataSentLastTimeUtc);
/// <summary> /// <summary>
/// Gets a value indicating whether this connection is connected. /// Gets a value indicating whether this connection is connected.
@ -351,35 +358,38 @@
/// <value> /// <value>
/// <c>true</c> if this instance is connected; otherwise, <c>false</c>. /// <c>true</c> if this instance is connected; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsConnected public Boolean IsConnected {
{ get {
get if(this._disconnectCalls > 0) {
{
if (_disconnectCalls > 0)
return false; return false;
}
try try {
{ Socket? socket = this.RemoteClient?.Client;
var socket = RemoteClient.Client; if(socket == null || this.NetworkStream == null) {
var pollResult = !((socket.Poll(1000, SelectMode.SelectRead) return false;
&& (NetworkStream.DataAvailable == false)) || !socket.Connected); }
Boolean pollResult = !(socket.Poll(1000, SelectMode.SelectRead) && this.NetworkStream.DataAvailable == false || !socket.Connected);
if (pollResult == false) if(pollResult == false) {
Disconnect(); this.Disconnect();
}
return pollResult; return pollResult;
} } catch {
catch this.Disconnect();
{
Disconnect();
return false; return false;
} }
} }
} }
private NetworkStream NetworkStream { get; set; } private NetworkStream? NetworkStream {
get; set;
}
private SslStream SecureStream { get; set; } private SslStream? SecureStream {
get; set;
}
#endregion #endregion
@ -393,58 +403,45 @@
/// <returns>A byte array containing the results of encoding the specified set of characters.</returns> /// <returns>A byte array containing the results of encoding the specified set of characters.</returns>
/// <exception cref="InvalidOperationException">Read methods have been disabled because continuous reading is enabled.</exception> /// <exception cref="InvalidOperationException">Read methods have been disabled because continuous reading is enabled.</exception>
/// <exception cref="TimeoutException">Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} m.</exception> /// <exception cref="TimeoutException">Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} m.</exception>
public async Task<byte[]> ReadDataAsync(TimeSpan timeout, CancellationToken cancellationToken = default) public async Task<Byte[]> ReadDataAsync(TimeSpan timeout, CancellationToken cancellationToken = default) {
{ if(this.IsContinuousReadingEnabled) {
if (IsContinuousReadingEnabled) throw new InvalidOperationException("Read methods have been disabled because continuous reading is enabled.");
{
throw new InvalidOperationException(
"Read methods have been disabled because continuous reading is enabled.");
} }
if (RemoteClient == null) if(this.RemoteClient == null) {
{
throw new InvalidOperationException("An open connection is required"); throw new InvalidOperationException("An open connection is required");
} }
var receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; Byte[] receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2];
var receiveBuilder = new List<byte>(receiveBuffer.Length); List<Byte> receiveBuilder = new List<Byte>(receiveBuffer.Length);
try try {
{ DateTime startTime = DateTime.UtcNow;
var startTime = DateTime.UtcNow;
while (receiveBuilder.Count <= 0) while(receiveBuilder.Count <= 0) {
{ if(DateTime.UtcNow.Subtract(startTime) >= timeout) {
if (DateTime.UtcNow.Subtract(startTime) >= timeout) throw new TimeoutException($"Reading data from {this.ActiveStream} timed out in {timeout.TotalMilliseconds} ms");
{
throw new TimeoutException(
$"Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} ms");
} }
if (_readTask == null) if(this._readTask == null) {
_readTask = ActiveStream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length, cancellationToken); this._readTask = this.ActiveStream?.ReadAsync(receiveBuffer, 0, receiveBuffer.Length, cancellationToken);
}
if (_readTask.Wait(_continuousReadingInterval)) if(this._readTask != null && this._readTask.Wait(this._continuousReadingInterval)) {
{ Int32 bytesReceivedCount = this._readTask.Result;
var bytesReceivedCount = _readTask.Result; if(bytesReceivedCount > 0) {
if (bytesReceivedCount > 0) this.DataReceivedLastTimeUtc = DateTime.UtcNow;
{ Byte[] buffer = new Byte[bytesReceivedCount];
DataReceivedLastTimeUtc = DateTime.UtcNow;
var buffer = new byte[bytesReceivedCount];
Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount);
receiveBuilder.AddRange(buffer); receiveBuilder.AddRange(buffer);
} }
_readTask = null; this._readTask = null;
} } else {
else await Task.Delay(this._continuousReadingInterval, cancellationToken).ConfigureAwait(false);
{
await Task.Delay(_continuousReadingInterval, cancellationToken).ConfigureAwait(false);
} }
} }
} } catch(Exception ex) {
catch (Exception ex)
{
ex.Error(typeof(Connection).FullName, "Error while reading network stream data asynchronously."); ex.Error(typeof(Connection).FullName, "Error while reading network stream data asynchronously.");
throw; throw;
} }
@ -459,8 +456,7 @@
/// <returns> /// <returns>
/// A byte array containing the results the specified sequence of bytes. /// A byte array containing the results the specified sequence of bytes.
/// </returns> /// </returns>
public Task<byte[]> ReadDataAsync(CancellationToken cancellationToken = default) public Task<Byte[]> ReadDataAsync(CancellationToken cancellationToken = default) => this.ReadDataAsync(TimeSpan.FromSeconds(5), cancellationToken);
=> ReadDataAsync(TimeSpan.FromSeconds(5), cancellationToken);
/// <summary> /// <summary>
/// Asynchronously reads data as text with the given timeout. /// Asynchronously reads data as text with the given timeout.
@ -470,10 +466,9 @@
/// <returns> /// <returns>
/// A <see cref="System.String" /> that contains the results of decoding the specified sequence of bytes. /// A <see cref="System.String" /> that contains the results of decoding the specified sequence of bytes.
/// </returns> /// </returns>
public async Task<string?> ReadTextAsync(TimeSpan timeout, CancellationToken cancellationToken = default) public async Task<String?> ReadTextAsync(TimeSpan timeout, CancellationToken cancellationToken = default) {
{ Byte[] buffer = await this.ReadDataAsync(timeout, cancellationToken).ConfigureAwait(false);
var buffer = await ReadDataAsync(timeout, cancellationToken).ConfigureAwait(false); return buffer == null ? null : this.TextEncoding.GetString(buffer);
return buffer == null ? null : TextEncoding.GetString(buffer);
} }
/// <summary> /// <summary>
@ -483,8 +478,7 @@
/// <returns> /// <returns>
/// When this method completes successfully, it returns the contents of the file as a text string. /// When this method completes successfully, it returns the contents of the file as a text string.
/// </returns> /// </returns>
public Task<string?> ReadTextAsync(CancellationToken cancellationToken = default) public Task<String?> ReadTextAsync(CancellationToken cancellationToken = default) => this.ReadTextAsync(TimeSpan.FromSeconds(5), cancellationToken);
=> ReadTextAsync(TimeSpan.FromSeconds(5), cancellationToken);
/// <summary> /// <summary>
/// Performs the same task as this method's overload but it defaults to a read timeout of 30 seconds. /// Performs the same task as this method's overload but it defaults to a read timeout of 30 seconds.
@ -494,8 +488,7 @@
/// A task that represents the asynchronous read operation. The value of the TResult parameter /// A task that represents the asynchronous read operation. The value of the TResult parameter
/// contains the next line from the stream, or is null if all the characters have been read. /// contains the next line from the stream, or is null if all the characters have been read.
/// </returns> /// </returns>
public Task<string?> ReadLineAsync(CancellationToken cancellationToken = default) public Task<String?> ReadLineAsync(CancellationToken cancellationToken = default) => this.ReadLineAsync(TimeSpan.FromSeconds(30), cancellationToken);
=> ReadLineAsync(TimeSpan.FromSeconds(30), cancellationToken);
/// <summary> /// <summary>
/// Reads the next available line of text in queue. Return null when no text is read. /// Reads the next available line of text in queue. Return null when no text is read.
@ -508,39 +501,39 @@
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task with a string line from the queue.</returns> /// <returns>A task with a string line from the queue.</returns>
/// <exception cref="InvalidOperationException">Read methods have been disabled because continuous reading is enabled.</exception> /// <exception cref="InvalidOperationException">Read methods have been disabled because continuous reading is enabled.</exception>
public async Task<string?> ReadLineAsync(TimeSpan timeout, CancellationToken cancellationToken = default) public async Task<String?> ReadLineAsync(TimeSpan timeout, CancellationToken cancellationToken = default) {
{ if(this.IsContinuousReadingEnabled) {
if (IsContinuousReadingEnabled) throw new InvalidOperationException("Read methods have been disabled because continuous reading is enabled.");
{
throw new InvalidOperationException(
"Read methods have been disabled because continuous reading is enabled.");
} }
if (_readLineBuffer.Count > 0) if(this._readLineBuffer.Count > 0) {
return _readLineBuffer.Dequeue(); return this._readLineBuffer.Dequeue();
}
var builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
while (true) while(true) {
{ String? text = await this.ReadTextAsync(timeout, cancellationToken).ConfigureAwait(false);
var text = await ReadTextAsync(timeout, cancellationToken).ConfigureAwait(false);
if (string.IsNullOrEmpty(text)) if(String.IsNullOrEmpty(text)) {
break; break;
}
builder.Append(text); _ = builder.Append(text);
if (!text.EndsWith(_newLineSequence)) continue; if(!text.EndsWith(this._newLineSequence)) {
continue;
}
var lines = builder.ToString().TrimEnd(_newLineSequenceChars) String[] lines = builder.ToString().TrimEnd(this._newLineSequenceChars).Split(this._newLineSequenceLineSplitter, StringSplitOptions.None);
.Split(_newLineSequenceLineSplitter, StringSplitOptions.None); foreach(String item in lines) {
foreach (var item in lines) this._readLineBuffer.Enqueue(item);
_readLineBuffer.Enqueue(item); }
break; break;
} }
return _readLineBuffer.Count > 0 ? _readLineBuffer.Dequeue() : null; return this._readLineBuffer.Count > 0 ? this._readLineBuffer.Dequeue() : null;
} }
#endregion #endregion
@ -554,21 +547,21 @@
/// <param name="forceFlush">if set to <c>true</c> [force flush].</param> /// <param name="forceFlush">if set to <c>true</c> [force flush].</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous write operation.</returns> /// <returns>A task that represents the asynchronous write operation.</returns>
public async Task WriteDataAsync(byte[] buffer, bool forceFlush, CancellationToken cancellationToken = default) public async Task WriteDataAsync(Byte[] buffer, Boolean forceFlush, CancellationToken cancellationToken = default) {
{ try {
try _ = this._writeDone.WaitOne();
{ _ = this._writeDone.Reset();
_writeDone.WaitOne(); if(this.ActiveStream == null) {
_writeDone.Reset(); return;
await ActiveStream.WriteAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
if (forceFlush)
await ActiveStream.FlushAsync(cancellationToken).ConfigureAwait(false);
DataSentLastTimeUtc = DateTime.UtcNow;
} }
finally await this.ActiveStream.WriteAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
{ if(forceFlush) {
_writeDone.Set(); await this.ActiveStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
this.DataSentLastTimeUtc = DateTime.UtcNow;
} finally {
_ = this._writeDone.Set();
} }
} }
@ -578,8 +571,7 @@
/// <param name="text">The text.</param> /// <param name="text">The text.</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous write operation.</returns> /// <returns>A task that represents the asynchronous write operation.</returns>
public Task WriteTextAsync(string text, CancellationToken cancellationToken = default) public Task WriteTextAsync(String text, CancellationToken cancellationToken = default) => this.WriteTextAsync(text, this.TextEncoding, cancellationToken);
=> WriteTextAsync(text, TextEncoding, cancellationToken);
/// <summary> /// <summary>
/// Writes text asynchronously. /// Writes text asynchronously.
@ -588,8 +580,7 @@
/// <param name="encoding">The encoding.</param> /// <param name="encoding">The encoding.</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous write operation.</returns> /// <returns>A task that represents the asynchronous write operation.</returns>
public Task WriteTextAsync(string text, Encoding encoding, CancellationToken cancellationToken = default) public Task WriteTextAsync(String text, Encoding encoding, CancellationToken cancellationToken = default) => this.WriteDataAsync(encoding.GetBytes(text), true, cancellationToken);
=> WriteDataAsync(encoding.GetBytes(text), true, cancellationToken);
/// <summary> /// <summary>
/// Writes a line of text asynchronously. /// Writes a line of text asynchronously.
@ -599,8 +590,7 @@
/// <param name="encoding">The encoding.</param> /// <param name="encoding">The encoding.</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous write operation.</returns> /// <returns>A task that represents the asynchronous write operation.</returns>
public Task WriteLineAsync(string line, Encoding encoding, CancellationToken cancellationToken = default) public Task WriteLineAsync(String line, Encoding encoding, CancellationToken cancellationToken = default) => this.WriteDataAsync(encoding.GetBytes($"{line}{this._newLineSequence}"), true, cancellationToken);
=> WriteDataAsync(encoding.GetBytes($"{line}{_newLineSequence}"), true, cancellationToken);
/// <summary> /// <summary>
/// Writes a line of text asynchronously. /// Writes a line of text asynchronously.
@ -609,8 +599,7 @@
/// <param name="line">The line.</param> /// <param name="line">The line.</param>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A task that represents the asynchronous write operation.</returns> /// <returns>A task that represents the asynchronous write operation.</returns>
public Task WriteLineAsync(string line, CancellationToken cancellationToken = default) public Task WriteLineAsync(String line, CancellationToken cancellationToken = default) => this.WriteLineAsync(line, this.TextEncoding, cancellationToken);
=> WriteLineAsync(line, TextEncoding, cancellationToken);
#endregion #endregion
@ -621,24 +610,21 @@
/// </summary> /// </summary>
/// <param name="serverCertificate">The server certificate.</param> /// <param name="serverCertificate">The server certificate.</param>
/// <returns><c>true</c> if the object is hosted in the server; otherwise, <c>false</c>.</returns> /// <returns><c>true</c> if the object is hosted in the server; otherwise, <c>false</c>.</returns>
public async Task<bool> UpgradeToSecureAsServerAsync(X509Certificate2 serverCertificate) public async Task<Boolean> UpgradeToSecureAsServerAsync(X509Certificate2 serverCertificate) {
{ if(this.IsActiveStreamSecure) {
if (IsActiveStreamSecure)
return true; return true;
}
_writeDone.WaitOne(); _ = this._writeDone.WaitOne();
SslStream? secureStream = null; SslStream? secureStream = null;
try try {
{ secureStream = new SslStream(this.NetworkStream, true);
secureStream = new SslStream(NetworkStream, true);
await secureStream.AuthenticateAsServerAsync(serverCertificate).ConfigureAwait(false); await secureStream.AuthenticateAsServerAsync(serverCertificate).ConfigureAwait(false);
SecureStream = secureStream; this.SecureStream = secureStream;
return true; return true;
} } catch(Exception ex) {
catch (Exception ex)
{
ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); ConnectionFailure(this, new ConnectionFailureEventArgs(ex));
secureStream?.Dispose(); secureStream?.Dispose();
@ -652,24 +638,17 @@
/// <param name="hostname">The hostname.</param> /// <param name="hostname">The hostname.</param>
/// <param name="callback">The callback.</param> /// <param name="callback">The callback.</param>
/// <returns>A tasks with <c>true</c> if the upgrade to SSL was successful; otherwise, <c>false</c>.</returns> /// <returns>A tasks with <c>true</c> if the upgrade to SSL was successful; otherwise, <c>false</c>.</returns>
public async Task<bool> UpgradeToSecureAsClientAsync( public async Task<Boolean> UpgradeToSecureAsClientAsync(String? hostname = null, RemoteCertificateValidationCallback? callback = null) {
string? hostname = null, if(this.IsActiveStreamSecure) {
RemoteCertificateValidationCallback? callback = null)
{
if (IsActiveStreamSecure)
return true; return true;
var secureStream = callback == null
? new SslStream(NetworkStream, true)
: new SslStream(NetworkStream, true, callback);
try
{
await secureStream.AuthenticateAsClientAsync(hostname ?? Network.HostName.ToLowerInvariant()).ConfigureAwait(false);
SecureStream = secureStream;
} }
catch (Exception ex)
{ SslStream secureStream = callback == null ? new SslStream(this.NetworkStream, true) : new SslStream(this.NetworkStream, true, callback);
try {
await secureStream.AuthenticateAsClientAsync(hostname ?? Network.HostName.ToLowerInvariant()).ConfigureAwait(false);
this.SecureStream = secureStream;
} catch(Exception ex) {
secureStream.Dispose(); secureStream.Dispose();
ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); ConnectionFailure(this, new ConnectionFailureEventArgs(ex));
return false; return false;
@ -681,41 +660,30 @@
/// <summary> /// <summary>
/// Disconnects this connection. /// Disconnects this connection.
/// </summary> /// </summary>
public void Disconnect() public void Disconnect() {
{ if(this._disconnectCalls > 0) {
if (_disconnectCalls > 0)
return; return;
_disconnectCalls++;
_writeDone.WaitOne();
try
{
ClientDisconnected(this, EventArgs.Empty);
} }
catch
{ this._disconnectCalls++;
_ = this._writeDone.WaitOne();
try {
ClientDisconnected(this, EventArgs.Empty);
} catch {
// ignore // ignore
} }
try try {
{ this.RemoteClient?.Dispose();
#if !NET461 this.SecureStream?.Dispose();
RemoteClient.Dispose(); this.NetworkStream?.Dispose();
SecureStream?.Dispose();
NetworkStream?.Dispose(); } finally {
#else this.NetworkStream = null;
RemoteClient.Close(); this.SecureStream = null;
SecureStream?.Close(); this.RemoteClient = null;
NetworkStream?.Close(); this._continuousReadingThread = null;
#endif
}
finally
{
NetworkStream = null;
SecureStream = null;
RemoteClient = null;
_continuousReadingThread = null;
} }
} }
@ -724,159 +692,141 @@
#region Dispose #region Dispose
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ if(this._hasDisposed) {
if (_hasDisposed)
return; return;
}
// Release managed resources // Release managed resources
Disconnect(); this.Disconnect();
_continuousReadingThread = null; this._continuousReadingThread = null;
_writeDone.Dispose(); this._writeDone.Dispose();
_hasDisposed = true; this._hasDisposed = true;
} }
#endregion #endregion
#region Continuous Read Methods #region Continuous Read Methods
private void RaiseReceiveBufferEvents(IEnumerable<byte> receivedData) private void RaiseReceiveBufferEvents(IEnumerable<Byte> receivedData) {
{ if(this.RemoteClient == null) {
var moreAvailable = RemoteClient.Available > 0; return;
}
Boolean moreAvailable = this.RemoteClient.Available > 0;
foreach (var data in receivedData) foreach(Byte data in receivedData) {
{ this.ProcessReceivedBlock(data, moreAvailable);
ProcessReceivedBlock(data, moreAvailable);
} }
// Check if we are left with some more stuff to handle // Check if we are left with some more stuff to handle
if (_receiveBufferPointer <= 0) if(this._receiveBufferPointer <= 0) {
return; return;
}
// Extract the segments split by newline terminated bytes // Extract the segments split by newline terminated bytes
var sequences = _receiveBuffer.Skip(0).Take(_receiveBufferPointer).ToArray() List<Byte[]> sequences = this._receiveBuffer.Skip(0).Take(this._receiveBufferPointer).ToArray().Split(0, this._newLineSequenceBytes);
.Split(0, _newLineSequenceBytes);
// Something really wrong happened // Something really wrong happened
if (sequences.Count == 0) if(sequences.Count == 0) {
throw new InvalidOperationException("Split function failed! This is terribly wrong!"); throw new InvalidOperationException("Split function failed! This is terribly wrong!");
}
// We only have one sequence and it is not newline-terminated // We only have one sequence and it is not newline-terminated
// we don't have to do anything. // we don't have to do anything.
if (sequences.Count == 1 && sequences[0].EndsWith(_newLineSequenceBytes) == false) if(sequences.Count == 1 && sequences[0].EndsWith(this._newLineSequenceBytes) == false) {
return; return;
}
// Process the events for each sequence // Process the events for each sequence
for (var i = 0; i < sequences.Count; i++) for(Int32 i = 0; i < sequences.Count; i++) {
{ Byte[] sequenceBytes = sequences[i];
var sequenceBytes = sequences[i]; Boolean isNewLineTerminated = sequences[i].EndsWith(this._newLineSequenceBytes);
var isNewLineTerminated = sequences[i].EndsWith(_newLineSequenceBytes); Boolean isLast = i == sequences.Count - 1;
var isLast = i == sequences.Count - 1;
if (isNewLineTerminated) if(isNewLineTerminated) {
{ ConnectionDataReceivedEventArgs eventArgs = new ConnectionDataReceivedEventArgs(sequenceBytes, ConnectionDataReceivedTrigger.NewLineSequenceEncountered, isLast == false);
var eventArgs = new ConnectionDataReceivedEventArgs(
sequenceBytes,
ConnectionDataReceivedTrigger.NewLineSequenceEncountered,
isLast == false);
DataReceived(this, eventArgs); DataReceived(this, eventArgs);
} }
// Depending on the last segment determine what to do with the receive buffer // Depending on the last segment determine what to do with the receive buffer
if (!isLast) continue; if(!isLast) {
continue;
if (isNewLineTerminated)
{
// Simply reset the buffer pointer if the last segment was also terminated
_receiveBufferPointer = 0;
} }
else
{ if(isNewLineTerminated) {
// Simply reset the buffer pointer if the last segment was also terminated
this._receiveBufferPointer = 0;
} else {
// If we have not received the termination sequence, then just shift the receive buffer to the left // If we have not received the termination sequence, then just shift the receive buffer to the left
// and adjust the pointer // and adjust the pointer
Array.Copy(sequenceBytes, _receiveBuffer, sequenceBytes.Length); Array.Copy(sequenceBytes, this._receiveBuffer, sequenceBytes.Length);
_receiveBufferPointer = sequenceBytes.Length; this._receiveBufferPointer = sequenceBytes.Length;
} }
} }
} }
private void ProcessReceivedBlock(byte data, bool moreAvailable) private void ProcessReceivedBlock(Byte data, Boolean moreAvailable) {
{ this._receiveBuffer[this._receiveBufferPointer] = data;
_receiveBuffer[_receiveBufferPointer] = data; this._receiveBufferPointer++;
_receiveBufferPointer++;
// Block size reached // Block size reached
if (ProtocolBlockSize > 0 && _receiveBufferPointer >= ProtocolBlockSize) if(this.ProtocolBlockSize > 0 && this._receiveBufferPointer >= this.ProtocolBlockSize) {
{ this.SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BlockSizeReached);
SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BlockSizeReached);
return; return;
} }
// The receive buffer is full. Time to flush // The receive buffer is full. Time to flush
if (_receiveBufferPointer >= _receiveBuffer.Length) if(this._receiveBufferPointer >= this._receiveBuffer.Length) {
{ this.SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BufferFull);
SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BufferFull);
} }
} }
private void SendBuffer(bool moreAvailable, ConnectionDataReceivedTrigger trigger) private void SendBuffer(Boolean moreAvailable, ConnectionDataReceivedTrigger trigger) {
{ Byte[] eventBuffer = new Byte[this._receiveBuffer.Length];
var eventBuffer = new byte[_receiveBuffer.Length]; Array.Copy(this._receiveBuffer, eventBuffer, eventBuffer.Length);
Array.Copy(_receiveBuffer, eventBuffer, eventBuffer.Length);
DataReceived(this, DataReceived(this, new ConnectionDataReceivedEventArgs(eventBuffer, trigger, moreAvailable));
new ConnectionDataReceivedEventArgs( this._receiveBufferPointer = 0;
eventBuffer,
trigger,
moreAvailable));
_receiveBufferPointer = 0;
} }
private void PerformContinuousReading(object threadContext) private void PerformContinuousReading(Object threadContext) {
{ this._continuousReadingThread = Thread.CurrentThread;
_continuousReadingThread = Thread.CurrentThread;
// Check if the RemoteClient is still there // Check if the RemoteClient is still there
if (RemoteClient == null) return; if(this.RemoteClient == null) {
return;
}
var receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; Byte[] receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2];
while (IsConnected && _disconnectCalls <= 0) while(this.IsConnected && this._disconnectCalls <= 0) {
{ Boolean doThreadSleep = false;
var doThreadSleep = false;
try try {
{ if(this._readTask == null) {
if (_readTask == null) this._readTask = this.ActiveStream?.ReadAsync(receiveBuffer, 0, receiveBuffer.Length);
_readTask = ActiveStream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length); }
if (_readTask.Wait(_continuousReadingInterval)) if(this._readTask != null && this._readTask.Wait(this._continuousReadingInterval)) {
{ Int32 bytesReceivedCount = this._readTask.Result;
var bytesReceivedCount = _readTask.Result; if(bytesReceivedCount > 0) {
if (bytesReceivedCount > 0) this.DataReceivedLastTimeUtc = DateTime.UtcNow;
{ Byte[] buffer = new Byte[bytesReceivedCount];
DataReceivedLastTimeUtc = DateTime.UtcNow;
var buffer = new byte[bytesReceivedCount];
Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount);
RaiseReceiveBufferEvents(buffer); this.RaiseReceiveBufferEvents(buffer);
} }
_readTask = null; this._readTask = null;
} else {
doThreadSleep = this._disconnectCalls <= 0;
} }
else } catch(Exception ex) {
{
doThreadSleep = _disconnectCalls <= 0;
}
}
catch (Exception ex)
{
ex.Log(nameof(PerformContinuousReading), "Continuous Read operation errored"); ex.Log(nameof(PerformContinuousReading), "Continuous Read operation errored");
} finally {
if(doThreadSleep) {
Thread.Sleep(this._continuousReadingInterval);
} }
finally
{
if (doThreadSleep)
Thread.Sleep(_continuousReadingInterval);
} }
} }
} }

View File

@ -1,10 +1,8 @@
namespace Swan namespace Swan {
{
/// <summary> /// <summary>
/// Enumerates the possible causes of the DataReceived event occurring. /// Enumerates the possible causes of the DataReceived event occurring.
/// </summary> /// </summary>
public enum ConnectionDataReceivedTrigger public enum ConnectionDataReceivedTrigger {
{
/// <summary> /// <summary>
/// The trigger was a forceful flush of the buffer /// The trigger was a forceful flush of the buffer
/// </summary> /// </summary>

View File

@ -1,24 +1,24 @@
namespace Swan.Net #nullable enable
{
using System; using System;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Net {
/// <summary> /// <summary>
/// TCP Listener manager with built-in events and asynchronous functionality. /// TCP Listener manager with built-in events and asynchronous functionality.
/// This networking component is typically used when writing server software. /// This networking component is typically used when writing server software.
/// </summary> /// </summary>
/// <seealso cref="System.IDisposable" /> /// <seealso cref="System.IDisposable" />
public sealed class ConnectionListener : IDisposable public sealed class ConnectionListener : IDisposable {
{ private readonly Object _stateLock = new Object();
private readonly object _stateLock = new object(); private TcpListener? _listenerSocket;
private TcpListener _listenerSocket; private Boolean _cancellationPending;
private bool _cancellationPending; [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0069:Verwerfbare Felder verwerfen", Justification = "<Ausstehend>")]
private CancellationTokenSource _cancelListening; private CancellationTokenSource? _cancelListening;
private Task? _backgroundWorkerTask; private Task? _backgroundWorkerTask;
private bool _hasDisposed; private Boolean _hasDisposed;
#region Events #region Events
@ -51,10 +51,9 @@
/// Initializes a new instance of the <see cref="ConnectionListener"/> class. /// Initializes a new instance of the <see cref="ConnectionListener"/> class.
/// </summary> /// </summary>
/// <param name="listenEndPoint">The listen end point.</param> /// <param name="listenEndPoint">The listen end point.</param>
public ConnectionListener(IPEndPoint listenEndPoint) public ConnectionListener(IPEndPoint listenEndPoint) {
{ this.Id = Guid.NewGuid();
Id = Guid.NewGuid(); this.LocalEndPoint = listenEndPoint ?? throw new ArgumentNullException(nameof(listenEndPoint));
LocalEndPoint = listenEndPoint ?? throw new ArgumentNullException(nameof(listenEndPoint));
} }
/// <summary> /// <summary>
@ -62,9 +61,7 @@
/// It uses the loopback address for listening. /// It uses the loopback address for listening.
/// </summary> /// </summary>
/// <param name="listenPort">The listen port.</param> /// <param name="listenPort">The listen port.</param>
public ConnectionListener(int listenPort) public ConnectionListener(Int32 listenPort) : this(new IPEndPoint(IPAddress.Loopback, listenPort)) {
: this(new IPEndPoint(IPAddress.Loopback, listenPort))
{
} }
/// <summary> /// <summary>
@ -72,17 +69,14 @@
/// </summary> /// </summary>
/// <param name="listenAddress">The listen address.</param> /// <param name="listenAddress">The listen address.</param>
/// <param name="listenPort">The listen port.</param> /// <param name="listenPort">The listen port.</param>
public ConnectionListener(IPAddress listenAddress, int listenPort) public ConnectionListener(IPAddress listenAddress, Int32 listenPort) : this(new IPEndPoint(listenAddress, listenPort)) {
: this(new IPEndPoint(listenAddress, listenPort))
{
} }
/// <summary> /// <summary>
/// Finalizes an instance of the <see cref="ConnectionListener"/> class. /// Finalizes an instance of the <see cref="ConnectionListener"/> class.
/// </summary> /// </summary>
~ConnectionListener() ~ConnectionListener() {
{ this.Dispose(false);
Dispose(false);
} }
#endregion #endregion
@ -95,7 +89,9 @@
/// <value> /// <value>
/// The local end point. /// The local end point.
/// </value> /// </value>
public IPEndPoint LocalEndPoint { get; } public IPEndPoint LocalEndPoint {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this listener is active. /// Gets a value indicating whether this listener is active.
@ -103,7 +99,7 @@
/// <value> /// <value>
/// <c>true</c> if this instance is listening; otherwise, <c>false</c>. /// <c>true</c> if this instance is listening; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsListening => _backgroundWorkerTask != null; public Boolean IsListening => this._backgroundWorkerTask != null;
/// <summary> /// <summary>
/// Gets a unique identifier that gets automatically assigned upon instantiation of this class. /// Gets a unique identifier that gets automatically assigned upon instantiation of this class.
@ -111,7 +107,9 @@
/// <value> /// <value>
/// The unique identifier. /// The unique identifier.
/// </value> /// </value>
public Guid Id { get; } public Guid Id {
get;
}
#endregion #endregion
@ -122,22 +120,17 @@
/// Subscribe to the events of this class to gain access to connected client sockets. /// Subscribe to the events of this class to gain access to connected client sockets.
/// </summary> /// </summary>
/// <exception cref="System.InvalidOperationException">Cancellation has already been requested. This listener is not reusable.</exception> /// <exception cref="System.InvalidOperationException">Cancellation has already been requested. This listener is not reusable.</exception>
public void Start() public void Start() {
{ lock(this._stateLock) {
lock (_stateLock) if(this._backgroundWorkerTask != null) {
{
if (_backgroundWorkerTask != null)
{
return; return;
} }
if (_cancellationPending) if(this._cancellationPending) {
{ throw new InvalidOperationException("Cancellation has already been requested. This listener is not reusable.");
throw new InvalidOperationException(
"Cancellation has already been requested. This listener is not reusable.");
} }
_backgroundWorkerTask = DoWorkAsync(); this._backgroundWorkerTask = this.DoWorkAsync();
} }
} }
@ -145,16 +138,14 @@
/// Stops the listener from receiving new connections. /// Stops the listener from receiving new connections.
/// This does not prevent the listener from . /// This does not prevent the listener from .
/// </summary> /// </summary>
public void Stop() public void Stop() {
{ lock(this._stateLock) {
lock (_stateLock) this._cancellationPending = true;
{ this._listenerSocket?.Stop();
_cancellationPending = true; this._cancelListening?.Cancel();
_listenerSocket?.Stop(); this._backgroundWorkerTask?.Wait();
_cancelListening?.Cancel(); this._backgroundWorkerTask = null;
_backgroundWorkerTask?.Wait(); this._cancellationPending = false;
_backgroundWorkerTask = null;
_cancellationPending = false;
} }
} }
@ -164,12 +155,11 @@
/// <returns> /// <returns>
/// A <see cref="System.String" /> that represents this instance. /// A <see cref="System.String" /> that represents this instance.
/// </returns> /// </returns>
public override string ToString() => LocalEndPoint.ToString(); public override String ToString() => this.LocalEndPoint.ToString();
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ this.Dispose(true);
Dispose(true);
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
@ -177,74 +167,57 @@
/// Releases unmanaged and - optionally - managed resources. /// Releases unmanaged and - optionally - managed resources.
/// </summary> /// </summary>
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param> /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
private void Dispose(bool disposing) private void Dispose(Boolean disposing) {
{ if(this._hasDisposed) {
if (_hasDisposed)
return; return;
if (disposing)
{
// Release managed resources
Stop();
} }
_hasDisposed = true; if(disposing) {
// Release managed resources
this.Stop();
}
this._hasDisposed = true;
} }
/// <summary> /// <summary>
/// Continuously checks for client connections until the Close method has been called. /// Continuously checks for client connections until the Close method has been called.
/// </summary> /// </summary>
/// <returns>A task that represents the asynchronous connection operation.</returns> /// <returns>A task that represents the asynchronous connection operation.</returns>
private async Task DoWorkAsync() private async Task DoWorkAsync() {
{ this._cancellationPending = false;
_cancellationPending = false; this._listenerSocket = new TcpListener(this.LocalEndPoint);
_listenerSocket = new TcpListener(LocalEndPoint); this._listenerSocket.Start();
_listenerSocket.Start(); this._cancelListening = new CancellationTokenSource();
_cancelListening = new CancellationTokenSource();
try try {
{ while(this._cancellationPending == false) {
while (_cancellationPending == false) try {
{ TcpClient client = await Task.Run(() => this._listenerSocket.AcceptTcpClientAsync(), this._cancelListening.Token).ConfigureAwait(false);
try ConnectionAcceptingEventArgs acceptingArgs = new ConnectionAcceptingEventArgs(client);
{
var client = await Task.Run(() => _listenerSocket.AcceptTcpClientAsync(), _cancelListening.Token).ConfigureAwait(false);
var acceptingArgs = new ConnectionAcceptingEventArgs(client);
OnConnectionAccepting(this, acceptingArgs); OnConnectionAccepting(this, acceptingArgs);
if (acceptingArgs.Cancel) if(acceptingArgs.Cancel) {
{
#if !NET461
client.Dispose(); client.Dispose();
#else
client.Close();
#endif
continue; continue;
} }
OnConnectionAccepted(this, new ConnectionAcceptedEventArgs(client)); OnConnectionAccepted(this, new ConnectionAcceptedEventArgs(client));
} } catch(Exception ex) {
catch (Exception ex)
{
OnConnectionFailure(this, new ConnectionFailureEventArgs(ex)); OnConnectionFailure(this, new ConnectionFailureEventArgs(ex));
} }
} }
OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(LocalEndPoint)); OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(this.LocalEndPoint));
} } catch(ObjectDisposedException) {
catch (ObjectDisposedException) OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(this.LocalEndPoint));
{ } catch(Exception ex) {
OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(LocalEndPoint));
}
catch (Exception ex)
{
OnListenerStopped(this, OnListenerStopped(this,
new ConnectionListenerStoppedEventArgs(LocalEndPoint, _cancellationPending ? null : ex)); new ConnectionListenerStoppedEventArgs(this.LocalEndPoint, this._cancellationPending ? null : ex));
} } finally {
finally this._backgroundWorkerTask = null;
{ this._cancellationPending = false;
_backgroundWorkerTask = null;
_cancellationPending = false;
} }
} }

View File

@ -1,61 +1,95 @@
namespace Swan.Net.Dns using System;
{
using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Collections.Generic; using System.Collections.Generic;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// DnsClient public interfaces. /// DnsClient public interfaces.
/// </summary> /// </summary>
internal partial class DnsClient internal partial class DnsClient {
{ public interface IDnsMessage {
public interface IDnsMessage IList<DnsQuestion> Questions {
{ get;
IList<DnsQuestion> Questions { get; }
int Size { get; }
byte[] ToArray();
} }
public interface IDnsMessageEntry Int32 Size {
{ get;
DnsDomain Name { get; } }
DnsRecordType Type { get; } Byte[] ToArray();
DnsRecordClass Class { get; }
int Size { get; }
byte[] ToArray();
} }
public interface IDnsResourceRecord : IDnsMessageEntry public interface IDnsMessageEntry {
{ DnsDomain Name {
TimeSpan TimeToLive { get; } get;
int DataLength { get; } }
byte[] Data { get; } DnsRecordType Type {
get;
}
DnsRecordClass Class {
get;
} }
public interface IDnsRequest : IDnsMessage Int32 Size {
{ get;
int Id { get; set; } }
DnsOperationCode OperationCode { get; set; } Byte[] ToArray();
bool RecursionDesired { get; set; }
} }
public interface IDnsResponse : IDnsMessage public interface IDnsResourceRecord : IDnsMessageEntry {
{ TimeSpan TimeToLive {
int Id { get; set; } get;
IList<IDnsResourceRecord> AnswerRecords { get; } }
IList<IDnsResourceRecord> AuthorityRecords { get; } Int32 DataLength {
IList<IDnsResourceRecord> AdditionalRecords { get; } get;
bool IsRecursionAvailable { get; set; } }
bool IsAuthorativeServer { get; set; } Byte[] Data {
bool IsTruncated { get; set; } get;
DnsOperationCode OperationCode { get; set; } }
DnsResponseCode ResponseCode { get; set; }
} }
public interface IDnsRequestResolver public interface IDnsRequest : IDnsMessage {
{ Int32 Id {
get; set;
}
DnsOperationCode OperationCode {
get; set;
}
Boolean RecursionDesired {
get; set;
}
}
public interface IDnsResponse : IDnsMessage {
Int32 Id {
get; set;
}
IList<IDnsResourceRecord> AnswerRecords {
get;
}
IList<IDnsResourceRecord> AuthorityRecords {
get;
}
IList<IDnsResourceRecord> AdditionalRecords {
get;
}
Boolean IsRecursionAvailable {
get; set;
}
Boolean IsAuthorativeServer {
get; set;
}
Boolean IsTruncated {
get; set;
}
DnsOperationCode OperationCode {
get; set;
}
DnsResponseCode ResponseCode {
get; set;
}
}
public interface IDnsRequestResolver {
Task<DnsClientResponse> Request(DnsClientRequest request); Task<DnsClientResponse> Request(DnsClientRequest request);
} }
} }

View File

@ -1,6 +1,5 @@
namespace Swan.Net.Dns #nullable enable
{ using Swan.Formatters;
using Formatters;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
@ -11,50 +10,47 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// DnsClient Request inner class. /// DnsClient Request inner class.
/// </summary> /// </summary>
internal partial class DnsClient internal partial class DnsClient {
{ public class DnsClientRequest : IDnsRequest {
public class DnsClientRequest : IDnsRequest
{
private readonly IDnsRequestResolver _resolver; private readonly IDnsRequestResolver _resolver;
private readonly IDnsRequest _request; private readonly IDnsRequest _request;
public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null) public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null) {
{ this.Dns = dns;
Dns = dns; this._request = request == null ? new DnsRequest() : new DnsRequest(request);
_request = request == null ? new DnsRequest() : new DnsRequest(request); this._resolver = resolver ?? new DnsUdpRequestResolver();
_resolver = resolver ?? new DnsUdpRequestResolver();
} }
public int Id public Int32 Id {
{ get => this._request.Id;
get => _request.Id; set => this._request.Id = value;
set => _request.Id = value;
} }
public DnsOperationCode OperationCode public DnsOperationCode OperationCode {
{ get => this._request.OperationCode;
get => _request.OperationCode; set => this._request.OperationCode = value;
set => _request.OperationCode = value;
} }
public bool RecursionDesired public Boolean RecursionDesired {
{ get => this._request.RecursionDesired;
get => _request.RecursionDesired; set => this._request.RecursionDesired = value;
set => _request.RecursionDesired = value;
} }
public IList<DnsQuestion> Questions => _request.Questions; public IList<DnsQuestion> Questions => this._request.Questions;
public int Size => _request.Size; public Int32 Size => this._request.Size;
public IPEndPoint Dns { get; set; } public IPEndPoint Dns {
get; set;
}
public byte[] ToArray() => _request.ToArray(); public Byte[] ToArray() => this._request.ToArray();
public override string ToString() => _request.ToString(); public override String ToString() => this._request.ToString()!;
/// <summary> /// <summary>
/// Resolves this request into a response using the provided DNS information. The given /// Resolves this request into a response using the provided DNS information. The given
@ -64,465 +60,374 @@
/// <exception cref="IOException">Thrown if a IO error occurs.</exception> /// <exception cref="IOException">Thrown if a IO error occurs.</exception>
/// <exception cref="SocketException">Thrown if a the reading or writing to the socket fails.</exception> /// <exception cref="SocketException">Thrown if a the reading or writing to the socket fails.</exception>
/// <returns>The response received from server.</returns> /// <returns>The response received from server.</returns>
public async Task<DnsClientResponse> Resolve() public async Task<DnsClientResponse> Resolve() {
{ try {
try DnsClientResponse response = await this._resolver.Request(this).ConfigureAwait(false);
{
var response = await _resolver.Request(this).ConfigureAwait(false);
if (response.Id != Id) if(response.Id != this.Id) {
{
throw new DnsQueryException(response, "Mismatching request/response IDs"); throw new DnsQueryException(response, "Mismatching request/response IDs");
} }
if (response.ResponseCode != DnsResponseCode.NoError) if(response.ResponseCode != DnsResponseCode.NoError) {
{
throw new DnsQueryException(response); throw new DnsQueryException(response);
} }
return response; return response;
} } catch(Exception e) {
catch (Exception e) if(e is ArgumentException || e is SocketException) {
{
if (e is ArgumentException || e is SocketException)
throw new DnsQueryException("Invalid response", e); throw new DnsQueryException("Invalid response", e);
}
throw; throw;
} }
} }
} }
public class DnsRequest : IDnsRequest public class DnsRequest : IDnsRequest {
{
private static readonly Random Random = new Random(); private static readonly Random Random = new Random();
private DnsHeader header; private DnsHeader header;
public DnsRequest() public DnsRequest() {
{ this.Questions = new List<DnsQuestion>();
Questions = new List<DnsQuestion>(); this.header = new DnsHeader {
header = new DnsHeader
{
OperationCode = DnsOperationCode.Query, OperationCode = DnsOperationCode.Query,
Response = false, Response = false,
Id = Random.Next(ushort.MaxValue), Id = Random.Next(UInt16.MaxValue),
}; };
} }
public DnsRequest(IDnsRequest request) public DnsRequest(IDnsRequest request) {
{ this.header = new DnsHeader();
header = new DnsHeader(); this.Questions = new List<DnsQuestion>(request.Questions);
Questions = new List<DnsQuestion>(request.Questions);
header.Response = false; this.header.Response = false;
Id = request.Id; this.Id = request.Id;
OperationCode = request.OperationCode; this.OperationCode = request.OperationCode;
RecursionDesired = request.RecursionDesired; this.RecursionDesired = request.RecursionDesired;
} }
public IList<DnsQuestion> Questions { get; } public IList<DnsQuestion> Questions {
get;
public int Size => header.Size + Questions.Sum(q => q.Size);
public int Id
{
get => header.Id;
set => header.Id = value;
} }
public DnsOperationCode OperationCode public Int32 Size => this.header.Size + this.Questions.Sum(q => q.Size);
{
get => header.OperationCode; public Int32 Id {
set => header.OperationCode = value; get => this.header.Id;
set => this.header.Id = value;
} }
public bool RecursionDesired public DnsOperationCode OperationCode {
{ get => this.header.OperationCode;
get => header.RecursionDesired; set => this.header.OperationCode = value;
set => header.RecursionDesired = value;
} }
public byte[] ToArray() public Boolean RecursionDesired {
{ get => this.header.RecursionDesired;
UpdateHeader(); set => this.header.RecursionDesired = value;
using var result = new MemoryStream(Size);
return result
.Append(header.ToArray())
.Append(Questions.Select(q => q.ToArray()))
.ToArray();
} }
public override string ToString() public Byte[] ToArray() {
{ this.UpdateHeader();
UpdateHeader(); using MemoryStream result = new MemoryStream(this.Size);
return result.Append(this.header.ToArray()).Append(this.Questions.Select(q => q.ToArray())).ToArray();
}
public override String ToString() {
this.UpdateHeader();
return Json.Serialize(this, true); return Json.Serialize(this, true);
} }
private void UpdateHeader() private void UpdateHeader() => this.header.QuestionCount = this.Questions.Count;
{
header.QuestionCount = Questions.Count;
}
} }
public class DnsTcpRequestResolver : IDnsRequestResolver public class DnsTcpRequestResolver : IDnsRequestResolver {
{ public async Task<DnsClientResponse> Request(DnsClientRequest request) {
public async Task<DnsClientResponse> Request(DnsClientRequest request) TcpClient tcp = new TcpClient();
{
var tcp = new TcpClient();
try try {
{
#if !NET461
await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false); await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false);
#else
tcp.Client.Connect(request.Dns);
#endif
var stream = tcp.GetStream();
var buffer = request.ToArray();
var length = BitConverter.GetBytes((ushort)buffer.Length);
if (BitConverter.IsLittleEndian) NetworkStream stream = tcp.GetStream();
Byte[] buffer = request.ToArray();
Byte[] length = BitConverter.GetBytes((UInt16)buffer.Length);
if(BitConverter.IsLittleEndian) {
Array.Reverse(length); Array.Reverse(length);
}
await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false); await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false);
await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false); await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
buffer = new byte[2]; buffer = new Byte[2];
await Read(stream, buffer).ConfigureAwait(false); await Read(stream, buffer).ConfigureAwait(false);
if (BitConverter.IsLittleEndian) if(BitConverter.IsLittleEndian) {
Array.Reverse(buffer); Array.Reverse(buffer);
}
buffer = new byte[BitConverter.ToUInt16(buffer, 0)]; buffer = new Byte[BitConverter.ToUInt16(buffer, 0)];
await Read(stream, buffer).ConfigureAwait(false); await Read(stream, buffer).ConfigureAwait(false);
var response = DnsResponse.FromArray(buffer); DnsResponse response = DnsResponse.FromArray(buffer);
return new DnsClientResponse(request, response, buffer); return new DnsClientResponse(request, response, buffer);
} } finally {
finally
{
#if NET461
tcp.Close();
#else
tcp.Dispose(); tcp.Dispose();
#endif
} }
} }
private static async Task Read(Stream stream, byte[] buffer) private static async Task Read(Stream stream, Byte[] buffer) {
{ Int32 length = buffer.Length;
var length = buffer.Length; Int32 offset = 0;
var offset = 0; Int32 size;
int size;
while (length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0) while(length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0) {
{
offset += size; offset += size;
length -= size; length -= size;
} }
if (length > 0) if(length > 0) {
{
throw new IOException("Unexpected end of stream"); throw new IOException("Unexpected end of stream");
} }
} }
} }
public class DnsUdpRequestResolver : IDnsRequestResolver public class DnsUdpRequestResolver : IDnsRequestResolver {
{
private readonly IDnsRequestResolver _fallback; private readonly IDnsRequestResolver _fallback;
public DnsUdpRequestResolver(IDnsRequestResolver fallback) public DnsUdpRequestResolver(IDnsRequestResolver fallback) => this._fallback = fallback;
{
_fallback = fallback;
}
public DnsUdpRequestResolver() public DnsUdpRequestResolver() => this._fallback = new DnsNullRequestResolver();
{
_fallback = new DnsNullRequestResolver();
}
public async Task<DnsClientResponse> Request(DnsClientRequest request) public async Task<DnsClientResponse> Request(DnsClientRequest request) {
{ UdpClient udp = new UdpClient();
var udp = new UdpClient(); IPEndPoint dns = request.Dns;
var dns = request.Dns;
try try {
{
udp.Client.SendTimeout = 7000; udp.Client.SendTimeout = 7000;
udp.Client.ReceiveTimeout = 7000; udp.Client.ReceiveTimeout = 7000;
#if !NET461
await udp.Client.ConnectAsync(dns).ConfigureAwait(false); await udp.Client.ConnectAsync(dns).ConfigureAwait(false);
#else
udp.Client.Connect(dns);
#endif
await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false);
var bufferList = new List<byte>(); _ = await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false);
do List<Byte> bufferList = new List<Byte>();
{
var tempBuffer = new byte[1024]; do {
var receiveCount = udp.Client.Receive(tempBuffer); Byte[] tempBuffer = new Byte[1024];
Int32 receiveCount = udp.Client.Receive(tempBuffer);
bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount)); bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount));
} }
while(udp.Client.Available > 0 || bufferList.Count == 0); while(udp.Client.Available > 0 || bufferList.Count == 0);
var buffer = bufferList.ToArray(); Byte[] buffer = bufferList.ToArray();
var response = DnsResponse.FromArray(buffer); DnsResponse response = DnsResponse.FromArray(buffer);
return response.IsTruncated return response.IsTruncated
? await _fallback.Request(request).ConfigureAwait(false) ? await this._fallback.Request(request).ConfigureAwait(false)
: new DnsClientResponse(request, response, buffer); : new DnsClientResponse(request, response, buffer);
} } finally {
finally
{
#if NET461
udp.Close();
#else
udp.Dispose(); udp.Dispose();
#endif
} }
} }
} }
public class DnsNullRequestResolver : IDnsRequestResolver public class DnsNullRequestResolver : IDnsRequestResolver {
{
public Task<DnsClientResponse> Request(DnsClientRequest request) => throw new DnsQueryException("Request failed"); public Task<DnsClientResponse> Request(DnsClientRequest request) => throw new DnsQueryException("Request failed");
} }
// 12 bytes message header // 12 bytes message header
[StructEndianness(Endianness.Big)] [StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 1)] [StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct DnsHeader public struct DnsHeader {
{ public const Int32 SIZE = 12;
public const int SIZE = 12;
private ushort id; private UInt16 id;
private byte flag0;
private byte flag1;
// Question count: number of questions in the Question section // Question count: number of questions in the Question section
private ushort questionCount; private UInt16 questionCount;
// Answer record count: number of records in the Answer section // Answer record count: number of records in the Answer section
private ushort answerCount; private UInt16 answerCount;
// Authority record count: number of records in the Authority section // Authority record count: number of records in the Authority section
private ushort authorityCount; private UInt16 authorityCount;
// Additional record count: number of records in the Additional section // Additional record count: number of records in the Additional section
private ushort addtionalCount; private UInt16 addtionalCount;
public int Id public Int32 Id {
{ get => this.id;
get => id; set => this.id = (UInt16)value;
set => id = (ushort)value;
} }
public int QuestionCount public Int32 QuestionCount {
{ get => this.questionCount;
get => questionCount; set => this.questionCount = (UInt16)value;
set => questionCount = (ushort)value;
} }
public int AnswerRecordCount public Int32 AnswerRecordCount {
{ get => this.answerCount;
get => answerCount; set => this.answerCount = (UInt16)value;
set => answerCount = (ushort)value;
} }
public int AuthorityRecordCount public Int32 AuthorityRecordCount {
{ get => this.authorityCount;
get => authorityCount; set => this.authorityCount = (UInt16)value;
set => authorityCount = (ushort)value;
} }
public int AdditionalRecordCount public Int32 AdditionalRecordCount {
{ get => this.addtionalCount;
get => addtionalCount; set => this.addtionalCount = (UInt16)value;
set => addtionalCount = (ushort)value;
} }
public bool Response public Boolean Response {
{ get => this.Qr == 1;
get => Qr == 1; set => this.Qr = Convert.ToByte(value);
set => Qr = Convert.ToByte(value);
} }
public DnsOperationCode OperationCode public DnsOperationCode OperationCode {
{ get => (DnsOperationCode)this.Opcode;
get => (DnsOperationCode)Opcode; set => this.Opcode = (Byte)value;
set => Opcode = (byte)value;
} }
public bool AuthorativeServer public Boolean AuthorativeServer {
{ get => this.Aa == 1;
get => Aa == 1; set => this.Aa = Convert.ToByte(value);
set => Aa = Convert.ToByte(value);
} }
public bool Truncated public Boolean Truncated {
{ get => this.Tc == 1;
get => Tc == 1; set => this.Tc = Convert.ToByte(value);
set => Tc = Convert.ToByte(value);
} }
public bool RecursionDesired public Boolean RecursionDesired {
{ get => this.Rd == 1;
get => Rd == 1; set => this.Rd = Convert.ToByte(value);
set => Rd = Convert.ToByte(value);
} }
public bool RecursionAvailable public Boolean RecursionAvailable {
{ get => this.Ra == 1;
get => Ra == 1; set => this.Ra = Convert.ToByte(value);
set => Ra = Convert.ToByte(value);
} }
public DnsResponseCode ResponseCode public DnsResponseCode ResponseCode {
{ get => (DnsResponseCode)this.RCode;
get => (DnsResponseCode)RCode; set => this.RCode = (Byte)value;
set => RCode = (byte)value;
} }
public int Size => SIZE; public Int32 Size => SIZE;
// Query/Response Flag // Query/Response Flag
private byte Qr private Byte Qr {
{ get => this.Flag0.GetBitValueAt(7);
get => Flag0.GetBitValueAt(7); set => this.Flag0 = this.Flag0.SetBitValueAt(7, 1, value);
set => Flag0 = Flag0.SetBitValueAt(7, 1, value);
} }
// Operation Code // Operation Code
private byte Opcode private Byte Opcode {
{ get => this.Flag0.GetBitValueAt(3, 4);
get => Flag0.GetBitValueAt(3, 4); set => this.Flag0 = this.Flag0.SetBitValueAt(3, 4, value);
set => Flag0 = Flag0.SetBitValueAt(3, 4, value);
} }
// Authorative Answer Flag // Authorative Answer Flag
private byte Aa private Byte Aa {
{ get => this.Flag0.GetBitValueAt(2);
get => Flag0.GetBitValueAt(2); set => this.Flag0 = this.Flag0.SetBitValueAt(2, 1, value);
set => Flag0 = Flag0.SetBitValueAt(2, 1, value);
} }
// Truncation Flag // Truncation Flag
private byte Tc private Byte Tc {
{ get => this.Flag0.GetBitValueAt(1);
get => Flag0.GetBitValueAt(1); set => this.Flag0 = this.Flag0.SetBitValueAt(1, 1, value);
set => Flag0 = Flag0.SetBitValueAt(1, 1, value);
} }
// Recursion Desired // Recursion Desired
private byte Rd private Byte Rd {
{ get => this.Flag0.GetBitValueAt(0);
get => Flag0.GetBitValueAt(0); set => this.Flag0 = this.Flag0.SetBitValueAt(0, 1, value);
set => Flag0 = Flag0.SetBitValueAt(0, 1, value);
} }
// Recursion Available // Recursion Available
private byte Ra private Byte Ra {
{ get => this.Flag1.GetBitValueAt(7);
get => Flag1.GetBitValueAt(7); set => this.Flag1 = this.Flag1.SetBitValueAt(7, 1, value);
set => Flag1 = Flag1.SetBitValueAt(7, 1, value);
} }
// Zero (Reserved) // Zero (Reserved)
private byte Z private Byte Z {
{ get => this.Flag1.GetBitValueAt(4, 3);
get => Flag1.GetBitValueAt(4, 3); set {
set { } }
} }
// Response Code // Response Code
private byte RCode private Byte RCode {
{ get => this.Flag1.GetBitValueAt(0, 4);
get => Flag1.GetBitValueAt(0, 4); set => this.Flag1 = this.Flag1.SetBitValueAt(0, 4, value);
set => Flag1 = Flag1.SetBitValueAt(0, 4, value);
} }
private byte Flag0 private Byte Flag0 {
{ get;
get => flag0; set;
set => flag0 = value;
} }
private byte Flag1 private Byte Flag1 {
{ get;
get => flag1; set;
set => flag1 = value;
} }
public static DnsHeader FromArray(byte[] header) => public static DnsHeader FromArray(Byte[] header) => header.Length < SIZE ? throw new ArgumentException("Header length too small") : header.ToStruct<DnsHeader>(0, SIZE);
header.Length < SIZE
? throw new ArgumentException("Header length too small")
: header.ToStruct<DnsHeader>(0, SIZE);
public byte[] ToArray() => this.ToBytes(); public Byte[] ToArray() => this.ToBytes();
public override string ToString() public override String ToString() => Json.SerializeExcluding(this, true, nameof(this.Size));
=> Json.SerializeExcluding(this, true, nameof(Size));
} }
public class DnsDomain : IComparable<DnsDomain> public class DnsDomain : IComparable<DnsDomain> {
{ private readonly String[] _labels;
private readonly string[] _labels;
public DnsDomain(string domain) public DnsDomain(String domain) : this(domain.Split('.')) {
: this(domain.Split('.'))
{
} }
public DnsDomain(string[] labels) public DnsDomain(String[] labels) => this._labels = labels;
{
_labels = labels;
}
public int Size => _labels.Sum(l => l.Length) + _labels.Length + 1; public Int32 Size => this._labels.Sum(l => l.Length) + this._labels.Length + 1;
public static DnsDomain FromArray(byte[] message, int offset) public static DnsDomain FromArray(Byte[] message, Int32 offset) => FromArray(message, offset, out _);
=> FromArray(message, offset, out offset);
public static DnsDomain FromArray(byte[] message, int offset, out int endOffset) public static DnsDomain FromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
{ List<Byte[]> labels = new List<Byte[]>();
var labels = new List<byte[]>(); Boolean endOffsetAssigned = false;
var endOffsetAssigned = false;
endOffset = 0; endOffset = 0;
byte lengthOrPointer; Byte lengthOrPointer;
while ((lengthOrPointer = message[offset++]) > 0) while((lengthOrPointer = message[offset++]) > 0) {
{
// Two heighest bits are set (pointer) // Two heighest bits are set (pointer)
if (lengthOrPointer.GetBitValueAt(6, 2) == 3) if(lengthOrPointer.GetBitValueAt(6, 2) == 3) {
{ if(!endOffsetAssigned) {
if (!endOffsetAssigned)
{
endOffsetAssigned = true; endOffsetAssigned = true;
endOffset = offset + 1; endOffset = offset + 1;
} }
ushort pointer = lengthOrPointer.GetBitValueAt(0, 6); UInt16 pointer = lengthOrPointer.GetBitValueAt(0, 6);
offset = (pointer << 8) | message[offset]; offset = (pointer << 8) | message[offset];
continue; continue;
} }
if (lengthOrPointer.GetBitValueAt(6, 2) != 0) if(lengthOrPointer.GetBitValueAt(6, 2) != 0) {
{
throw new ArgumentException("Unexpected bit pattern in label length"); throw new ArgumentException("Unexpected bit pattern in label length");
} }
var length = lengthOrPointer; Byte length = lengthOrPointer;
var label = new byte[length]; Byte[] label = new Byte[length];
Array.Copy(message, offset, label, 0, length); Array.Copy(message, offset, label, 0, length);
labels.Add(label); labels.Add(label);
@ -530,25 +435,21 @@
offset += length; offset += length;
} }
if (!endOffsetAssigned) if(!endOffsetAssigned) {
{
endOffset = offset; endOffset = offset;
} }
return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray()); return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray());
} }
public static DnsDomain PointerName(IPAddress ip) public static DnsDomain PointerName(IPAddress ip) => new DnsDomain(FormatReverseIP(ip));
=> new DnsDomain(FormatReverseIP(ip));
public byte[] ToArray() public Byte[] ToArray() {
{ Byte[] result = new Byte[this.Size];
var result = new byte[Size]; Int32 offset = 0;
var offset = 0;
foreach (var l in _labels.Select(label => Encoding.ASCII.GetBytes(label))) foreach(Byte[] l in this._labels.Select(label => Encoding.ASCII.GetBytes(label))) {
{ result[offset++] = (Byte)l.Length;
result[offset++] = (byte)l.Length;
l.CopyTo(result, offset); l.CopyTo(result, offset);
offset += l.Length; offset += l.Length;
@ -559,58 +460,41 @@
return result; return result;
} }
public override string ToString() public override String ToString() => String.Join(".", this._labels);
=> string.Join(".", _labels);
public int CompareTo(DnsDomain other) public Int32 CompareTo(DnsDomain other) => String.Compare(this.ToString(), other.ToString(), StringComparison.Ordinal);
=> string.Compare(ToString(), other.ToString(), StringComparison.Ordinal);
public override bool Equals(object obj) public override Boolean Equals(Object? obj) => obj is DnsDomain domain && this.CompareTo(domain) == 0;
=> obj is DnsDomain domain && CompareTo(domain) == 0;
public override int GetHashCode() => ToString().GetHashCode(); public override Int32 GetHashCode() => this.ToString().GetHashCode();
private static string FormatReverseIP(IPAddress ip) private static String FormatReverseIP(IPAddress ip) {
{ Byte[] address = ip.GetAddressBytes();
var address = ip.GetAddressBytes();
if (address.Length == 4) if(address.Length == 4) {
{ return String.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa";
return string.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa";
} }
var nibbles = new byte[address.Length * 2]; Byte[] nibbles = new Byte[address.Length * 2];
for (int i = 0, j = 0; i < address.Length; i++, j = 2 * i) for(Int32 i = 0, j = 0; i < address.Length; i++, j = 2 * i) {
{ Byte b = address[i];
var b = address[i];
nibbles[j] = b.GetBitValueAt(4, 4); nibbles[j] = b.GetBitValueAt(4, 4);
nibbles[j + 1] = b.GetBitValueAt(0, 4); nibbles[j + 1] = b.GetBitValueAt(0, 4);
} }
return string.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa"; return String.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa";
} }
} }
public class DnsQuestion : IDnsMessageEntry public class DnsQuestion : IDnsMessageEntry {
{ public static IList<DnsQuestion> GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount) => GetAllFromArray(message, offset, questionCount, out _);
private readonly DnsRecordType _type;
private readonly DnsRecordClass _klass;
public static IList<DnsQuestion> GetAllFromArray(byte[] message, int offset, int questionCount) => public static IList<DnsQuestion> GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount, out Int32 endOffset) {
GetAllFromArray(message, offset, questionCount, out offset);
public static IList<DnsQuestion> GetAllFromArray(
byte[] message,
int offset,
int questionCount,
out int endOffset)
{
IList<DnsQuestion> questions = new List<DnsQuestion>(questionCount); IList<DnsQuestion> questions = new List<DnsQuestion>(questionCount);
for (var i = 0; i < questionCount; i++) for(Int32 i = 0; i < questionCount; i++) {
{
questions.Add(FromArray(message, offset, out offset)); questions.Add(FromArray(message, offset, out offset));
} }
@ -618,62 +502,55 @@
return questions; return questions;
} }
public static DnsQuestion FromArray(byte[] message, int offset, out int endOffset) public static DnsQuestion FromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
{ DnsDomain domain = DnsDomain.FromArray(message, offset, out offset);
var domain = DnsDomain.FromArray(message, offset, out offset); Tail tail = message.ToStruct<Tail>(offset, Tail.SIZE);
var tail = message.ToStruct<Tail>(offset, Tail.SIZE);
endOffset = offset + Tail.SIZE; endOffset = offset + Tail.SIZE;
return new DnsQuestion(domain, tail.Type, tail.Class); return new DnsQuestion(domain, tail.Type, tail.Class);
} }
public DnsQuestion( public DnsQuestion(DnsDomain domain, DnsRecordType type = DnsRecordType.A, DnsRecordClass klass = DnsRecordClass.IN) {
DnsDomain domain, this.Name = domain;
DnsRecordType type = DnsRecordType.A, this.Type = type;
DnsRecordClass klass = DnsRecordClass.IN) this.Class = klass;
{
Name = domain;
_type = type;
_klass = klass;
} }
public DnsDomain Name { get; } public DnsDomain Name {
get;
}
public DnsRecordType Type => _type; public DnsRecordType Type {
get;
}
public DnsRecordClass Class => _klass; public DnsRecordClass Class {
get;
}
public int Size => Name.Size + Tail.SIZE; public Int32 Size => this.Name.Size + Tail.SIZE;
public byte[] ToArray() => public Byte[] ToArray() => new MemoryStream(this.Size).Append(this.Name.ToArray()).Append(new Tail { Type = Type, Class = Class }.ToBytes()).ToArray();
new MemoryStream(Size)
.Append(Name.ToArray())
.Append(new Tail { Type = Type, Class = Class }.ToBytes())
.ToArray();
public override string ToString() public override String ToString() => Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class));
=> Json.SerializeOnly(this, true, nameof(Name), nameof(Type), nameof(Class));
[StructEndianness(Endianness.Big)] [StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 2)] [StructLayout(LayoutKind.Sequential, Pack = 2)]
private struct Tail private struct Tail {
{ public const Int32 SIZE = 4;
public const int SIZE = 4;
private ushort type; private UInt16 type;
private ushort klass; private UInt16 klass;
public DnsRecordType Type public DnsRecordType Type {
{ get => (DnsRecordType)this.type;
get => (DnsRecordType)type; set => this.type = (UInt16)value;
set => type = (ushort)value;
} }
public DnsRecordClass Class public DnsRecordClass Class {
{ get => (DnsRecordClass)this.klass;
get => (DnsRecordClass)klass; set => this.klass = (UInt16)value;
set => klass = (ushort)value;
} }
} }
} }

View File

@ -1,85 +1,79 @@
namespace Swan.Net.Dns using Swan.Formatters;
{
using Formatters;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// DnsClient public methods. /// DnsClient public methods.
/// </summary> /// </summary>
internal partial class DnsClient internal partial class DnsClient {
{ public abstract class DnsResourceRecordBase : IDnsResourceRecord {
public abstract class DnsResourceRecordBase : IDnsResourceRecord
{
private readonly IDnsResourceRecord _record; private readonly IDnsResourceRecord _record;
protected DnsResourceRecordBase(IDnsResourceRecord record) protected DnsResourceRecordBase(IDnsResourceRecord record) => this._record = record;
{
_record = record; public DnsDomain Name => this._record.Name;
public DnsRecordType Type => this._record.Type;
public DnsRecordClass Class => this._record.Class;
public TimeSpan TimeToLive => this._record.TimeToLive;
public Int32 DataLength => this._record.DataLength;
public Byte[] Data => this._record.Data;
public Int32 Size => this._record.Size;
protected virtual String[] IncludedProperties => new[] { nameof(this.Name), nameof(this.Type), nameof(this.Class), nameof(this.TimeToLive), nameof(this.DataLength) };
public Byte[] ToArray() => this._record.ToArray();
public override String ToString() => Json.SerializeOnly(this, true, this.IncludedProperties);
} }
public DnsDomain Name => _record.Name; public class DnsResourceRecord : IDnsResourceRecord {
public DnsResourceRecord(DnsDomain domain, Byte[] data, DnsRecordType type, DnsRecordClass klass = DnsRecordClass.IN, TimeSpan ttl = default) {
public DnsRecordType Type => _record.Type; this.Name = domain;
this.Type = type;
public DnsRecordClass Class => _record.Class; this.Class = klass;
this.TimeToLive = ttl;
public TimeSpan TimeToLive => _record.TimeToLive; this.Data = data;
public int DataLength => _record.DataLength;
public byte[] Data => _record.Data;
public int Size => _record.Size;
protected virtual string[] IncludedProperties
=> new[] {nameof(Name), nameof(Type), nameof(Class), nameof(TimeToLive), nameof(DataLength)};
public byte[] ToArray() => _record.ToArray();
public override string ToString()
=> Json.SerializeOnly(this, true, IncludedProperties);
} }
public class DnsResourceRecord : IDnsResourceRecord public DnsDomain Name {
{ get;
public DnsResourceRecord(
DnsDomain domain,
byte[] data,
DnsRecordType type,
DnsRecordClass klass = DnsRecordClass.IN,
TimeSpan ttl = default)
{
Name = domain;
Type = type;
Class = klass;
TimeToLive = ttl;
Data = data;
} }
public DnsDomain Name { get; } public DnsRecordType Type {
get;
}
public DnsRecordType Type { get; } public DnsRecordClass Class {
get;
}
public DnsRecordClass Class { get; } public TimeSpan TimeToLive {
get;
}
public TimeSpan TimeToLive { get; } public Int32 DataLength => this.Data.Length;
public int DataLength => Data.Length; public Byte[] Data {
get;
}
public byte[] Data { get; } public Int32 Size => this.Name.Size + Tail.SIZE + this.Data.Length;
public int Size => Name.Size + Tail.SIZE + Data.Length; public static DnsResourceRecord FromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
DnsDomain domain = DnsDomain.FromArray(message, offset, out offset);
Tail tail = message.ToStruct<Tail>(offset, Tail.SIZE);
public static DnsResourceRecord FromArray(byte[] message, int offset, out int endOffset) Byte[] data = new Byte[tail.DataLength];
{
var domain = DnsDomain.FromArray(message, offset, out offset);
var tail = message.ToStruct<Tail>(offset, Tail.SIZE);
var data = new byte[tail.DataLength];
offset += Tail.SIZE; offset += Tail.SIZE;
Array.Copy(message, offset, data, 0, data.Length); Array.Copy(message, offset, data, 0, data.Length);
@ -89,239 +83,184 @@
return new DnsResourceRecord(domain, data, tail.Type, tail.Class, tail.TimeToLive); return new DnsResourceRecord(domain, data, tail.Type, tail.Class, tail.TimeToLive);
} }
public byte[] ToArray() => public Byte[] ToArray() => new MemoryStream(this.Size).Append(this.Name.ToArray()).Append(new Tail() { Type = Type, Class = Class, TimeToLive = TimeToLive, DataLength = this.Data.Length, }.ToBytes()).Append(this.Data).ToArray();
new MemoryStream(Size)
.Append(Name.ToArray())
.Append(new Tail()
{
Type = Type,
Class = Class,
TimeToLive = TimeToLive,
DataLength = Data.Length,
}.ToBytes())
.Append(Data)
.ToArray();
public override string ToString() public override String ToString() => Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class), nameof(this.TimeToLive), nameof(this.DataLength));
{
return Json.SerializeOnly(
this,
true,
nameof(Name),
nameof(Type),
nameof(Class),
nameof(TimeToLive),
nameof(DataLength));
}
[StructEndianness(Endianness.Big)] [StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 2)] [StructLayout(LayoutKind.Sequential, Pack = 2)]
private struct Tail private struct Tail {
{ public const Int32 SIZE = 10;
public const int SIZE = 10;
private ushort type; private UInt16 type;
private ushort klass; private UInt16 klass;
private uint ttl; private UInt32 ttl;
private ushort dataLength; private UInt16 dataLength;
public DnsRecordType Type public DnsRecordType Type {
{ get => (DnsRecordType)this.type;
get => (DnsRecordType) type; set => this.type = (UInt16)value;
set => type = (ushort) value;
} }
public DnsRecordClass Class public DnsRecordClass Class {
{ get => (DnsRecordClass)this.klass;
get => (DnsRecordClass) klass; set => this.klass = (UInt16)value;
set => klass = (ushort) value;
} }
public TimeSpan TimeToLive public TimeSpan TimeToLive {
{ get => TimeSpan.FromSeconds(this.ttl);
get => TimeSpan.FromSeconds(ttl); set => this.ttl = (UInt32)value.TotalSeconds;
set => ttl = (uint) value.TotalSeconds;
} }
public int DataLength public Int32 DataLength {
{ get => this.dataLength;
get => dataLength; set => this.dataLength = (UInt16)value;
set => dataLength = (ushort) value;
} }
} }
} }
public class DnsPointerResourceRecord : DnsResourceRecordBase public class DnsPointerResourceRecord : DnsResourceRecordBase {
{ public DnsPointerResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.PointerDomainName = DnsDomain.FromArray(message, dataOffset);
public DnsPointerResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset)
: base(record) public DnsDomain PointerDomainName {
{ get;
PointerDomainName = DnsDomain.FromArray(message, dataOffset);
} }
public DnsDomain PointerDomainName { get; } protected override String[] IncludedProperties {
get {
protected override string[] IncludedProperties List<String> temp = new List<String>(base.IncludedProperties) { nameof(this.PointerDomainName) };
{
get
{
var temp = new List<string>(base.IncludedProperties) {nameof(PointerDomainName)};
return temp.ToArray(); return temp.ToArray();
} }
} }
} }
public class DnsIPAddressResourceRecord : DnsResourceRecordBase public class DnsIPAddressResourceRecord : DnsResourceRecordBase {
{ public DnsIPAddressResourceRecord(IDnsResourceRecord record) : base(record) => this.IPAddress = new IPAddress(this.Data);
public DnsIPAddressResourceRecord(IDnsResourceRecord record)
: base(record) public IPAddress IPAddress {
{ get;
IPAddress = new IPAddress(Data);
} }
public IPAddress IPAddress { get; } protected override String[] IncludedProperties => new List<String>(base.IncludedProperties) { nameof(this.IPAddress) }.ToArray();
protected override string[] IncludedProperties
=> new List<string>(base.IncludedProperties) {nameof(IPAddress)}.ToArray();
} }
public class DnsNameServerResourceRecord : DnsResourceRecordBase public class DnsNameServerResourceRecord : DnsResourceRecordBase {
{ public DnsNameServerResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.NSDomainName = DnsDomain.FromArray(message, dataOffset);
public DnsNameServerResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset)
: base(record) public DnsDomain NSDomainName {
{ get;
NSDomainName = DnsDomain.FromArray(message, dataOffset);
} }
public DnsDomain NSDomainName { get; } protected override String[] IncludedProperties => new List<String>(base.IncludedProperties) { nameof(this.NSDomainName) }.ToArray();
protected override string[] IncludedProperties
=> new List<string>(base.IncludedProperties) {nameof(NSDomainName)}.ToArray();
} }
public class DnsCanonicalNameResourceRecord : DnsResourceRecordBase public class DnsCanonicalNameResourceRecord : DnsResourceRecordBase {
{ public DnsCanonicalNameResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.CanonicalDomainName = DnsDomain.FromArray(message, dataOffset);
public DnsCanonicalNameResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset)
: base(record) public DnsDomain CanonicalDomainName {
{ get;
CanonicalDomainName = DnsDomain.FromArray(message, dataOffset);
} }
public DnsDomain CanonicalDomainName { get; } protected override String[] IncludedProperties => new List<String>(base.IncludedProperties) { nameof(this.CanonicalDomainName) }.ToArray();
protected override string[] IncludedProperties
=> new List<string>(base.IncludedProperties) {nameof(CanonicalDomainName)}.ToArray();
} }
public class DnsMailExchangeResourceRecord : DnsResourceRecordBase public class DnsMailExchangeResourceRecord : DnsResourceRecordBase {
{ private const Int32 PreferenceSize = 2;
private const int PreferenceSize = 2;
public DnsMailExchangeResourceRecord( public DnsMailExchangeResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset)
IDnsResourceRecord record, : base(record) {
byte[] message, Byte[] preference = new Byte[PreferenceSize];
int dataOffset)
: base(record)
{
var preference = new byte[PreferenceSize];
Array.Copy(message, dataOffset, preference, 0, preference.Length); Array.Copy(message, dataOffset, preference, 0, preference.Length);
if (BitConverter.IsLittleEndian) if(BitConverter.IsLittleEndian) {
{
Array.Reverse(preference); Array.Reverse(preference);
} }
dataOffset += PreferenceSize; dataOffset += PreferenceSize;
Preference = BitConverter.ToUInt16(preference, 0); this.Preference = BitConverter.ToUInt16(preference, 0);
ExchangeDomainName = DnsDomain.FromArray(message, dataOffset); this.ExchangeDomainName = DnsDomain.FromArray(message, dataOffset);
} }
public int Preference { get; } public Int32 Preference {
get;
}
public DnsDomain ExchangeDomainName { get; } public DnsDomain ExchangeDomainName {
get;
}
protected override string[] IncludedProperties => new List<string>(base.IncludedProperties) protected override String[] IncludedProperties => new List<String>(base.IncludedProperties)
{ {
nameof(Preference), nameof(this.Preference),
nameof(ExchangeDomainName), nameof(this.ExchangeDomainName),
}.ToArray(); }.ToArray();
} }
public class DnsStartOfAuthorityResourceRecord : DnsResourceRecordBase public class DnsStartOfAuthorityResourceRecord : DnsResourceRecordBase {
{ public DnsStartOfAuthorityResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) {
public DnsStartOfAuthorityResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset) this.MasterDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset);
: base(record) this.ResponsibleDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset);
{
MasterDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset);
ResponsibleDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset);
var tail = message.ToStruct<Options>(dataOffset, Options.SIZE); Options tail = message.ToStruct<Options>(dataOffset, Options.SIZE);
SerialNumber = tail.SerialNumber; this.SerialNumber = tail.SerialNumber;
RefreshInterval = tail.RefreshInterval; this.RefreshInterval = tail.RefreshInterval;
RetryInterval = tail.RetryInterval; this.RetryInterval = tail.RetryInterval;
ExpireInterval = tail.ExpireInterval; this.ExpireInterval = tail.ExpireInterval;
MinimumTimeToLive = tail.MinimumTimeToLive; this.MinimumTimeToLive = tail.MinimumTimeToLive;
} }
public DnsStartOfAuthorityResourceRecord( public DnsStartOfAuthorityResourceRecord(DnsDomain domain, DnsDomain master, DnsDomain responsible, Int64 serial, TimeSpan refresh, TimeSpan retry, TimeSpan expire, TimeSpan minTtl, TimeSpan ttl = default)
DnsDomain domain, : base(Create(domain, master, responsible, serial, refresh, retry, expire, minTtl, ttl)) {
DnsDomain master, this.MasterDomainName = master;
DnsDomain responsible, this.ResponsibleDomainName = responsible;
long serial,
TimeSpan refresh,
TimeSpan retry,
TimeSpan expire,
TimeSpan minTtl,
TimeSpan ttl = default)
: base(Create(domain, master, responsible, serial, refresh, retry, expire, minTtl, ttl))
{
MasterDomainName = master;
ResponsibleDomainName = responsible;
SerialNumber = serial; this.SerialNumber = serial;
RefreshInterval = refresh; this.RefreshInterval = refresh;
RetryInterval = retry; this.RetryInterval = retry;
ExpireInterval = expire; this.ExpireInterval = expire;
MinimumTimeToLive = minTtl; this.MinimumTimeToLive = minTtl;
} }
public DnsDomain MasterDomainName { get; } public DnsDomain MasterDomainName {
get;
}
public DnsDomain ResponsibleDomainName { get; } public DnsDomain ResponsibleDomainName {
get;
}
public long SerialNumber { get; } public Int64 SerialNumber {
get;
}
public TimeSpan RefreshInterval { get; } public TimeSpan RefreshInterval {
get;
}
public TimeSpan RetryInterval { get; } public TimeSpan RetryInterval {
get;
}
public TimeSpan ExpireInterval { get; } public TimeSpan ExpireInterval {
get;
}
public TimeSpan MinimumTimeToLive { get; } public TimeSpan MinimumTimeToLive {
get;
}
protected override string[] IncludedProperties => new List<string>(base.IncludedProperties) protected override String[] IncludedProperties => new List<String>(base.IncludedProperties)
{ {
nameof(MasterDomainName), nameof(this.MasterDomainName),
nameof(ResponsibleDomainName), nameof(this.ResponsibleDomainName),
nameof(SerialNumber), nameof(this.SerialNumber),
}.ToArray(); }.ToArray();
private static IDnsResourceRecord Create( private static IDnsResourceRecord Create(DnsDomain domain, DnsDomain master, DnsDomain responsible, Int64 serial, TimeSpan refresh, TimeSpan retry, TimeSpan expire, TimeSpan minTtl, TimeSpan ttl) {
DnsDomain domain, MemoryStream data = new MemoryStream(Options.SIZE + master.Size + responsible.Size);
DnsDomain master, Options tail = new Options {
DnsDomain responsible,
long serial,
TimeSpan refresh,
TimeSpan retry,
TimeSpan expire,
TimeSpan minTtl,
TimeSpan ttl)
{
var data = new MemoryStream(Options.SIZE + master.Size + responsible.Size);
var tail = new Options
{
SerialNumber = serial, SerialNumber = serial,
RefreshInterval = refresh, RefreshInterval = refresh,
RetryInterval = retry, RetryInterval = retry,
@ -329,67 +268,54 @@
MinimumTimeToLive = minTtl, MinimumTimeToLive = minTtl,
}; };
data.Append(master.ToArray()).Append(responsible.ToArray()).Append(tail.ToBytes()); _ = data.Append(master.ToArray()).Append(responsible.ToArray()).Append(tail.ToBytes());
return new DnsResourceRecord(domain, data.ToArray(), DnsRecordType.SOA, DnsRecordClass.IN, ttl); return new DnsResourceRecord(domain, data.ToArray(), DnsRecordType.SOA, DnsRecordClass.IN, ttl);
} }
[StructEndianness(Endianness.Big)] [StructEndianness(Endianness.Big)]
[StructLayout(LayoutKind.Sequential, Pack = 4)] [StructLayout(LayoutKind.Sequential, Pack = 4)]
public struct Options public struct Options {
{ public const Int32 SIZE = 20;
public const int SIZE = 20;
private uint serialNumber; private UInt32 serialNumber;
private uint refreshInterval; private UInt32 refreshInterval;
private uint retryInterval; private UInt32 retryInterval;
private uint expireInterval; private UInt32 expireInterval;
private uint ttl; private UInt32 ttl;
public long SerialNumber public Int64 SerialNumber {
{ get => this.serialNumber;
get => serialNumber; set => this.serialNumber = (UInt32)value;
set => serialNumber = (uint) value;
} }
public TimeSpan RefreshInterval public TimeSpan RefreshInterval {
{ get => TimeSpan.FromSeconds(this.refreshInterval);
get => TimeSpan.FromSeconds(refreshInterval); set => this.refreshInterval = (UInt32)value.TotalSeconds;
set => refreshInterval = (uint) value.TotalSeconds;
} }
public TimeSpan RetryInterval public TimeSpan RetryInterval {
{ get => TimeSpan.FromSeconds(this.retryInterval);
get => TimeSpan.FromSeconds(retryInterval); set => this.retryInterval = (UInt32)value.TotalSeconds;
set => retryInterval = (uint) value.TotalSeconds;
} }
public TimeSpan ExpireInterval public TimeSpan ExpireInterval {
{ get => TimeSpan.FromSeconds(this.expireInterval);
get => TimeSpan.FromSeconds(expireInterval); set => this.expireInterval = (UInt32)value.TotalSeconds;
set => expireInterval = (uint) value.TotalSeconds;
} }
public TimeSpan MinimumTimeToLive public TimeSpan MinimumTimeToLive {
{ get => TimeSpan.FromSeconds(this.ttl);
get => TimeSpan.FromSeconds(ttl); set => this.ttl = (UInt32)value.TotalSeconds;
set => ttl = (uint) value.TotalSeconds;
} }
} }
} }
private static class DnsResourceRecordFactory private static class DnsResourceRecordFactory {
{ public static IList<IDnsResourceRecord> GetAllFromArray(Byte[] message, Int32 offset, Int32 count, out Int32 endOffset) {
public static IList<IDnsResourceRecord> GetAllFromArray( List<IDnsResourceRecord> result = new List<IDnsResourceRecord>(count);
byte[] message,
int offset,
int count,
out int endOffset)
{
var result = new List<IDnsResourceRecord>(count);
for (var i = 0; i < count; i++) for(Int32 i = 0; i < count; i++) {
{
result.Add(GetFromArray(message, offset, out offset)); result.Add(GetFromArray(message, offset, out offset));
} }
@ -397,14 +323,13 @@
return result; return result;
} }
private static IDnsResourceRecord GetFromArray(byte[] message, int offset, out int endOffset) private static IDnsResourceRecord GetFromArray(Byte[] message, Int32 offset, out Int32 endOffset) {
{ DnsResourceRecord record = DnsResourceRecord.FromArray(message, offset, out endOffset);
var record = DnsResourceRecord.FromArray(message, offset, out endOffset); Int32 dataOffset = endOffset - record.DataLength;
var dataOffset = endOffset - record.DataLength;
return record.Type switch return record.Type switch
{ {
DnsRecordType.A => (IDnsResourceRecord) new DnsIPAddressResourceRecord(record), DnsRecordType.A => (new DnsIPAddressResourceRecord(record)),
DnsRecordType.AAAA => new DnsIPAddressResourceRecord(record), DnsRecordType.AAAA => new DnsIPAddressResourceRecord(record),
DnsRecordType.NS => new DnsNameServerResourceRecord(record, message, dataOffset), DnsRecordType.NS => new DnsNameServerResourceRecord(record, message, dataOffset),
DnsRecordType.CNAME => new DnsCanonicalNameResourceRecord(record, message, dataOffset), DnsRecordType.CNAME => new DnsCanonicalNameResourceRecord(record, message, dataOffset),

View File

@ -1,214 +1,173 @@
namespace Swan.Net.Dns using Swan.Formatters;
{
using Formatters;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// DnsClient Response inner class. /// DnsClient Response inner class.
/// </summary> /// </summary>
internal partial class DnsClient internal partial class DnsClient {
{ public class DnsClientResponse : IDnsResponse {
public class DnsClientResponse : IDnsResponse
{
private readonly DnsResponse _response; private readonly DnsResponse _response;
private readonly byte[] _message; private readonly Byte[] _message;
internal DnsClientResponse(DnsClientRequest request, DnsResponse response, byte[] message) internal DnsClientResponse(DnsClientRequest request, DnsResponse response, Byte[] message) {
{ this.Request = request;
Request = request;
_message = message; this._message = message;
_response = response; this._response = response;
} }
public DnsClientRequest Request { get; } public DnsClientRequest Request {
get;
public int Id
{
get { return _response.Id; }
set { }
} }
public IList<IDnsResourceRecord> AnswerRecords => _response.AnswerRecords; public Int32 Id {
get => this._response.Id;
public IList<IDnsResourceRecord> AuthorityRecords => set {
new ReadOnlyCollection<IDnsResourceRecord>(_response.AuthorityRecords); }
public IList<IDnsResourceRecord> AdditionalRecords =>
new ReadOnlyCollection<IDnsResourceRecord>(_response.AdditionalRecords);
public bool IsRecursionAvailable
{
get { return _response.IsRecursionAvailable; }
set { }
} }
public bool IsAuthorativeServer public IList<IDnsResourceRecord> AnswerRecords => this._response.AnswerRecords;
{
get { return _response.IsAuthorativeServer; } public IList<IDnsResourceRecord> AuthorityRecords => new ReadOnlyCollection<IDnsResourceRecord>(this._response.AuthorityRecords);
set { }
public IList<IDnsResourceRecord> AdditionalRecords => new ReadOnlyCollection<IDnsResourceRecord>(this._response.AdditionalRecords);
public Boolean IsRecursionAvailable {
get => this._response.IsRecursionAvailable;
set {
}
} }
public bool IsTruncated public Boolean IsAuthorativeServer {
{ get => this._response.IsAuthorativeServer;
get { return _response.IsTruncated; } set {
set { } }
} }
public DnsOperationCode OperationCode public Boolean IsTruncated {
{ get => this._response.IsTruncated;
get { return _response.OperationCode; } set {
set { } }
} }
public DnsResponseCode ResponseCode public DnsOperationCode OperationCode {
{ get => this._response.OperationCode;
get { return _response.ResponseCode; } set {
set { } }
} }
public IList<DnsQuestion> Questions => new ReadOnlyCollection<DnsQuestion>(_response.Questions); public DnsResponseCode ResponseCode {
get => this._response.ResponseCode;
public int Size => _message.Length; set {
}
public byte[] ToArray() => _message;
public override string ToString() => _response.ToString();
} }
public class DnsResponse : IDnsResponse public IList<DnsQuestion> Questions => new ReadOnlyCollection<DnsQuestion>(this._response.Questions);
{
public Int32 Size => this._message.Length;
public Byte[] ToArray() => this._message;
public override String ToString() => this._response.ToString();
}
public class DnsResponse : IDnsResponse {
private DnsHeader _header; private DnsHeader _header;
public DnsResponse( public DnsResponse(DnsHeader header, IList<DnsQuestion> questions, IList<IDnsResourceRecord> answers, IList<IDnsResourceRecord> authority, IList<IDnsResourceRecord> additional) {
DnsHeader header, this._header = header;
IList<DnsQuestion> questions, this.Questions = questions;
IList<IDnsResourceRecord> answers, this.AnswerRecords = answers;
IList<IDnsResourceRecord> authority, this.AuthorityRecords = authority;
IList<IDnsResourceRecord> additional) this.AdditionalRecords = additional;
{
_header = header;
Questions = questions;
AnswerRecords = answers;
AuthorityRecords = authority;
AdditionalRecords = additional;
} }
public IList<DnsQuestion> Questions { get; } public IList<DnsQuestion> Questions {
get;
public IList<IDnsResourceRecord> AnswerRecords { get; }
public IList<IDnsResourceRecord> AuthorityRecords { get; }
public IList<IDnsResourceRecord> AdditionalRecords { get; }
public int Id
{
get => _header.Id;
set => _header.Id = value;
} }
public bool IsRecursionAvailable public IList<IDnsResourceRecord> AnswerRecords {
{ get;
get => _header.RecursionAvailable;
set => _header.RecursionAvailable = value;
} }
public bool IsAuthorativeServer public IList<IDnsResourceRecord> AuthorityRecords {
{ get;
get => _header.AuthorativeServer;
set => _header.AuthorativeServer = value;
} }
public bool IsTruncated public IList<IDnsResourceRecord> AdditionalRecords {
{ get;
get => _header.Truncated;
set => _header.Truncated = value;
} }
public DnsOperationCode OperationCode public Int32 Id {
{ get => this._header.Id;
get => _header.OperationCode; set => this._header.Id = value;
set => _header.OperationCode = value;
} }
public DnsResponseCode ResponseCode public Boolean IsRecursionAvailable {
{ get => this._header.RecursionAvailable;
get => _header.ResponseCode; set => this._header.RecursionAvailable = value;
set => _header.ResponseCode = value;
} }
public int Size public Boolean IsAuthorativeServer {
=> _header.Size + get => this._header.AuthorativeServer;
Questions.Sum(q => q.Size) + set => this._header.AuthorativeServer = value;
AnswerRecords.Sum(a => a.Size) + }
AuthorityRecords.Sum(a => a.Size) +
AdditionalRecords.Sum(a => a.Size);
public static DnsResponse FromArray(byte[] message) public Boolean IsTruncated {
{ get => this._header.Truncated;
var header = DnsHeader.FromArray(message); set => this._header.Truncated = value;
var offset = header.Size; }
if (!header.Response || header.QuestionCount == 0) public DnsOperationCode OperationCode {
{ get => this._header.OperationCode;
set => this._header.OperationCode = value;
}
public DnsResponseCode ResponseCode {
get => this._header.ResponseCode;
set => this._header.ResponseCode = value;
}
public Int32 Size => this._header.Size + this.Questions.Sum(q => q.Size) + this.AnswerRecords.Sum(a => a.Size) + this.AuthorityRecords.Sum(a => a.Size) + this.AdditionalRecords.Sum(a => a.Size);
public static DnsResponse FromArray(Byte[] message) {
DnsHeader header = DnsHeader.FromArray(message);
Int32 offset = header.Size;
if(!header.Response || header.QuestionCount == 0) {
throw new ArgumentException("Invalid response message"); throw new ArgumentException("Invalid response message");
} }
if (header.Truncated) return header.Truncated
{ ? new DnsResponse(header, DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount), new List<IDnsResourceRecord>(), new List<IDnsResourceRecord>(), new List<IDnsResourceRecord>())
return new DnsResponse(header, : new DnsResponse(header, DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AnswerRecordCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AuthorityRecordCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AdditionalRecordCount, out _));
DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount),
new List<IDnsResourceRecord>(),
new List<IDnsResourceRecord>(),
new List<IDnsResourceRecord>());
} }
return new DnsResponse(header, public Byte[] ToArray() {
DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount, out offset), this.UpdateHeader();
DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AnswerRecordCount, out offset), MemoryStream result = new MemoryStream(this.Size);
DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AuthorityRecordCount, out offset),
DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AdditionalRecordCount, out offset));
}
public byte[] ToArray() _ = result.Append(this._header.ToArray()).Append(this.Questions.Select(q => q.ToArray())).Append(this.AnswerRecords.Select(a => a.ToArray())).Append(this.AuthorityRecords.Select(a => a.ToArray())).Append(this.AdditionalRecords.Select(a => a.ToArray()));
{
UpdateHeader();
var result = new MemoryStream(Size);
result
.Append(_header.ToArray())
.Append(Questions.Select(q => q.ToArray()))
.Append(AnswerRecords.Select(a => a.ToArray()))
.Append(AuthorityRecords.Select(a => a.ToArray()))
.Append(AdditionalRecords.Select(a => a.ToArray()));
return result.ToArray(); return result.ToArray();
} }
public override string ToString() public override String ToString() {
{ this.UpdateHeader();
UpdateHeader();
return Json.SerializeOnly( return Json.SerializeOnly(this, true, nameof(this.Questions), nameof(this.AnswerRecords), nameof(this.AuthorityRecords), nameof(this.AdditionalRecords));
this,
true,
nameof(Questions),
nameof(AnswerRecords),
nameof(AuthorityRecords),
nameof(AdditionalRecords));
} }
private void UpdateHeader() private void UpdateHeader() {
{ this._header.QuestionCount = this.Questions.Count;
_header.QuestionCount = Questions.Count; this._header.AnswerRecordCount = this.AnswerRecords.Count;
_header.AnswerRecordCount = AnswerRecords.Count; this._header.AuthorityRecordCount = this.AuthorityRecords.Count;
_header.AuthorityRecordCount = AuthorityRecords.Count; this._header.AdditionalRecordCount = this.AdditionalRecords.Count;
_header.AdditionalRecordCount = AdditionalRecords.Count;
} }
} }
} }

View File

@ -1,73 +1,59 @@
namespace Swan.Net.Dns using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
#nullable enable
using System.Net; using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// DnsClient public methods. /// DnsClient public methods.
/// </summary> /// </summary>
internal partial class DnsClient internal partial class DnsClient {
{
private readonly IPEndPoint _dns; private readonly IPEndPoint _dns;
private readonly IDnsRequestResolver _resolver; private readonly IDnsRequestResolver _resolver;
public DnsClient(IPEndPoint dns, IDnsRequestResolver? resolver = null) public DnsClient(IPEndPoint dns, IDnsRequestResolver? resolver = null) {
{ this._dns = dns;
_dns = dns; this._resolver = resolver ?? new DnsUdpRequestResolver(new DnsTcpRequestResolver());
_resolver = resolver ?? new DnsUdpRequestResolver(new DnsTcpRequestResolver());
} }
public DnsClient(IPAddress ip, int port = Network.DnsDefaultPort, IDnsRequestResolver? resolver = null) public DnsClient(IPAddress ip, Int32 port = Network.DnsDefaultPort, IDnsRequestResolver? resolver = null) : this(new IPEndPoint(ip, port), resolver) {
: this(new IPEndPoint(ip, port), resolver)
{
} }
public DnsClientRequest Create(IDnsRequest? request = null) public DnsClientRequest Create(IDnsRequest? request = null) => new DnsClientRequest(this._dns, request, this._resolver);
=> new DnsClientRequest(_dns, request, _resolver);
public async Task<IList<IPAddress>> Lookup(string domain, DnsRecordType type = DnsRecordType.A) public async Task<IList<IPAddress>> Lookup(String domain, DnsRecordType type = DnsRecordType.A) {
{ if(String.IsNullOrWhiteSpace(domain)) {
if (string.IsNullOrWhiteSpace(domain))
throw new ArgumentNullException(nameof(domain)); throw new ArgumentNullException(nameof(domain));
}
if (type != DnsRecordType.A && type != DnsRecordType.AAAA) if(type != DnsRecordType.A && type != DnsRecordType.AAAA) {
{
throw new ArgumentException("Invalid record type " + type); throw new ArgumentException("Invalid record type " + type);
} }
var response = await Resolve(domain, type).ConfigureAwait(false); DnsClientResponse response = await this.Resolve(domain, type).ConfigureAwait(false);
var ips = response.AnswerRecords List<IPAddress> ips = response.AnswerRecords.Where(r => r.Type == type).Cast<DnsIPAddressResourceRecord>().Select(r => r.IPAddress).ToList();
.Where(r => r.Type == type)
.Cast<DnsIPAddressResourceRecord>()
.Select(r => r.IPAddress)
.ToList();
return ips.Count == 0 ? throw new DnsQueryException(response, "No matching records") : ips; return ips.Count == 0 ? throw new DnsQueryException(response, "No matching records") : ips;
} }
public async Task<string> Reverse(IPAddress ip) public async Task<String> Reverse(IPAddress ip) {
{ if(ip == null) {
if (ip == null)
throw new ArgumentNullException(nameof(ip)); throw new ArgumentNullException(nameof(ip));
var response = await Resolve(DnsDomain.PointerName(ip), DnsRecordType.PTR);
var ptr = response.AnswerRecords.FirstOrDefault(r => r.Type == DnsRecordType.PTR);
return ptr == null
? throw new DnsQueryException(response, "No matching records")
: ((DnsPointerResourceRecord) ptr).PointerDomainName.ToString();
} }
public Task<DnsClientResponse> Resolve(string domain, DnsRecordType type) => DnsClientResponse response = await this.Resolve(DnsDomain.PointerName(ip), DnsRecordType.PTR);
Resolve(new DnsDomain(domain), type); IDnsResourceRecord ptr = response.AnswerRecords.FirstOrDefault(r => r.Type == DnsRecordType.PTR);
public Task<DnsClientResponse> Resolve(DnsDomain domain, DnsRecordType type) return ptr == null ? throw new DnsQueryException(response, "No matching records") : ((DnsPointerResourceRecord)ptr).PointerDomainName.ToString();
{ }
var request = Create();
var question = new DnsQuestion(domain, type); public Task<DnsClientResponse> Resolve(String domain, DnsRecordType type) => this.Resolve(new DnsDomain(domain), type);
public Task<DnsClientResponse> Resolve(DnsDomain domain, DnsRecordType type) {
DnsClientRequest request = this.Create();
DnsQuestion question = new DnsQuestion(domain, type);
request.Questions.Add(question); request.Questions.Add(question);
request.OperationCode = DnsOperationCode.Query; request.OperationCode = DnsOperationCode.Query;

View File

@ -1,37 +1,28 @@
namespace Swan.Net.Dns #nullable enable
{
using System; using System;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// An exception thrown when the DNS query fails. /// An exception thrown when the DNS query fails.
/// </summary> /// </summary>
/// <seealso cref="Exception" /> /// <seealso cref="Exception" />
[Serializable] [Serializable]
public class DnsQueryException : Exception public class DnsQueryException : Exception {
{ internal DnsQueryException(String message) : base(message) {
internal DnsQueryException(string message)
: base(message)
{
} }
internal DnsQueryException(string message, Exception e) internal DnsQueryException(String message, Exception e) : base(message, e) {
: base(message, e)
{
} }
internal DnsQueryException(DnsClient.IDnsResponse response) internal DnsQueryException(DnsClient.IDnsResponse response) : this(response, Format(response)) {
: this(response, Format(response))
{
} }
internal DnsQueryException(DnsClient.IDnsResponse response, string message) internal DnsQueryException(DnsClient.IDnsResponse response, String message) : base(message) => this.Response = response;
: base(message)
{ internal DnsClient.IDnsResponse? Response {
Response = response; get;
} }
internal DnsClient.IDnsResponse? Response { get; } private static String Format(DnsClient.IDnsResponse response) => $"Invalid response received with code {response.ResponseCode}";
private static string Format(DnsClient.IDnsResponse response) => $"Invalid response received with code {response.ResponseCode}";
} }
} }

View File

@ -1,12 +1,10 @@
namespace Swan.Net.Dns namespace Swan.Net.Dns {
{
using System.Collections.Generic; using System.Collections.Generic;
/// <summary> /// <summary>
/// Represents a response from a DNS server. /// Represents a response from a DNS server.
/// </summary> /// </summary>
public class DnsQueryResult public class DnsQueryResult {
{
private readonly List<DnsRecord> _mAnswerRecords = new List<DnsRecord>(); private readonly List<DnsRecord> _mAnswerRecords = new List<DnsRecord>();
private readonly List<DnsRecord> _mAdditionalRecords = new List<DnsRecord>(); private readonly List<DnsRecord> _mAdditionalRecords = new List<DnsRecord>();
private readonly List<DnsRecord> _mAuthorityRecords = new List<DnsRecord>(); private readonly List<DnsRecord> _mAuthorityRecords = new List<DnsRecord>();
@ -15,37 +13,34 @@
/// Initializes a new instance of the <see cref="DnsQueryResult"/> class. /// Initializes a new instance of the <see cref="DnsQueryResult"/> class.
/// </summary> /// </summary>
/// <param name="response">The response.</param> /// <param name="response">The response.</param>
internal DnsQueryResult(DnsClient.IDnsResponse response) internal DnsQueryResult(DnsClient.IDnsResponse response) : this() {
: this() this.Id = response.Id;
{ this.IsAuthoritativeServer = response.IsAuthorativeServer;
Id = response.Id; this.IsRecursionAvailable = response.IsRecursionAvailable;
IsAuthoritativeServer = response.IsAuthorativeServer; this.IsTruncated = response.IsTruncated;
IsRecursionAvailable = response.IsRecursionAvailable; this.OperationCode = response.OperationCode;
IsTruncated = response.IsTruncated; this.ResponseCode = response.ResponseCode;
OperationCode = response.OperationCode;
ResponseCode = response.ResponseCode;
if (response.AnswerRecords != null) if(response.AnswerRecords != null) {
{ foreach(DnsClient.IDnsResourceRecord record in response.AnswerRecords) {
foreach (var record in response.AnswerRecords) this.AnswerRecords.Add(new DnsRecord(record));
AnswerRecords.Add(new DnsRecord(record));
}
if (response.AuthorityRecords != null)
{
foreach (var record in response.AuthorityRecords)
AuthorityRecords.Add(new DnsRecord(record));
}
if (response.AdditionalRecords != null)
{
foreach (var record in response.AdditionalRecords)
AdditionalRecords.Add(new DnsRecord(record));
} }
} }
private DnsQueryResult() if(response.AuthorityRecords != null) {
{ foreach(DnsClient.IDnsResourceRecord record in response.AuthorityRecords) {
this.AuthorityRecords.Add(new DnsRecord(record));
}
}
if(response.AdditionalRecords != null) {
foreach(DnsClient.IDnsResourceRecord record in response.AdditionalRecords) {
this.AdditionalRecords.Add(new DnsRecord(record));
}
}
}
private DnsQueryResult() {
} }
/// <summary> /// <summary>
@ -54,7 +49,9 @@
/// <value> /// <value>
/// The identifier. /// The identifier.
/// </value> /// </value>
public int Id { get; } public System.Int32 Id {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this instance is authoritative server. /// Gets a value indicating whether this instance is authoritative server.
@ -62,7 +59,9 @@
/// <value> /// <value>
/// <c>true</c> if this instance is authoritative server; otherwise, <c>false</c>. /// <c>true</c> if this instance is authoritative server; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsAuthoritativeServer { get; } public System.Boolean IsAuthoritativeServer {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this instance is truncated. /// Gets a value indicating whether this instance is truncated.
@ -70,7 +69,9 @@
/// <value> /// <value>
/// <c>true</c> if this instance is truncated; otherwise, <c>false</c>. /// <c>true</c> if this instance is truncated; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsTruncated { get; } public System.Boolean IsTruncated {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this instance is recursion available. /// Gets a value indicating whether this instance is recursion available.
@ -78,7 +79,9 @@
/// <value> /// <value>
/// <c>true</c> if this instance is recursion available; otherwise, <c>false</c>. /// <c>true</c> if this instance is recursion available; otherwise, <c>false</c>.
/// </value> /// </value>
public bool IsRecursionAvailable { get; } public System.Boolean IsRecursionAvailable {
get;
}
/// <summary> /// <summary>
/// Gets the operation code. /// Gets the operation code.
@ -86,7 +89,9 @@
/// <value> /// <value>
/// The operation code. /// The operation code.
/// </value> /// </value>
public DnsOperationCode OperationCode { get; } public DnsOperationCode OperationCode {
get;
}
/// <summary> /// <summary>
/// Gets the response code. /// Gets the response code.
@ -94,7 +99,9 @@
/// <value> /// <value>
/// The response code. /// The response code.
/// </value> /// </value>
public DnsResponseCode ResponseCode { get; } public DnsResponseCode ResponseCode {
get;
}
/// <summary> /// <summary>
/// Gets the answer records. /// Gets the answer records.
@ -102,7 +109,7 @@
/// <value> /// <value>
/// The answer records. /// The answer records.
/// </value> /// </value>
public IList<DnsRecord> AnswerRecords => _mAnswerRecords; public IList<DnsRecord> AnswerRecords => this._mAnswerRecords;
/// <summary> /// <summary>
/// Gets the additional records. /// Gets the additional records.
@ -110,7 +117,7 @@
/// <value> /// <value>
/// The additional records. /// The additional records.
/// </value> /// </value>
public IList<DnsRecord> AdditionalRecords => _mAdditionalRecords; public IList<DnsRecord> AdditionalRecords => this._mAdditionalRecords;
/// <summary> /// <summary>
/// Gets the authority records. /// Gets the authority records.
@ -118,6 +125,6 @@
/// <value> /// <value>
/// The authority records. /// The authority records.
/// </value> /// </value>
public IList<DnsRecord> AuthorityRecords => _mAuthorityRecords; public IList<DnsRecord> AuthorityRecords => this._mAuthorityRecords;
} }
} }

View File

@ -1,55 +1,50 @@
namespace Swan.Net.Dns using System;
{
using System;
using System.Net; using System.Net;
using System.Text; using System.Text;
namespace Swan.Net.Dns {
/// <summary> /// <summary>
/// Represents a DNS record entry. /// Represents a DNS record entry.
/// </summary> /// </summary>
public class DnsRecord public class DnsRecord {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DnsRecord"/> class. /// Initializes a new instance of the <see cref="DnsRecord"/> class.
/// </summary> /// </summary>
/// <param name="record">The record.</param> /// <param name="record">The record.</param>
internal DnsRecord(DnsClient.IDnsResourceRecord record) internal DnsRecord(DnsClient.IDnsResourceRecord record) : this() {
: this() this.Name = record.Name.ToString();
{ this.Type = record.Type;
Name = record.Name.ToString(); this.Class = record.Class;
Type = record.Type; this.TimeToLive = record.TimeToLive;
Class = record.Class; this.Data = record.Data;
TimeToLive = record.TimeToLive;
Data = record.Data;
// PTR // PTR
PointerDomainName = (record as DnsClient.DnsPointerResourceRecord)?.PointerDomainName?.ToString(); this.PointerDomainName = (record as DnsClient.DnsPointerResourceRecord)?.PointerDomainName?.ToString();
// A // A
IPAddress = (record as DnsClient.DnsIPAddressResourceRecord)?.IPAddress; this.IPAddress = (record as DnsClient.DnsIPAddressResourceRecord)?.IPAddress;
// NS // NS
NameServerDomainName = (record as DnsClient.DnsNameServerResourceRecord)?.NSDomainName?.ToString(); this.NameServerDomainName = (record as DnsClient.DnsNameServerResourceRecord)?.NSDomainName?.ToString();
// CNAME // CNAME
CanonicalDomainName = (record as DnsClient.DnsCanonicalNameResourceRecord)?.CanonicalDomainName.ToString(); this.CanonicalDomainName = (record as DnsClient.DnsCanonicalNameResourceRecord)?.CanonicalDomainName.ToString();
// MX // MX
MailExchangerDomainName = (record as DnsClient.DnsMailExchangeResourceRecord)?.ExchangeDomainName.ToString(); this.MailExchangerDomainName = (record as DnsClient.DnsMailExchangeResourceRecord)?.ExchangeDomainName.ToString();
MailExchangerPreference = (record as DnsClient.DnsMailExchangeResourceRecord)?.Preference; this.MailExchangerPreference = (record as DnsClient.DnsMailExchangeResourceRecord)?.Preference;
// SOA // SOA
SoaMasterDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MasterDomainName.ToString(); this.SoaMasterDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MasterDomainName.ToString();
SoaResponsibleDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ResponsibleDomainName.ToString(); this.SoaResponsibleDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ResponsibleDomainName.ToString();
SoaSerialNumber = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.SerialNumber; this.SoaSerialNumber = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.SerialNumber;
SoaRefreshInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RefreshInterval; this.SoaRefreshInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RefreshInterval;
SoaRetryInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RetryInterval; this.SoaRetryInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RetryInterval;
SoaExpireInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ExpireInterval; this.SoaExpireInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ExpireInterval;
SoaMinimumTimeToLive = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MinimumTimeToLive; this.SoaMinimumTimeToLive = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MinimumTimeToLive;
} }
private DnsRecord() private DnsRecord() {
{
// placeholder // placeholder
} }
@ -59,7 +54,9 @@
/// <value> /// <value>
/// The name. /// The name.
/// </value> /// </value>
public string Name { get; } public String Name {
get;
}
/// <summary> /// <summary>
/// Gets the type. /// Gets the type.
@ -67,7 +64,9 @@
/// <value> /// <value>
/// The type. /// The type.
/// </value> /// </value>
public DnsRecordType Type { get; } public DnsRecordType Type {
get;
}
/// <summary> /// <summary>
/// Gets the class. /// Gets the class.
@ -75,7 +74,9 @@
/// <value> /// <value>
/// The class. /// The class.
/// </value> /// </value>
public DnsRecordClass Class { get; } public DnsRecordClass Class {
get;
}
/// <summary> /// <summary>
/// Gets the time to live. /// Gets the time to live.
@ -83,7 +84,9 @@
/// <value> /// <value>
/// The time to live. /// The time to live.
/// </value> /// </value>
public TimeSpan TimeToLive { get; } public TimeSpan TimeToLive {
get;
}
/// <summary> /// <summary>
/// Gets the raw data of the record. /// Gets the raw data of the record.
@ -91,7 +94,9 @@
/// <value> /// <value>
/// The data. /// The data.
/// </value> /// </value>
public byte[] Data { get; } public Byte[] Data {
get;
}
/// <summary> /// <summary>
/// Gets the data text bytes in ASCII encoding. /// Gets the data text bytes in ASCII encoding.
@ -99,7 +104,7 @@
/// <value> /// <value>
/// The data text. /// The data text.
/// </value> /// </value>
public string DataText => Data == null ? string.Empty : Encoding.ASCII.GetString(Data); public String DataText => this.Data == null ? String.Empty : Encoding.ASCII.GetString(this.Data);
/// <summary> /// <summary>
/// Gets the name of the pointer domain. /// Gets the name of the pointer domain.
@ -107,7 +112,9 @@
/// <value> /// <value>
/// The name of the pointer domain. /// The name of the pointer domain.
/// </value> /// </value>
public string PointerDomainName { get; } public String PointerDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the ip address. /// Gets the ip address.
@ -115,7 +122,9 @@
/// <value> /// <value>
/// The ip address. /// The ip address.
/// </value> /// </value>
public IPAddress IPAddress { get; } public IPAddress IPAddress {
get;
}
/// <summary> /// <summary>
/// Gets the name of the name server domain. /// Gets the name of the name server domain.
@ -123,7 +132,9 @@
/// <value> /// <value>
/// The name of the name server domain. /// The name of the name server domain.
/// </value> /// </value>
public string NameServerDomainName { get; } public String NameServerDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the name of the canonical domain. /// Gets the name of the canonical domain.
@ -131,7 +142,9 @@
/// <value> /// <value>
/// The name of the canonical domain. /// The name of the canonical domain.
/// </value> /// </value>
public string CanonicalDomainName { get; } public String CanonicalDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the mail exchanger preference. /// Gets the mail exchanger preference.
@ -139,7 +152,9 @@
/// <value> /// <value>
/// The mail exchanger preference. /// The mail exchanger preference.
/// </value> /// </value>
public int? MailExchangerPreference { get; } public Int32? MailExchangerPreference {
get;
}
/// <summary> /// <summary>
/// Gets the name of the mail exchanger domain. /// Gets the name of the mail exchanger domain.
@ -147,7 +162,9 @@
/// <value> /// <value>
/// The name of the mail exchanger domain. /// The name of the mail exchanger domain.
/// </value> /// </value>
public string MailExchangerDomainName { get; } public String MailExchangerDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the name of the soa master domain. /// Gets the name of the soa master domain.
@ -155,7 +172,9 @@
/// <value> /// <value>
/// The name of the soa master domain. /// The name of the soa master domain.
/// </value> /// </value>
public string SoaMasterDomainName { get; } public String SoaMasterDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the name of the soa responsible domain. /// Gets the name of the soa responsible domain.
@ -163,7 +182,9 @@
/// <value> /// <value>
/// The name of the soa responsible domain. /// The name of the soa responsible domain.
/// </value> /// </value>
public string SoaResponsibleDomainName { get; } public String SoaResponsibleDomainName {
get;
}
/// <summary> /// <summary>
/// Gets the soa serial number. /// Gets the soa serial number.
@ -171,7 +192,9 @@
/// <value> /// <value>
/// The soa serial number. /// The soa serial number.
/// </value> /// </value>
public long? SoaSerialNumber { get; } public Int64? SoaSerialNumber {
get;
}
/// <summary> /// <summary>
/// Gets the soa refresh interval. /// Gets the soa refresh interval.
@ -179,7 +202,9 @@
/// <value> /// <value>
/// The soa refresh interval. /// The soa refresh interval.
/// </value> /// </value>
public TimeSpan? SoaRefreshInterval { get; } public TimeSpan? SoaRefreshInterval {
get;
}
/// <summary> /// <summary>
/// Gets the soa retry interval. /// Gets the soa retry interval.
@ -187,7 +212,9 @@
/// <value> /// <value>
/// The soa retry interval. /// The soa retry interval.
/// </value> /// </value>
public TimeSpan? SoaRetryInterval { get; } public TimeSpan? SoaRetryInterval {
get;
}
/// <summary> /// <summary>
/// Gets the soa expire interval. /// Gets the soa expire interval.
@ -195,7 +222,9 @@
/// <value> /// <value>
/// The soa expire interval. /// The soa expire interval.
/// </value> /// </value>
public TimeSpan? SoaExpireInterval { get; } public TimeSpan? SoaExpireInterval {
get;
}
/// <summary> /// <summary>
/// Gets the soa minimum time to live. /// Gets the soa minimum time to live.
@ -203,6 +232,8 @@
/// <value> /// <value>
/// The soa minimum time to live. /// The soa minimum time to live.
/// </value> /// </value>
public TimeSpan? SoaMinimumTimeToLive { get; } public TimeSpan? SoaMinimumTimeToLive {
get;
}
} }
} }

View File

@ -1,11 +1,9 @@
// ReSharper disable InconsistentNaming // ReSharper disable InconsistentNaming
namespace Swan.Net.Dns namespace Swan.Net.Dns {
{
/// <summary> /// <summary>
/// Enumerates the different DNS record types. /// Enumerates the different DNS record types.
/// </summary> /// </summary>
public enum DnsRecordType public enum DnsRecordType {
{
/// <summary> /// <summary>
/// A records /// A records
/// </summary> /// </summary>
@ -65,8 +63,7 @@ namespace Swan.Net.Dns
/// <summary> /// <summary>
/// Enumerates the different DNS record classes. /// Enumerates the different DNS record classes.
/// </summary> /// </summary>
public enum DnsRecordClass public enum DnsRecordClass {
{
/// <summary> /// <summary>
/// IN records /// IN records
/// </summary> /// </summary>
@ -81,8 +78,7 @@ namespace Swan.Net.Dns
/// <summary> /// <summary>
/// Enumerates the different DNS operation codes. /// Enumerates the different DNS operation codes.
/// </summary> /// </summary>
public enum DnsOperationCode public enum DnsOperationCode {
{
/// <summary> /// <summary>
/// Query operation /// Query operation
/// </summary> /// </summary>
@ -112,8 +108,7 @@ namespace Swan.Net.Dns
/// <summary> /// <summary>
/// Enumerates the different DNS query response codes. /// Enumerates the different DNS query response codes.
/// </summary> /// </summary>
public enum DnsResponseCode public enum DnsResponseCode {
{
/// <summary> /// <summary>
/// No error /// No error
/// </summary> /// </summary>

View File

@ -1,24 +1,20 @@
namespace Swan.Net #nullable enable
{
using System; using System;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
namespace Swan.Net {
/// <summary> /// <summary>
/// The event arguments for when connections are accepted. /// The event arguments for when connections are accepted.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionAcceptedEventArgs : EventArgs public class ConnectionAcceptedEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionAcceptedEventArgs" /> class. /// Initializes a new instance of the <see cref="ConnectionAcceptedEventArgs" /> class.
/// </summary> /// </summary>
/// <param name="client">The client.</param> /// <param name="client">The client.</param>
/// <exception cref="ArgumentNullException">client.</exception> /// <exception cref="ArgumentNullException">client.</exception>
public ConnectionAcceptedEventArgs(TcpClient client) public ConnectionAcceptedEventArgs(TcpClient client) => this.Client = client ?? throw new ArgumentNullException(nameof(client));
{
Client = client ?? throw new ArgumentNullException(nameof(client));
}
/// <summary> /// <summary>
/// Gets the client. /// Gets the client.
@ -26,22 +22,21 @@
/// <value> /// <value>
/// The client. /// The client.
/// </value> /// </value>
public TcpClient Client { get; } public TcpClient Client {
get;
}
} }
/// <summary> /// <summary>
/// Occurs before a connection is accepted. Set the Cancel property to true to prevent the connection from being accepted. /// Occurs before a connection is accepted. Set the Cancel property to true to prevent the connection from being accepted.
/// </summary> /// </summary>
/// <seealso cref="ConnectionAcceptedEventArgs" /> /// <seealso cref="ConnectionAcceptedEventArgs" />
public class ConnectionAcceptingEventArgs : ConnectionAcceptedEventArgs public class ConnectionAcceptingEventArgs : ConnectionAcceptedEventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionAcceptingEventArgs"/> class. /// Initializes a new instance of the <see cref="ConnectionAcceptingEventArgs"/> class.
/// </summary> /// </summary>
/// <param name="client">The client.</param> /// <param name="client">The client.</param>
public ConnectionAcceptingEventArgs(TcpClient client) public ConnectionAcceptingEventArgs(TcpClient client) : base(client) {
: base(client)
{
} }
/// <summary> /// <summary>
@ -50,24 +45,22 @@
/// <value> /// <value>
/// <c>true</c> if cancel; otherwise, <c>false</c>. /// <c>true</c> if cancel; otherwise, <c>false</c>.
/// </value> /// </value>
public bool Cancel { get; set; } public Boolean Cancel {
get; set;
}
} }
/// <summary> /// <summary>
/// Event arguments for when a server listener is started. /// Event arguments for when a server listener is started.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionListenerStartedEventArgs : EventArgs public class ConnectionListenerStartedEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionListenerStartedEventArgs" /> class. /// Initializes a new instance of the <see cref="ConnectionListenerStartedEventArgs" /> class.
/// </summary> /// </summary>
/// <param name="listenerEndPoint">The listener end point.</param> /// <param name="listenerEndPoint">The listener end point.</param>
/// <exception cref="ArgumentNullException">listenerEndPoint.</exception> /// <exception cref="ArgumentNullException">listenerEndPoint.</exception>
public ConnectionListenerStartedEventArgs(IPEndPoint listenerEndPoint) public ConnectionListenerStartedEventArgs(IPEndPoint listenerEndPoint) => this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint));
{
EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint));
}
/// <summary> /// <summary>
/// Gets the end point. /// Gets the end point.
@ -75,15 +68,16 @@
/// <value> /// <value>
/// The end point. /// The end point.
/// </value> /// </value>
public IPEndPoint EndPoint { get; } public IPEndPoint EndPoint {
get;
}
} }
/// <summary> /// <summary>
/// Event arguments for when a server listener fails to start. /// Event arguments for when a server listener fails to start.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionListenerFailedEventArgs : EventArgs public class ConnectionListenerFailedEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionListenerFailedEventArgs" /> class. /// Initializes a new instance of the <see cref="ConnectionListenerFailedEventArgs" /> class.
/// </summary> /// </summary>
@ -94,10 +88,9 @@
/// or /// or
/// ex. /// ex.
/// </exception> /// </exception>
public ConnectionListenerFailedEventArgs(IPEndPoint listenerEndPoint, Exception ex) public ConnectionListenerFailedEventArgs(IPEndPoint listenerEndPoint, Exception ex) {
{ this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint));
EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); this.Error = ex ?? throw new ArgumentNullException(nameof(ex));
Error = ex ?? throw new ArgumentNullException(nameof(ex));
} }
/// <summary> /// <summary>
@ -106,7 +99,9 @@
/// <value> /// <value>
/// The end point. /// The end point.
/// </value> /// </value>
public IPEndPoint EndPoint { get; } public IPEndPoint EndPoint {
get;
}
/// <summary> /// <summary>
/// Gets the error. /// Gets the error.
@ -114,15 +109,16 @@
/// <value> /// <value>
/// The error. /// The error.
/// </value> /// </value>
public Exception Error { get; } public Exception Error {
get;
}
} }
/// <summary> /// <summary>
/// Event arguments for when a server listener stopped. /// Event arguments for when a server listener stopped.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionListenerStoppedEventArgs : EventArgs public class ConnectionListenerStoppedEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionListenerStoppedEventArgs" /> class. /// Initializes a new instance of the <see cref="ConnectionListenerStoppedEventArgs" /> class.
/// </summary> /// </summary>
@ -133,10 +129,9 @@
/// or /// or
/// ex. /// ex.
/// </exception> /// </exception>
public ConnectionListenerStoppedEventArgs(IPEndPoint listenerEndPoint, Exception? ex = null) public ConnectionListenerStoppedEventArgs(IPEndPoint listenerEndPoint, Exception? ex = null) {
{ this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint));
EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); this.Error = ex;
Error = ex;
} }
/// <summary> /// <summary>
@ -145,7 +140,9 @@
/// <value> /// <value>
/// The end point. /// The end point.
/// </value> /// </value>
public IPEndPoint EndPoint { get; } public IPEndPoint EndPoint {
get;
}
/// <summary> /// <summary>
/// Gets the error. /// Gets the error.
@ -153,6 +150,8 @@
/// <value> /// <value>
/// The error. /// The error.
/// </value> /// </value>
public Exception? Error { get; } public Exception? Error {
get;
}
} }
} }

View File

@ -1,22 +1,17 @@
namespace Swan.Net using System;
{
using System;
using System.Text; using System.Text;
namespace Swan.Net {
/// <summary> /// <summary>
/// The event arguments for connection failure events. /// The event arguments for connection failure events.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionFailureEventArgs : EventArgs public class ConnectionFailureEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionFailureEventArgs"/> class. /// Initializes a new instance of the <see cref="ConnectionFailureEventArgs"/> class.
/// </summary> /// </summary>
/// <param name="ex">The ex.</param> /// <param name="ex">The ex.</param>
public ConnectionFailureEventArgs(Exception ex) public ConnectionFailureEventArgs(Exception ex) => this.Error = ex;
{
Error = ex;
}
/// <summary> /// <summary>
/// Gets the error. /// Gets the error.
@ -24,26 +19,26 @@
/// <value> /// <value>
/// The error. /// The error.
/// </value> /// </value>
public Exception Error { get; } public Exception Error {
get;
}
} }
/// <summary> /// <summary>
/// Event arguments for when data is received. /// Event arguments for when data is received.
/// </summary> /// </summary>
/// <seealso cref="System.EventArgs" /> /// <seealso cref="System.EventArgs" />
public class ConnectionDataReceivedEventArgs : EventArgs public class ConnectionDataReceivedEventArgs : EventArgs {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ConnectionDataReceivedEventArgs"/> class. /// Initializes a new instance of the <see cref="ConnectionDataReceivedEventArgs"/> class.
/// </summary> /// </summary>
/// <param name="buffer">The buffer.</param> /// <param name="buffer">The buffer.</param>
/// <param name="trigger">The trigger.</param> /// <param name="trigger">The trigger.</param>
/// <param name="moreAvailable">if set to <c>true</c> [more available].</param> /// <param name="moreAvailable">if set to <c>true</c> [more available].</param>
public ConnectionDataReceivedEventArgs(byte[] buffer, ConnectionDataReceivedTrigger trigger, bool moreAvailable) public ConnectionDataReceivedEventArgs(Byte[] buffer, ConnectionDataReceivedTrigger trigger, Boolean moreAvailable) {
{ this.Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer));
Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); this.Trigger = trigger;
Trigger = trigger; this.HasMoreAvailable = moreAvailable;
HasMoreAvailable = moreAvailable;
} }
/// <summary> /// <summary>
@ -52,7 +47,9 @@
/// <value> /// <value>
/// The buffer. /// The buffer.
/// </value> /// </value>
public byte[] Buffer { get; } public Byte[] Buffer {
get;
}
/// <summary> /// <summary>
/// Gets the cause as to why this event was thrown. /// Gets the cause as to why this event was thrown.
@ -60,7 +57,9 @@
/// <value> /// <value>
/// The trigger. /// The trigger.
/// </value> /// </value>
public ConnectionDataReceivedTrigger Trigger { get; } public ConnectionDataReceivedTrigger Trigger {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether the receive buffer has more bytes available. /// Gets a value indicating whether the receive buffer has more bytes available.
@ -68,7 +67,9 @@
/// <value> /// <value>
/// <c>true</c> if this instance has more available; otherwise, <c>false</c>. /// <c>true</c> if this instance has more available; otherwise, <c>false</c>.
/// </value> /// </value>
public bool HasMoreAvailable { get; } public Boolean HasMoreAvailable {
get;
}
/// <summary> /// <summary>
/// Gets the string from buffer. /// Gets the string from buffer.
@ -78,7 +79,6 @@
/// A <see cref="System.String" /> that contains the results of decoding the specified sequence of bytes. /// A <see cref="System.String" /> that contains the results of decoding the specified sequence of bytes.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">encoding</exception> /// <exception cref="ArgumentNullException">encoding</exception>
public string GetStringFromBuffer(Encoding encoding) public String GetStringFromBuffer(Encoding encoding) => encoding?.GetString(this.Buffer).TrimEnd('\r', '\n') ?? throw new ArgumentNullException(nameof(encoding));
=> encoding?.GetString(Buffer).TrimEnd('\r', '\n') ?? throw new ArgumentNullException(nameof(encoding));
} }
} }

View File

@ -1,6 +1,5 @@
namespace Swan.Net #nullable enable
{ using Swan.Formatters;
using Formatters;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Net.Http; using System.Net.Http;
@ -10,14 +9,14 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Net {
/// <summary> /// <summary>
/// Represents a HttpClient with extended methods to use with JSON payloads /// Represents a HttpClient with extended methods to use with JSON payloads
/// and bearer tokens authentication. /// and bearer tokens authentication.
/// </summary> /// </summary>
public static class JsonClient public static class JsonClient {
{ private const String JsonMimeType = "application/json";
private const string JsonMimeType = "application/json"; private const String FormType = "application/x-www-form-urlencoded";
private const string FormType = "application/x-www-form-urlencoded";
private static readonly HttpClient HttpClient = new HttpClient(); private static readonly HttpClient HttpClient = new HttpClient();
@ -32,16 +31,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested type. /// A task with a result of the requested type.
/// </returns> /// </returns>
public static async Task<T> Post<T>( public static async Task<T> Post<T>(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) where T : notnull {
Uri requestUri, String jsonString = await PostString(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false);
object payload,
string? authorization = null,
CancellationToken cancellationToken = default)
{
var jsonString = await PostString(requestUri, payload, authorization, cancellationToken)
.ConfigureAwait(false);
return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default; return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default;
} }
/// <summary> /// <summary>
@ -54,18 +47,10 @@
/// <returns> /// <returns>
/// A task with a result as a collection of key/value pairs. /// A task with a result as a collection of key/value pairs.
/// </returns> /// </returns>
public static async Task<IDictionary<string, object>?> Post( public static async Task<IDictionary<String, Object>?> Post(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) {
Uri requestUri, String jsonString = await PostString(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false);
object payload,
string? authorization = null,
CancellationToken cancellationToken = default)
{
var jsonString = await PostString(requestUri, payload, authorization, cancellationToken)
.ConfigureAwait(false);
return string.IsNullOrWhiteSpace(jsonString) return String.IsNullOrWhiteSpace(jsonString) ? default : Json.Deserialize(jsonString) as IDictionary<String, Object>;
? default
: Json.Deserialize(jsonString) as IDictionary<string, object>;
} }
/// <summary> /// <summary>
@ -80,12 +65,7 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">url.</exception> /// <exception cref="ArgumentNullException">url.</exception>
/// <exception cref="JsonRequestException">Error POST JSON.</exception> /// <exception cref="JsonRequestException">Error POST JSON.</exception>
public static Task<string> PostString( public static Task<String> PostString(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) => SendAsync(HttpMethod.Post, requestUri, payload, authorization, cancellationToken);
Uri requestUri,
object payload,
string? authorization = null,
CancellationToken cancellationToken = default)
=> SendAsync(HttpMethod.Post, requestUri, payload, authorization, cancellationToken);
/// <summary> /// <summary>
/// Puts the specified URL. /// Puts the specified URL.
@ -98,16 +78,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested type. /// A task with a result of the requested type.
/// </returns> /// </returns>
public static async Task<T> Put<T>( public static async Task<T> Put<T>(Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) where T : notnull {
Uri requestUri, String jsonString = await PutString(requestUri, payload, authorization, ct).ConfigureAwait(false);
object payload,
string? authorization = null,
CancellationToken ct = default)
{
var jsonString = await PutString(requestUri, payload, authorization, ct)
.ConfigureAwait(false);
return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default; return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default;
} }
/// <summary> /// <summary>
@ -120,16 +94,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested collection of key/value pairs. /// A task with a result of the requested collection of key/value pairs.
/// </returns> /// </returns>
public static async Task<IDictionary<string, object>?> Put( public static async Task<IDictionary<String, Object>?> Put(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) {
Uri requestUri, Object response = await Put<Object>(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false);
object payload,
string? authorization = null,
CancellationToken cancellationToken = default)
{
var response = await Put<object>(requestUri, payload, authorization, cancellationToken)
.ConfigureAwait(false);
return response as IDictionary<string, object>; return response as IDictionary<String, Object>;
} }
/// <summary> /// <summary>
@ -144,11 +112,7 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">url.</exception> /// <exception cref="ArgumentNullException">url.</exception>
/// <exception cref="JsonRequestException">Error PUT JSON.</exception> /// <exception cref="JsonRequestException">Error PUT JSON.</exception>
public static Task<string> PutString( public static Task<String> PutString(Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) => SendAsync(HttpMethod.Put, requestUri, payload, authorization, ct);
Uri requestUri,
object payload,
string? authorization = null,
CancellationToken ct = default) => SendAsync(HttpMethod.Put, requestUri, payload, authorization, ct);
/// <summary> /// <summary>
/// Gets as string. /// Gets as string.
@ -161,11 +125,7 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">url.</exception> /// <exception cref="ArgumentNullException">url.</exception>
/// <exception cref="JsonRequestException">Error GET JSON.</exception> /// <exception cref="JsonRequestException">Error GET JSON.</exception>
public static Task<string> GetString( public static Task<String> GetString(Uri requestUri, String? authorization = null, CancellationToken ct = default) => GetString(requestUri, null, authorization, ct);
Uri requestUri,
string? authorization = null,
CancellationToken ct = default)
=> GetString(requestUri, null, authorization, ct);
/// <summary> /// <summary>
/// Gets the string. /// Gets the string.
@ -177,17 +137,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested string. /// A task with a result of the requested string.
/// </returns> /// </returns>
public static async Task<string> GetString( public static async Task<String> GetString(Uri uri, IDictionary<String, IEnumerable<String>>? headers, String? authorization = null, CancellationToken ct = default) {
Uri uri, HttpContent response = await GetHttpContent(uri, ct, authorization, headers).ConfigureAwait(false);
IDictionary<string, IEnumerable<string>>? headers,
string? authorization = null,
CancellationToken ct = default)
{
var response = await GetHttpContent(uri, ct, authorization, headers)
.ConfigureAwait(false);
return await response.ReadAsStringAsync() return await response.ReadAsStringAsync().ConfigureAwait(false);
.ConfigureAwait(false);
} }
/// <summary> /// <summary>
@ -201,15 +154,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested type. /// A task with a result of the requested type.
/// </returns> /// </returns>
public static async Task<T> Get<T>( public static async Task<T> Get<T>(Uri requestUri, String? authorization = null, CancellationToken ct = default) where T : notnull {
Uri requestUri, String jsonString = await GetString(requestUri, authorization, ct).ConfigureAwait(false);
string? authorization = null,
CancellationToken ct = default)
{
var jsonString = await GetString(requestUri, authorization, ct)
.ConfigureAwait(false);
return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default; return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default;
} }
/// <summary> /// <summary>
@ -224,16 +172,10 @@
/// <returns> /// <returns>
/// A task with a result of the requested type. /// A task with a result of the requested type.
/// </returns> /// </returns>
public static async Task<T> Get<T>( public static async Task<T> Get<T>(Uri requestUri, IDictionary<String, IEnumerable<String>>? headers, String? authorization = null, CancellationToken ct = default) where T : notnull {
Uri requestUri, String jsonString = await GetString(requestUri, headers, authorization, ct).ConfigureAwait(false);
IDictionary<string, IEnumerable<string>>? headers,
string? authorization = null,
CancellationToken ct = default)
{
var jsonString = await GetString(requestUri, headers, authorization, ct)
.ConfigureAwait(false);
return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default; return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize<T>(jsonString) : default;
} }
/// <summary> /// <summary>
@ -247,16 +189,10 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">url.</exception> /// <exception cref="ArgumentNullException">url.</exception>
/// <exception cref="JsonRequestException">Error GET Binary.</exception> /// <exception cref="JsonRequestException">Error GET Binary.</exception>
public static async Task<byte[]> GetBinary( public static async Task<Byte[]> GetBinary(Uri requestUri, String? authorization = null, CancellationToken ct = default) {
Uri requestUri, HttpContent response = await GetHttpContent(requestUri, ct, authorization).ConfigureAwait(false);
string? authorization = null,
CancellationToken ct = default)
{
var response = await GetHttpContent(requestUri, ct, authorization)
.ConfigureAwait(false);
return await response.ReadAsByteArrayAsync() return await response.ReadAsByteArrayAsync().ConfigureAwait(false);
.ConfigureAwait(false);
} }
/// <summary> /// <summary>
@ -273,26 +209,23 @@
/// or /// or
/// username.</exception> /// username.</exception>
/// <exception cref="SecurityException">Error Authenticating.</exception> /// <exception cref="SecurityException">Error Authenticating.</exception>
public static async Task<IDictionary<string, object>?> Authenticate( public static async Task<IDictionary<String, Object>?> Authenticate(Uri requestUri, String username, String password, CancellationToken ct = default) {
Uri requestUri, if(String.IsNullOrWhiteSpace(username)) {
string username,
string password,
CancellationToken ct = default)
{
if (string.IsNullOrWhiteSpace(username))
throw new ArgumentNullException(nameof(username)); throw new ArgumentNullException(nameof(username));
}
// ignore empty password for now // ignore empty password for now
var content = $"grant_type=password&username={username}&password={password}"; String content = $"grant_type=password&username={username}&password={password}";
using var requestContent = new StringContent(content, Encoding.UTF8, FormType); using StringContent requestContent = new StringContent(content, Encoding.UTF8, FormType);
var response = await HttpClient.PostAsync(requestUri, requestContent, ct).ConfigureAwait(false); HttpResponseMessage response = await HttpClient.PostAsync(requestUri, requestContent, ct).ConfigureAwait(false);
if (!response.IsSuccessStatusCode) if(!response.IsSuccessStatusCode) {
throw new SecurityException($"Error Authenticating. Status code: {response.StatusCode}."); throw new SecurityException($"Error Authenticating. Status code: {response.StatusCode}.");
}
var jsonPayload = await response.Content.ReadAsStringAsync().ConfigureAwait(false); String jsonPayload = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
return Json.Deserialize(jsonPayload) as IDictionary<string, object>; return Json.Deserialize(jsonPayload) as IDictionary<String, Object>;
} }
/// <summary> /// <summary>
@ -306,13 +239,7 @@
/// <returns> /// <returns>
/// A task with a result of the requested string. /// A task with a result of the requested string.
/// </returns> /// </returns>
public static Task<string> PostFileString( public static Task<String> PostFileString(Uri requestUri, Byte[] buffer, String fileName, String? authorization = null, CancellationToken ct = default) => PostString(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct);
Uri requestUri,
byte[] buffer,
string fileName,
string? authorization = null,
CancellationToken ct = default) =>
PostString(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct);
/// <summary> /// <summary>
/// Posts the file. /// Posts the file.
@ -326,13 +253,7 @@
/// <returns> /// <returns>
/// A task with a result of the requested string. /// A task with a result of the requested string.
/// </returns> /// </returns>
public static Task<T> PostFile<T>( public static Task<T> PostFile<T>(Uri requestUri, Byte[] buffer, String fileName, String? authorization = null, CancellationToken ct = default) where T : notnull => Post<T>(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct);
Uri requestUri,
byte[] buffer,
string fileName,
string? authorization = null,
CancellationToken ct = default) =>
Post<T>(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct);
/// <summary> /// <summary>
/// Sends the asynchronous request. /// Sends the asynchronous request.
@ -347,72 +268,46 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">requestUri.</exception> /// <exception cref="ArgumentNullException">requestUri.</exception>
/// <exception cref="JsonRequestException">Error {method} JSON.</exception> /// <exception cref="JsonRequestException">Error {method} JSON.</exception>
public static async Task<string> SendAsync( public static async Task<String> SendAsync(HttpMethod method, Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) {
HttpMethod method, using HttpResponseMessage response = await GetResponse(requestUri, authorization, null, payload, method, ct).ConfigureAwait(false);
Uri requestUri, if(!response.IsSuccessStatusCode) {
object payload,
string? authorization = null,
CancellationToken ct = default)
{
using var response = await GetResponse(requestUri, authorization, null, payload, method, ct).ConfigureAwait(false);
if (!response.IsSuccessStatusCode)
{
throw new JsonRequestException( throw new JsonRequestException(
$"Error {method} JSON", $"Error {method} JSON",
(int)response.StatusCode, (Int32)response.StatusCode,
await response.Content.ReadAsStringAsync().ConfigureAwait(false)); await response.Content.ReadAsStringAsync().ConfigureAwait(false));
} }
return await response.Content.ReadAsStringAsync() return await response.Content.ReadAsStringAsync().ConfigureAwait(false);
.ConfigureAwait(false);
} }
private static async Task<HttpContent> GetHttpContent( private static async Task<HttpContent> GetHttpContent(Uri uri, CancellationToken ct, String? authorization = null, IDictionary<String, IEnumerable<String>>? headers = null) {
Uri uri, HttpResponseMessage response = await GetResponse(uri, authorization, headers, ct: ct).ConfigureAwait(false);
CancellationToken ct,
string? authorization = null,
IDictionary<string, IEnumerable<string>>? headers = null)
{
var response = await GetResponse(uri, authorization, headers, ct: ct)
.ConfigureAwait(false);
return response.IsSuccessStatusCode return response.IsSuccessStatusCode ? response.Content : throw new JsonRequestException("Error GET", (Int32)response.StatusCode);
? response.Content
: throw new JsonRequestException("Error GET", (int)response.StatusCode);
} }
private static async Task<HttpResponseMessage> GetResponse( private static async Task<HttpResponseMessage> GetResponse(Uri uri, String? authorization, IDictionary<String, IEnumerable<String>>? headers, Object? payload = null, HttpMethod? method = default, CancellationToken ct = default) {
Uri uri, if(uri == null) {
string? authorization,
IDictionary<string, IEnumerable<string>>? headers,
object? payload = null,
HttpMethod? method = default,
CancellationToken ct = default)
{
if (uri == null)
throw new ArgumentNullException(nameof(uri)); throw new ArgumentNullException(nameof(uri));
using var requestMessage = new HttpRequestMessage(method ?? HttpMethod.Get, uri);
if (!string.IsNullOrWhiteSpace(authorization))
{
requestMessage.Headers.Authorization
= new AuthenticationHeaderValue("Bearer", authorization);
} }
if (headers != null) using HttpRequestMessage requestMessage = new HttpRequestMessage(method ?? HttpMethod.Get, uri);
{
foreach (var header in headers) if(!String.IsNullOrWhiteSpace(authorization)) {
requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", authorization);
}
if(headers != null) {
foreach(KeyValuePair<String, IEnumerable<String>> header in headers) {
requestMessage.Headers.Add(header.Key, header.Value); requestMessage.Headers.Add(header.Key, header.Value);
} }
}
if (payload != null && requestMessage.Method != HttpMethod.Get) if(payload != null && requestMessage.Method != HttpMethod.Get) {
{
requestMessage.Content = new StringContent(Json.Serialize(payload), Encoding.UTF8, JsonMimeType); requestMessage.Content = new StringContent(Json.Serialize(payload), Encoding.UTF8, JsonMimeType);
} }
return await HttpClient.SendAsync(requestMessage, ct) return await HttpClient.SendAsync(requestMessage, ct).ConfigureAwait(false);
.ConfigureAwait(false);
} }
} }
} }

View File

@ -1,26 +1,21 @@
namespace Swan.Net using System;
{
using System;
namespace Swan.Net {
/// <summary> /// <summary>
/// Represents errors that occurs requesting a JSON file through HTTP. /// Represents errors that occurs requesting a JSON file through HTTP.
/// </summary> /// </summary>
/// <seealso cref="System.Exception" /> /// <seealso cref="System.Exception" />
[Serializable] [Serializable]
public class JsonRequestException public class JsonRequestException : Exception {
: Exception
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="JsonRequestException"/> class. /// Initializes a new instance of the <see cref="JsonRequestException"/> class.
/// </summary> /// </summary>
/// <param name="message">The message.</param> /// <param name="message">The message.</param>
/// <param name="httpErrorCode">The HTTP error code.</param> /// <param name="httpErrorCode">The HTTP error code.</param>
/// <param name="errorContent">Content of the error.</param> /// <param name="errorContent">Content of the error.</param>
public JsonRequestException(string message, int httpErrorCode = 500, string errorContent = null) public JsonRequestException(String message, Int32 httpErrorCode = 500, String errorContent = null) : base(message) {
: base(message) this.HttpErrorCode = httpErrorCode;
{ this.HttpErrorContent = errorContent;
HttpErrorCode = httpErrorCode;
HttpErrorContent = errorContent;
} }
/// <summary> /// <summary>
@ -29,7 +24,9 @@
/// <value> /// <value>
/// The HTTP error code. /// The HTTP error code.
/// </value> /// </value>
public int HttpErrorCode { get; } public Int32 HttpErrorCode {
get;
}
/// <summary> /// <summary>
/// Gets the content of the HTTP error. /// Gets the content of the HTTP error.
@ -37,11 +34,11 @@
/// <value> /// <value>
/// The content of the HTTP error. /// The content of the HTTP error.
/// </value> /// </value>
public string HttpErrorContent { get; } public String HttpErrorContent {
get;
}
/// <inheritdoc /> /// <inheritdoc />
public override string ToString() => string.IsNullOrEmpty(HttpErrorContent) public override String ToString() => String.IsNullOrEmpty(this.HttpErrorContent) ? $"HTTP Response Status Code {this.HttpErrorCode} Error Message: {this.HttpErrorContent}" : base.ToString();
? $"HTTP Response Status Code {HttpErrorCode} Error Message: {HttpErrorContent}"
: base.ToString();
} }
} }

View File

@ -1,6 +1,4 @@
namespace Swan.Net using Swan.Net.Dns;
{
using Net.Dns;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@ -11,21 +9,21 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Net {
/// <summary> /// <summary>
/// Provides miscellaneous network utilities such as a Public IP finder, /// Provides miscellaneous network utilities such as a Public IP finder,
/// a DNS client to query DNS records of any kind, and an NTP client. /// a DNS client to query DNS records of any kind, and an NTP client.
/// </summary> /// </summary>
public static class Network public static class Network {
{
/// <summary> /// <summary>
/// The DNS default port. /// The DNS default port.
/// </summary> /// </summary>
public const int DnsDefaultPort = 53; public const Int32 DnsDefaultPort = 53;
/// <summary> /// <summary>
/// The NTP default port. /// The NTP default port.
/// </summary> /// </summary>
public const int NtpDefaultPort = 123; public const Int32 NtpDefaultPort = 123;
/// <summary> /// <summary>
/// Gets the name of the host. /// Gets the name of the host.
@ -33,7 +31,7 @@
/// <value> /// <value>
/// The name of the host. /// The name of the host.
/// </value> /// </value>
public static string HostName => IPGlobalProperties.GetIPGlobalProperties().HostName; public static String HostName => IPGlobalProperties.GetIPGlobalProperties().HostName;
/// <summary> /// <summary>
/// Gets the name of the network domain. /// Gets the name of the network domain.
@ -41,7 +39,7 @@
/// <value> /// <value>
/// The name of the network domain. /// The name of the network domain.
/// </value> /// </value>
public static string DomainName => IPGlobalProperties.GetIPGlobalProperties().DomainName; public static String DomainName => IPGlobalProperties.GetIPGlobalProperties().DomainName;
#region IP Addresses and Adapters Information Methods #region IP Addresses and Adapters Information Methods
@ -53,31 +51,19 @@
/// A collection of NetworkInterface/IPInterfaceProperties pairs /// A collection of NetworkInterface/IPInterfaceProperties pairs
/// that represents the active IPv4 interfaces. /// that represents the active IPv4 interfaces.
/// </returns> /// </returns>
public static Dictionary<NetworkInterface, IPInterfaceProperties> GetIPv4Interfaces() public static Dictionary<NetworkInterface, IPInterfaceProperties> GetIPv4Interfaces() {
{
// zero conf ip address // zero conf ip address
var zeroConf = new IPAddress(0); IPAddress zeroConf = new IPAddress(0);
var adapters = NetworkInterface.GetAllNetworkInterfaces() NetworkInterface[] adapters = NetworkInterface.GetAllNetworkInterfaces().Where(network => network.OperationalStatus == OperationalStatus.Up && network.NetworkInterfaceType != NetworkInterfaceType.Unknown && network.NetworkInterfaceType != NetworkInterfaceType.Loopback).ToArray();
.Where(network =>
network.OperationalStatus == OperationalStatus.Up
&& network.NetworkInterfaceType != NetworkInterfaceType.Unknown
&& network.NetworkInterfaceType != NetworkInterfaceType.Loopback)
.ToArray();
var result = new Dictionary<NetworkInterface, IPInterfaceProperties>(); Dictionary<NetworkInterface, IPInterfaceProperties> result = new Dictionary<NetworkInterface, IPInterfaceProperties>();
foreach (var adapter in adapters) foreach(NetworkInterface adapter in adapters) {
{ IPInterfaceProperties properties = adapter.GetIPProperties();
var properties = adapter.GetIPProperties(); if(properties == null || properties.GatewayAddresses.Count == 0 || properties.GatewayAddresses.All(gateway => Equals(gateway.Address, zeroConf)) || properties.UnicastAddresses.Count == 0 || properties.GatewayAddresses.All(address => Equals(address.Address, zeroConf)) || properties.UnicastAddresses.Any(a => a.Address.AddressFamily == AddressFamily.InterNetwork) == false) {
if (properties == null
|| properties.GatewayAddresses.Count == 0
|| properties.GatewayAddresses.All(gateway => Equals(gateway.Address, zeroConf))
|| properties.UnicastAddresses.Count == 0
|| properties.GatewayAddresses.All(address => Equals(address.Address, zeroConf))
|| properties.UnicastAddresses.Any(a => a.Address.AddressFamily == AddressFamily.InterNetwork) ==
false)
continue; continue;
}
result[adapter] = properties; result[adapter] = properties;
} }
@ -90,8 +76,7 @@
/// </summary> /// </summary>
/// <param name="includeLoopback">if set to <c>true</c> [include loopback].</param> /// <param name="includeLoopback">if set to <c>true</c> [include loopback].</param>
/// <returns>An array of local ip addresses.</returns> /// <returns>An array of local ip addresses.</returns>
public static IPAddress[] GetIPv4Addresses(bool includeLoopback = true) => public static IPAddress[] GetIPv4Addresses(Boolean includeLoopback = true) => GetIPv4Addresses(NetworkInterfaceType.Unknown, true, includeLoopback);
GetIPv4Addresses(NetworkInterfaceType.Unknown, true, includeLoopback);
/// <summary> /// <summary>
/// Retrieves the local ip addresses. /// Retrieves the local ip addresses.
@ -100,35 +85,24 @@
/// <param name="skipTypeFilter">if set to <c>true</c> [skip type filter].</param> /// <param name="skipTypeFilter">if set to <c>true</c> [skip type filter].</param>
/// <param name="includeLoopback">if set to <c>true</c> [include loopback].</param> /// <param name="includeLoopback">if set to <c>true</c> [include loopback].</param>
/// <returns>An array of local ip addresses.</returns> /// <returns>An array of local ip addresses.</returns>
public static IPAddress[] GetIPv4Addresses( public static IPAddress[] GetIPv4Addresses(NetworkInterfaceType interfaceType, Boolean skipTypeFilter = false, Boolean includeLoopback = false) {
NetworkInterfaceType interfaceType, List<IPAddress> addressList = new List<IPAddress>();
bool skipTypeFilter = false, NetworkInterface[] interfaces = NetworkInterface.GetAllNetworkInterfaces()
bool includeLoopback = false) .Where(ni => (skipTypeFilter || ni.NetworkInterfaceType == interfaceType) && ni.OperationalStatus == OperationalStatus.Up).ToArray();
{
var addressList = new List<IPAddress>();
var interfaces = NetworkInterface.GetAllNetworkInterfaces()
.Where(ni =>
#if NET461
ni.IsReceiveOnly == false &&
#endif
(skipTypeFilter || ni.NetworkInterfaceType == interfaceType) &&
ni.OperationalStatus == OperationalStatus.Up)
.ToArray();
foreach (var networkInterface in interfaces) foreach(NetworkInterface networkInterface in interfaces) {
{ IPInterfaceProperties properties = networkInterface.GetIPProperties();
var properties = networkInterface.GetIPProperties();
if (properties.GatewayAddresses.All(g => g.Address.AddressFamily != AddressFamily.InterNetwork)) if(properties.GatewayAddresses.All(g => g.Address.AddressFamily != AddressFamily.InterNetwork)) {
continue; continue;
addressList.AddRange(properties.UnicastAddresses
.Where(i => i.Address.AddressFamily == AddressFamily.InterNetwork)
.Select(i => i.Address));
} }
if (includeLoopback || interfaceType == NetworkInterfaceType.Loopback) addressList.AddRange(properties.UnicastAddresses.Where(i => i.Address.AddressFamily == AddressFamily.InterNetwork).Select(i => i.Address));
}
if(includeLoopback || interfaceType == NetworkInterfaceType.Loopback) {
addressList.Add(IPAddress.Loopback); addressList.Add(IPAddress.Loopback);
}
return addressList.ToArray(); return addressList.ToArray();
} }
@ -138,10 +112,9 @@
/// </summary> /// </summary>
/// <param name="cancellationToken">The cancellation token.</param> /// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A public IP address of the result produced by this Task.</returns> /// <returns>A public IP address of the result produced by this Task.</returns>
public static async Task<IPAddress> GetPublicIPAddressAsync(CancellationToken cancellationToken = default) public static async Task<IPAddress> GetPublicIPAddressAsync(CancellationToken cancellationToken = default) {
{ using HttpClient client = new HttpClient();
using var client = new HttpClient(); HttpResponseMessage response = await client.GetAsync("https://api.ipify.org", cancellationToken).ConfigureAwait(false);
var response = await client.GetAsync("https://api.ipify.org", cancellationToken).ConfigureAwait(false);
return IPAddress.Parse(await response.Content.ReadAsStringAsync().ConfigureAwait(false)); return IPAddress.Parse(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
} }
@ -152,11 +125,7 @@
/// A collection of NetworkInterface/IPInterfaceProperties pairs /// A collection of NetworkInterface/IPInterfaceProperties pairs
/// that represents the active IPv4 interfaces. /// that represents the active IPv4 interfaces.
/// </returns> /// </returns>
public static IPAddress[] GetIPv4DnsServers() public static IPAddress[] GetIPv4DnsServers() => GetIPv4Interfaces().Select(a => a.Value.DnsAddresses.Where(d => d.AddressFamily == AddressFamily.InterNetwork)).SelectMany(d => d).ToArray();
=> GetIPv4Interfaces()
.Select(a => a.Value.DnsAddresses.Where(d => d.AddressFamily == AddressFamily.InterNetwork))
.SelectMany(d => d)
.ToArray();
#endregion #endregion
@ -167,9 +136,8 @@
/// </summary> /// </summary>
/// <param name="fqdn">The FQDN.</param> /// <param name="fqdn">The FQDN.</param>
/// <returns>An array of local ip addresses of the result produced by this task.</returns> /// <returns>An array of local ip addresses of the result produced by this task.</returns>
public static Task<IPAddress[]> GetDnsHostEntryAsync(string fqdn) public static Task<IPAddress[]> GetDnsHostEntryAsync(String fqdn) {
{ IPAddress dnsServer = GetIPv4DnsServers().FirstOrDefault() ?? IPAddress.Parse("8.8.8.8");
var dnsServer = GetIPv4DnsServers().FirstOrDefault() ?? IPAddress.Parse("8.8.8.8");
return GetDnsHostEntryAsync(fqdn, dnsServer, DnsDefaultPort); return GetDnsHostEntryAsync(fqdn, dnsServer, DnsDefaultPort);
} }
@ -183,25 +151,25 @@
/// An array of local ip addresses of the result produced by this task. /// An array of local ip addresses of the result produced by this task.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">fqdn.</exception> /// <exception cref="ArgumentNullException">fqdn.</exception>
public static async Task<IPAddress[]> GetDnsHostEntryAsync(string fqdn, IPAddress dnsServer, int port) public static async Task<IPAddress[]> GetDnsHostEntryAsync(String fqdn, IPAddress dnsServer, Int32 port) {
{ if(fqdn == null) {
if (fqdn == null)
throw new ArgumentNullException(nameof(fqdn)); throw new ArgumentNullException(nameof(fqdn));
}
if (fqdn.IndexOf(".", StringComparison.Ordinal) == -1) if(fqdn.IndexOf(".", StringComparison.Ordinal) == -1) {
{
fqdn += "." + IPGlobalProperties.GetIPGlobalProperties().DomainName; fqdn += "." + IPGlobalProperties.GetIPGlobalProperties().DomainName;
} }
while (true) while(true) {
{ if(!fqdn.EndsWith(".", StringComparison.OrdinalIgnoreCase)) {
if (!fqdn.EndsWith(".", StringComparison.OrdinalIgnoreCase)) break; break;
fqdn = fqdn.Substring(0, fqdn.Length - 1);
} }
var client = new DnsClient(dnsServer, port); fqdn = fqdn[0..^1];
var result = await client.Lookup(fqdn).ConfigureAwait(false); }
DnsClient client = new DnsClient(dnsServer, port);
IList<IPAddress> result = await client.Lookup(fqdn).ConfigureAwait(false);
return result.ToArray(); return result.ToArray();
} }
@ -212,9 +180,8 @@
/// <param name="dnsServer">The DNS server.</param> /// <param name="dnsServer">The DNS server.</param>
/// <param name="port">The port.</param> /// <param name="port">The port.</param>
/// <returns>A <see cref="System.String" /> that represents the current object.</returns> /// <returns>A <see cref="System.String" /> that represents the current object.</returns>
public static Task<string> GetDnsPointerEntryAsync(IPAddress query, IPAddress dnsServer, int port) public static Task<String> GetDnsPointerEntryAsync(IPAddress query, IPAddress dnsServer, Int32 port) {
{ DnsClient client = new DnsClient(dnsServer, port);
var client = new DnsClient(dnsServer, port);
return client.Reverse(query); return client.Reverse(query);
} }
@ -223,9 +190,8 @@
/// </summary> /// </summary>
/// <param name="query">The query.</param> /// <param name="query">The query.</param>
/// <returns>A <see cref="System.String" /> that represents the current object.</returns> /// <returns>A <see cref="System.String" /> that represents the current object.</returns>
public static Task<string> GetDnsPointerEntryAsync(IPAddress query) public static Task<String> GetDnsPointerEntryAsync(IPAddress query) {
{ DnsClient client = new DnsClient(GetIPv4DnsServers().FirstOrDefault());
var client = new DnsClient(GetIPv4DnsServers().FirstOrDefault());
return client.Reverse(query); return client.Reverse(query);
} }
@ -237,13 +203,13 @@
/// <param name="dnsServer">The DNS server.</param> /// <param name="dnsServer">The DNS server.</param>
/// <param name="port">The port.</param> /// <param name="port">The port.</param>
/// <returns>Queries the DNS server for the specified record type of the result produced by this Task.</returns> /// <returns>Queries the DNS server for the specified record type of the result produced by this Task.</returns>
public static async Task<DnsQueryResult> QueryDnsAsync(string query, DnsRecordType recordType, IPAddress dnsServer, int port) public static async Task<DnsQueryResult> QueryDnsAsync(String query, DnsRecordType recordType, IPAddress dnsServer, Int32 port) {
{ if(query == null) {
if (query == null)
throw new ArgumentNullException(nameof(query)); throw new ArgumentNullException(nameof(query));
}
var client = new DnsClient(dnsServer, port); DnsClient client = new DnsClient(dnsServer, port);
var response = await client.Resolve(query, recordType).ConfigureAwait(false); DnsClient.DnsClientResponse response = await client.Resolve(query, recordType).ConfigureAwait(false);
return new DnsQueryResult(response); return new DnsQueryResult(response);
} }
@ -253,7 +219,7 @@
/// <param name="query">The query.</param> /// <param name="query">The query.</param>
/// <param name="recordType">Type of the record.</param> /// <param name="recordType">Type of the record.</param>
/// <returns>Queries the DNS server for the specified record type of the result produced by this Task.</returns> /// <returns>Queries the DNS server for the specified record type of the result produced by this Task.</returns>
public static Task<DnsQueryResult> QueryDnsAsync(string query, DnsRecordType recordType) => QueryDnsAsync(query, recordType, GetIPv4DnsServers().FirstOrDefault(), DnsDefaultPort); public static Task<DnsQueryResult> QueryDnsAsync(String query, DnsRecordType recordType) => QueryDnsAsync(query, recordType, GetIPv4DnsServers().FirstOrDefault(), DnsDefaultPort);
/// <summary> /// <summary>
/// Gets the UTC time by querying from an NTP server. /// Gets the UTC time by querying from an NTP server.
@ -261,53 +227,50 @@
/// <param name="ntpServerAddress">The NTP server address.</param> /// <param name="ntpServerAddress">The NTP server address.</param>
/// <param name="port">The port.</param> /// <param name="port">The port.</param>
/// <returns>The UTC time by querying from an NTP server of the result produced by this Task.</returns> /// <returns>The UTC time by querying from an NTP server of the result produced by this Task.</returns>
public static async Task<DateTime> GetNetworkTimeUtcAsync(IPAddress ntpServerAddress, int port = NtpDefaultPort) public static async Task<DateTime> GetNetworkTimeUtcAsync(IPAddress ntpServerAddress, Int32 port = NtpDefaultPort) {
{ if(ntpServerAddress == null) {
if (ntpServerAddress == null)
throw new ArgumentNullException(nameof(ntpServerAddress)); throw new ArgumentNullException(nameof(ntpServerAddress));
}
// NTP message size - 16 bytes of the digest (RFC 2030) // NTP message size - 16 bytes of the digest (RFC 2030)
var ntpData = new byte[48]; Byte[] ntpData = new Byte[48];
// Setting the Leap Indicator, Version Number and Mode values // Setting the Leap Indicator, Version Number and Mode values
ntpData[0] = 0x1B; // LI = 0 (no warning), VN = 3 (IPv4 only), Mode = 3 (Client Mode) ntpData[0] = 0x1B; // LI = 0 (no warning), VN = 3 (IPv4 only), Mode = 3 (Client Mode)
// The UDP port number assigned to NTP is 123 // The UDP port number assigned to NTP is 123
var endPoint = new IPEndPoint(ntpServerAddress, port); IPEndPoint endPoint = new IPEndPoint(ntpServerAddress, port);
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
#if !NET461
await socket.ConnectAsync(endPoint).ConfigureAwait(false); await socket.ConnectAsync(endPoint).ConfigureAwait(false);
#else
socket.Connect(endPoint);
#endif
socket.ReceiveTimeout = 3000; // Stops code hang if NTP is blocked socket.ReceiveTimeout = 3000; // Stops code hang if NTP is blocked
socket.Send(ntpData); _ = socket.Send(ntpData);
socket.Receive(ntpData); _ = socket.Receive(ntpData);
socket.Dispose(); socket.Dispose();
// Offset to get to the "Transmit Timestamp" field (time at which the reply // Offset to get to the "Transmit Timestamp" field (time at which the reply
// departed the server for the client, in 64-bit timestamp format." // departed the server for the client, in 64-bit timestamp format."
const byte serverReplyTime = 40; const Byte serverReplyTime = 40;
// Get the seconds part // Get the seconds part
ulong intPart = BitConverter.ToUInt32(ntpData, serverReplyTime); UInt64 intPart = BitConverter.ToUInt32(ntpData, serverReplyTime);
// Get the seconds fraction // Get the seconds fraction
ulong fractPart = BitConverter.ToUInt32(ntpData, serverReplyTime + 4); UInt64 fractPart = BitConverter.ToUInt32(ntpData, serverReplyTime + 4);
// Convert From big-endian to little-endian to match the platform // Convert From big-endian to little-endian to match the platform
if (BitConverter.IsLittleEndian) if(BitConverter.IsLittleEndian) {
{
intPart = intPart.SwapEndianness(); intPart = intPart.SwapEndianness();
fractPart = intPart.SwapEndianness(); fractPart = intPart.SwapEndianness();
} }
var milliseconds = (intPart * 1000) + ((fractPart * 1000) / 0x100000000L); UInt64 milliseconds = intPart * 1000 + fractPart * 1000 / 0x100000000L;
// The time is given in UTC // The time is given in UTC
return new DateTime(1900, 1, 1, 0, 0, 0, DateTimeKind.Utc).AddMilliseconds((long) milliseconds); return new DateTime(1900, 1, 1, 0, 0, 0, DateTimeKind.Utc).AddMilliseconds((Int64)milliseconds);
} }
/// <summary> /// <summary>
@ -316,10 +279,8 @@
/// <param name="ntpServerName">The NTP server, by default pool.ntp.org.</param> /// <param name="ntpServerName">The NTP server, by default pool.ntp.org.</param>
/// <param name="port">The port, by default NTP 123.</param> /// <param name="port">The port, by default NTP 123.</param>
/// <returns>The UTC time by querying from an NTP server of the result produced by this Task.</returns> /// <returns>The UTC time by querying from an NTP server of the result produced by this Task.</returns>
public static async Task<DateTime> GetNetworkTimeUtcAsync(string ntpServerName = "pool.ntp.org", public static async Task<DateTime> GetNetworkTimeUtcAsync(String ntpServerName = "pool.ntp.org", Int32 port = NtpDefaultPort) {
int port = NtpDefaultPort) IPAddress[] addresses = await GetDnsHostEntryAsync(ntpServerName).ConfigureAwait(false);
{
var addresses = await GetDnsHostEntryAsync(ntpServerName).ConfigureAwait(false);
return await GetNetworkTimeUtcAsync(addresses.First(), port).ConfigureAwait(false); return await GetNetworkTimeUtcAsync(addresses.First(), port).ConfigureAwait(false);
} }

View File

@ -1,11 +1,9 @@
// ReSharper disable InconsistentNaming // ReSharper disable InconsistentNaming
namespace Swan.Net.Smtp namespace Swan.Net.Smtp {
{
/// <summary> /// <summary>
/// Enumerates all of the well-known SMTP command names. /// Enumerates all of the well-known SMTP command names.
/// </summary> /// </summary>
public enum SmtpCommandNames public enum SmtpCommandNames {
{
/// <summary> /// <summary>
/// An unknown command /// An unknown command
/// </summary> /// </summary>
@ -95,8 +93,7 @@ namespace Swan.Net.Smtp
/// <summary> /// <summary>
/// Enumerates the reply code severities. /// Enumerates the reply code severities.
/// </summary> /// </summary>
public enum SmtpReplyCodeSeverities public enum SmtpReplyCodeSeverities {
{
/// <summary> /// <summary>
/// The unknown severity /// The unknown severity
/// </summary> /// </summary>
@ -126,8 +123,7 @@ namespace Swan.Net.Smtp
/// <summary> /// <summary>
/// Enumerates the reply code categories. /// Enumerates the reply code categories.
/// </summary> /// </summary>
public enum SmtpReplyCodeCategories public enum SmtpReplyCodeCategories {
{
/// <summary> /// <summary>
/// The unknown category /// The unknown category
/// </summary> /// </summary>

View File

@ -1,5 +1,4 @@
namespace Swan.Net.Smtp #nullable enable
{
using System.Threading; using System.Threading;
using System; using System;
using System.Linq; using System.Linq;
@ -12,6 +11,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Net.Mail; using System.Net.Mail;
namespace Swan.Net.Smtp {
/// <summary> /// <summary>
/// Represents a basic SMTP client that is capable of submitting messages to an SMTP server. /// Represents a basic SMTP client that is capable of submitting messages to an SMTP server.
/// </summary> /// </summary>
@ -92,19 +92,17 @@
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public class SmtpClient public class SmtpClient {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="SmtpClient" /> class. /// Initializes a new instance of the <see cref="SmtpClient" /> class.
/// </summary> /// </summary>
/// <param name="host">The host.</param> /// <param name="host">The host.</param>
/// <param name="port">The port.</param> /// <param name="port">The port.</param>
/// <exception cref="ArgumentNullException">host.</exception> /// <exception cref="ArgumentNullException">host.</exception>
public SmtpClient(string host, int port) public SmtpClient(String host, Int32 port) {
{ this.Host = host ?? throw new ArgumentNullException(nameof(host));
Host = host ?? throw new ArgumentNullException(nameof(host)); this.Port = port;
Port = port; this.ClientHostname = Network.HostName;
ClientHostname = Network.HostName;
} }
/// <summary> /// <summary>
@ -113,7 +111,9 @@
/// <value> /// <value>
/// The credentials. /// The credentials.
/// </value> /// </value>
public NetworkCredential Credentials { get; set; } public NetworkCredential? Credentials {
get; set;
}
/// <summary> /// <summary>
/// Gets the host. /// Gets the host.
@ -121,7 +121,9 @@
/// <value> /// <value>
/// The host. /// The host.
/// </value> /// </value>
public string Host { get; } public String Host {
get;
}
/// <summary> /// <summary>
/// Gets the port. /// Gets the port.
@ -129,7 +131,9 @@
/// <value> /// <value>
/// The port. /// The port.
/// </value> /// </value>
public int Port { get; } public Int32 Port {
get;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the SSL is enabled. /// Gets or sets a value indicating whether the SSL is enabled.
@ -138,7 +142,9 @@
/// <value> /// <value>
/// <c>true</c> if [enable SSL]; otherwise, <c>false</c>. /// <c>true</c> if [enable SSL]; otherwise, <c>false</c>.
/// </value> /// </value>
public bool EnableSsl { get; set; } public Boolean EnableSsl {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the name of the client that gets announced to the server. /// Gets or sets the name of the client that gets announced to the server.
@ -146,7 +152,10 @@
/// <value> /// <value>
/// The client hostname. /// The client hostname.
/// </value> /// </value>
public string ClientHostname { get; set; } public String ClientHostname {
get; set;
}
/// <summary> /// <summary>
/// Sends an email message asynchronously. /// Sends an email message asynchronously.
@ -159,37 +168,31 @@
/// A task that represents the asynchronous of send email operation. /// A task that represents the asynchronous of send email operation.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">message.</exception> /// <exception cref="ArgumentNullException">message.</exception>
public Task SendMailAsync( [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0067:Objekte verwerfen, bevor Bereich verloren geht", Justification = "<Ausstehend>")]
MailMessage message, public Task SendMailAsync(MailMessage message, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) {
string? sessionId = null, if(message == null) {
RemoteCertificateValidationCallback? callback = null,
CancellationToken cancellationToken = default)
{
if (message == null)
throw new ArgumentNullException(nameof(message)); throw new ArgumentNullException(nameof(message));
}
var state = new SmtpSessionState SmtpSessionState state = new SmtpSessionState {
{ AuthMode = this.Credentials == null ? String.Empty : SmtpDefinitions.SmtpAuthMethods.Login,
AuthMode = Credentials == null ? string.Empty : SmtpDefinitions.SmtpAuthMethods.Login,
ClientHostname = ClientHostname, ClientHostname = ClientHostname,
IsChannelSecure = EnableSsl, IsChannelSecure = EnableSsl,
SenderAddress = message.From.Address, SenderAddress = message.From.Address,
}; };
if (Credentials != null) if(this.Credentials != null) {
{ state.Username = this.Credentials.UserName;
state.Username = Credentials.UserName; state.Password = this.Credentials.Password;
state.Password = Credentials.Password;
} }
foreach (var recipient in message.To) foreach(MailAddress recipient in message.To) {
{
state.Recipients.Add(recipient.Address); state.Recipients.Add(recipient.Address);
} }
state.DataBuffer.AddRange(message.ToMimeMessage().ToArray()); state.DataBuffer.AddRange(message.ToMimeMessage().ToArray());
return SendMailAsync(state, sessionId, callback, cancellationToken); return this.SendMailAsync(state, sessionId, callback, cancellationToken);
} }
/// <summary> /// <summary>
@ -205,16 +208,12 @@
/// A task that represents the asynchronous of send email operation. /// A task that represents the asynchronous of send email operation.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">sessionState.</exception> /// <exception cref="ArgumentNullException">sessionState.</exception>
public Task SendMailAsync( public Task SendMailAsync(SmtpSessionState sessionState, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) {
SmtpSessionState sessionState, if(sessionState == null) {
string? sessionId = null,
RemoteCertificateValidationCallback? callback = null,
CancellationToken cancellationToken = default)
{
if (sessionState == null)
throw new ArgumentNullException(nameof(sessionState)); throw new ArgumentNullException(nameof(sessionState));
}
return SendMailAsync(new[] { sessionState }, sessionId, callback, cancellationToken); return this.SendMailAsync(new[] { sessionState }, sessionId, callback, cancellationToken);
} }
/// <summary> /// <summary>
@ -232,54 +231,47 @@
/// <exception cref="ArgumentNullException">sessionStates.</exception> /// <exception cref="ArgumentNullException">sessionStates.</exception>
/// <exception cref="SecurityException">Could not upgrade the channel to SSL.</exception> /// <exception cref="SecurityException">Could not upgrade the channel to SSL.</exception>
/// <exception cref="SmtpException">Defines an SMTP Exceptions class.</exception> /// <exception cref="SmtpException">Defines an SMTP Exceptions class.</exception>
public async Task SendMailAsync( public async Task SendMailAsync(IEnumerable<SmtpSessionState> sessionStates, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) {
IEnumerable<SmtpSessionState> sessionStates, if(sessionStates == null) {
string? sessionId = null,
RemoteCertificateValidationCallback? callback = null,
CancellationToken cancellationToken = default)
{
if (sessionStates == null)
throw new ArgumentNullException(nameof(sessionStates)); throw new ArgumentNullException(nameof(sessionStates));
}
using var tcpClient = new TcpClient(); using TcpClient tcpClient = new TcpClient();
await tcpClient.ConnectAsync(Host, Port).ConfigureAwait(false); await tcpClient.ConnectAsync(this.Host, this.Port).ConfigureAwait(false);
using var connection = new Connection(tcpClient, Encoding.UTF8, "\r\n", true, 1000); using Connection connection = new Connection(tcpClient, Encoding.UTF8, "\r\n", true, 1000);
var sender = new SmtpSender(sessionId); SmtpSender sender = new SmtpSender(sessionId);
try try {
{
// Read the greeting message // Read the greeting message
sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false);
// EHLO 1 // EHLO 1
await SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); await this.SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false);
// STARTTLS // STARTTLS
if (EnableSsl) if(this.EnableSsl) {
{
sender.RequestText = $"{SmtpCommandNames.STARTTLS}"; sender.RequestText = $"{SmtpCommandNames.STARTTLS}";
await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false);
sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false);
sender.ValidateReply(); sender.ValidateReply();
if (await connection.UpgradeToSecureAsClientAsync(callback: callback).ConfigureAwait(false) == false) if(await connection.UpgradeToSecureAsClientAsync(callback: callback).ConfigureAwait(false) == false) {
throw new SecurityException("Could not upgrade the channel to SSL."); throw new SecurityException("Could not upgrade the channel to SSL.");
} }
}
// EHLO 2 // EHLO 2
await SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); await this.SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false);
// AUTH // AUTH
if (Credentials != null) if(this.Credentials != null) {
{ ConnectionAuth auth = new ConnectionAuth(connection, sender, this.Credentials);
var auth = new ConnectionAuth(connection, sender, Credentials);
await auth.AuthenticateAsync(cancellationToken).ConfigureAwait(false); await auth.AuthenticateAsync(cancellationToken).ConfigureAwait(false);
} }
foreach (var sessionState in sessionStates) foreach(SmtpSessionState sessionState in sessionStates) {
{
{ {
// MAIL FROM // MAIL FROM
sender.RequestText = $"{SmtpCommandNames.MAIL} FROM:<{sessionState.SenderAddress}>"; sender.RequestText = $"{SmtpCommandNames.MAIL} FROM:<{sessionState.SenderAddress}>";
@ -290,8 +282,7 @@
} }
// RCPT TO // RCPT TO
foreach (var recipient in sessionState.Recipients) foreach(String recipient in sessionState.Recipients) {
{
sender.RequestText = $"{SmtpCommandNames.RCPT} TO:<{recipient}>"; sender.RequestText = $"{SmtpCommandNames.RCPT} TO:<{recipient}>";
await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false);
@ -310,16 +301,15 @@
{ {
// CONTENT // CONTENT
var dataTerminator = sessionState.DataBuffer String dataTerminator = sessionState.DataBuffer.Skip(sessionState.DataBuffer.Count - 5).ToText();
.Skip(sessionState.DataBuffer.Count - 5)
.ToText();
sender.RequestText = $"Buffer ({sessionState.DataBuffer.Count} bytes)"; sender.RequestText = $"Buffer ({sessionState.DataBuffer.Count} bytes)";
await connection.WriteDataAsync(sessionState.DataBuffer.ToArray(), true, cancellationToken).ConfigureAwait(false); await connection.WriteDataAsync(sessionState.DataBuffer.ToArray(), true, cancellationToken).ConfigureAwait(false);
if (!dataTerminator.EndsWith(SmtpDefinitions.SmtpDataCommandTerminator)) if(!dataTerminator.EndsWith(SmtpDefinitions.SmtpDataCommandTerminator)) {
await connection.WriteTextAsync(SmtpDefinitions.SmtpDataCommandTerminator, cancellationToken).ConfigureAwait(false); await connection.WriteTextAsync(SmtpDefinitions.SmtpDataCommandTerminator, cancellationToken).ConfigureAwait(false);
}
sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false);
sender.ValidateReply(); sender.ValidateReply();
@ -334,21 +324,17 @@
sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false);
sender.ValidateReply(); sender.ValidateReply();
} }
} } catch(Exception ex) {
catch (Exception ex)
{
throw new SmtpException($"Could not send email - Session ID {sessionId}. {ex.Message}\r\n Last Request: {sender.RequestText}\r\n Last Reply: {sender.ReplyText}"); throw new SmtpException($"Could not send email - Session ID {sessionId}. {ex.Message}\r\n Last Request: {sender.RequestText}\r\n Last Reply: {sender.ReplyText}");
} }
} }
private async Task SendEhlo(SmtpSender sender, Connection connection, CancellationToken cancellationToken) private async Task SendEhlo(SmtpSender sender, Connection connection, CancellationToken cancellationToken) {
{ sender.RequestText = $"{SmtpCommandNames.EHLO} {this.ClientHostname}";
sender.RequestText = $"{SmtpCommandNames.EHLO} {ClientHostname}";
await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false);
do do {
{
sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false);
} }
while(!sender.IsReplyOk); while(!sender.IsReplyOk);
@ -356,32 +342,28 @@
sender.ValidateReply(); sender.ValidateReply();
} }
private class ConnectionAuth private class ConnectionAuth {
{
private readonly SmtpSender _sender; private readonly SmtpSender _sender;
private readonly Connection _connection; private readonly Connection _connection;
private readonly NetworkCredential _credentials; private readonly NetworkCredential _credentials;
public ConnectionAuth(Connection connection, SmtpSender sender, NetworkCredential credentials) public ConnectionAuth(Connection connection, SmtpSender sender, NetworkCredential credentials) {
{ this._connection = connection;
_connection = connection; this._sender = sender;
_sender = sender; this._credentials = credentials;
_credentials = credentials;
} }
public async Task AuthenticateAsync(CancellationToken ct) public async Task AuthenticateAsync(CancellationToken ct) {
{ this._sender.RequestText = $"{SmtpCommandNames.AUTH} {SmtpDefinitions.SmtpAuthMethods.Login} {Convert.ToBase64String(Encoding.UTF8.GetBytes(this._credentials.UserName))}";
_sender.RequestText =
$"{SmtpCommandNames.AUTH} {SmtpDefinitions.SmtpAuthMethods.Login} {Convert.ToBase64String(Encoding.UTF8.GetBytes(_credentials.UserName))}";
await _connection.WriteLineAsync(_sender.RequestText, ct).ConfigureAwait(false); await this._connection.WriteLineAsync(this._sender.RequestText, ct).ConfigureAwait(false);
_sender.ReplyText = await _connection.ReadLineAsync(ct).ConfigureAwait(false); this._sender.ReplyText = await this._connection.ReadLineAsync(ct).ConfigureAwait(false);
_sender.ValidateReply(); this._sender.ValidateReply();
_sender.RequestText = Convert.ToBase64String(Encoding.UTF8.GetBytes(_credentials.Password)); this._sender.RequestText = Convert.ToBase64String(Encoding.UTF8.GetBytes(this._credentials.Password));
await _connection.WriteLineAsync(_sender.RequestText, ct).ConfigureAwait(false); await this._connection.WriteLineAsync(this._sender.RequestText, ct).ConfigureAwait(false);
_sender.ReplyText = await _connection.ReadLineAsync(ct).ConfigureAwait(false); this._sender.ReplyText = await this._connection.ReadLineAsync(ct).ConfigureAwait(false);
_sender.ValidateReply(); this._sender.ValidateReply();
} }
} }
} }

View File

@ -1,29 +1,28 @@
namespace Swan.Net.Smtp using System;
{
namespace Swan.Net.Smtp {
/// <summary> /// <summary>
/// Contains useful constants and definitions. /// Contains useful constants and definitions.
/// </summary> /// </summary>
public static class SmtpDefinitions public static class SmtpDefinitions {
{
/// <summary> /// <summary>
/// The string sequence that delimits the end of the DATA command. /// The string sequence that delimits the end of the DATA command.
/// </summary> /// </summary>
public const string SmtpDataCommandTerminator = "\r\n.\r\n"; public const String SmtpDataCommandTerminator = "\r\n.\r\n";
/// <summary> /// <summary>
/// Lists the AUTH methods supported by default. /// Lists the AUTH methods supported by default.
/// </summary> /// </summary>
public static class SmtpAuthMethods public static class SmtpAuthMethods {
{
/// <summary> /// <summary>
/// The plain method. /// The plain method.
/// </summary> /// </summary>
public const string Plain = "PLAIN"; public const String Plain = "PLAIN";
/// <summary> /// <summary>
/// The login method. /// The login method.
/// </summary> /// </summary>
public const string Login = "LOGIN"; public const String Login = "LOGIN";
} }
} }
} }

View File

@ -1,59 +1,52 @@
namespace Swan.Net.Smtp using Swan.Logging;
{
using Logging;
using System; using System;
using System.Linq; using System.Linq;
using System.Net.Mail; using System.Net.Mail;
namespace Swan.Net.Smtp {
/// <summary> /// <summary>
/// Use this class to store the sender session data. /// Use this class to store the sender session data.
/// </summary> /// </summary>
internal class SmtpSender internal class SmtpSender {
{ private readonly String _sessionId;
private readonly string _sessionId; private String _requestText;
private string _requestText;
public SmtpSender(string sessionId) public SmtpSender(String sessionId) => this._sessionId = sessionId;
{
_sessionId = sessionId;
}
public string RequestText public String RequestText {
{ get => this._requestText;
get => _requestText; set {
set this._requestText = value;
{ $" TX {this._requestText}".Trace(typeof(SmtpClient), this._sessionId);
_requestText = value;
$" TX {_requestText}".Trace(typeof(SmtpClient), _sessionId);
} }
} }
public string ReplyText { get; set; } public String ReplyText {
get; set;
}
public bool IsReplyOk => ReplyText.StartsWith("250 ", StringComparison.OrdinalIgnoreCase); public Boolean IsReplyOk => this.ReplyText.StartsWith("250 ", StringComparison.OrdinalIgnoreCase);
public void ValidateReply() public void ValidateReply() {
{ if(this.ReplyText == null) {
if (ReplyText == null)
throw new SmtpException("There was no response from the server"); throw new SmtpException("There was no response from the server");
}
try try {
{ SmtpServerReply response = SmtpServerReply.Parse(this.ReplyText);
var response = SmtpServerReply.Parse(ReplyText); $" RX {this.ReplyText} - {response.IsPositive}".Trace(typeof(SmtpClient), this._sessionId);
$" RX {ReplyText} - {response.IsPositive}".Trace(typeof(SmtpClient), _sessionId);
if (response.IsPositive) return; if(response.IsPositive) {
return;
}
var responseContent = response.Content.Any() String responseContent = response.Content.Any() ? String.Join(";", response.Content.ToArray()) : String.Empty;
? string.Join(";", response.Content.ToArray())
: string.Empty;
throw new SmtpException((SmtpStatusCode)response.ReplyCode, responseContent); throw new SmtpException((SmtpStatusCode)response.ReplyCode, responseContent);
} } catch(Exception ex) {
catch (Exception ex) if(!(ex is SmtpException)) {
{ throw new SmtpException($"Could not parse server response: {this.ReplyText}");
if (!(ex is SmtpException)) }
throw new SmtpException($"Could not parse server response: {ReplyText}");
} }
} }
} }

View File

@ -1,16 +1,14 @@
namespace Swan.Net.Smtp using System;
{
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Globalization; using System.Globalization;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
namespace Swan.Net.Smtp {
/// <summary> /// <summary>
/// Represents an SMTP server response object. /// Represents an SMTP server response object.
/// </summary> /// </summary>
public class SmtpServerReply public class SmtpServerReply {
{
#region Constructors #region Constructors
/// <summary> /// <summary>
@ -19,36 +17,50 @@
/// <param name="responseCode">The response code.</param> /// <param name="responseCode">The response code.</param>
/// <param name="statusCode">The status code.</param> /// <param name="statusCode">The status code.</param>
/// <param name="content">The content.</param> /// <param name="content">The content.</param>
public SmtpServerReply(int responseCode, string statusCode, params string[] content) public SmtpServerReply(Int32 responseCode, String statusCode, params String[] content) {
{ this.Content = new List<String>();
Content = new List<string>(); this.ReplyCode = responseCode;
ReplyCode = responseCode; this.EnhancedStatusCode = statusCode;
EnhancedStatusCode = statusCode; this.Content.AddRange(content);
Content.AddRange(content); this.IsValid = responseCode >= 200 && responseCode < 600;
IsValid = responseCode >= 200 && responseCode < 600; this.ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown;
ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown; this.ReplyCodeCategory = SmtpReplyCodeCategories.Unknown;
ReplyCodeCategory = SmtpReplyCodeCategories.Unknown;
if (!IsValid) return; if(!this.IsValid) {
if (responseCode >= 200) ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveCompletion; return;
if (responseCode >= 300) ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveIntermediate; }
if (responseCode >= 400) ReplyCodeSeverity = SmtpReplyCodeSeverities.TransientNegative;
if (responseCode >= 500) ReplyCodeSeverity = SmtpReplyCodeSeverities.PermanentNegative;
if (responseCode >= 600) ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown;
if (int.TryParse(responseCode.ToString(CultureInfo.InvariantCulture).Substring(1, 1), out var middleDigit)) if(responseCode >= 200) {
{ this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveCompletion;
if (middleDigit >= 0 && middleDigit <= 5) }
ReplyCodeCategory = (SmtpReplyCodeCategories) middleDigit;
if(responseCode >= 300) {
this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveIntermediate;
}
if(responseCode >= 400) {
this.ReplyCodeSeverity = SmtpReplyCodeSeverities.TransientNegative;
}
if(responseCode >= 500) {
this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PermanentNegative;
}
if(responseCode >= 600) {
this.ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown;
}
if(Int32.TryParse(responseCode.ToString(CultureInfo.InvariantCulture).Substring(1, 1), out Int32 middleDigit)) {
if(middleDigit >= 0 && middleDigit <= 5) {
this.ReplyCodeCategory = (SmtpReplyCodeCategories)middleDigit;
}
} }
} }
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="SmtpServerReply"/> class. /// Initializes a new instance of the <see cref="SmtpServerReply"/> class.
/// </summary> /// </summary>
public SmtpServerReply() public SmtpServerReply() : this(0, String.Empty, String.Empty) {
: this(0, string.Empty, string.Empty)
{
// placeholder // placeholder
} }
@ -58,9 +70,7 @@
/// <param name="responseCode">The response code.</param> /// <param name="responseCode">The response code.</param>
/// <param name="statusCode">The status code.</param> /// <param name="statusCode">The status code.</param>
/// <param name="content">The content.</param> /// <param name="content">The content.</param>
public SmtpServerReply(int responseCode, string statusCode, string content) public SmtpServerReply(Int32 responseCode, String statusCode, String content) : this(responseCode, statusCode, new[] { content }) {
: this(responseCode, statusCode, new[] {content})
{
} }
/// <summary> /// <summary>
@ -68,9 +78,7 @@
/// </summary> /// </summary>
/// <param name="responseCode">The response code.</param> /// <param name="responseCode">The response code.</param>
/// <param name="content">The content.</param> /// <param name="content">The content.</param>
public SmtpServerReply(int responseCode, string content) public SmtpServerReply(Int32 responseCode, String content) : this(responseCode, String.Empty, content) {
: this(responseCode, string.Empty, content)
{
} }
#endregion #endregion
@ -80,14 +88,12 @@
/// <summary> /// <summary>
/// Gets the command unrecognized reply. /// Gets the command unrecognized reply.
/// </summary> /// </summary>
public static SmtpServerReply CommandUnrecognized => public static SmtpServerReply CommandUnrecognized => new SmtpServerReply(500, "Syntax error, command unrecognized");
new SmtpServerReply(500, "Syntax error, command unrecognized");
/// <summary> /// <summary>
/// Gets the syntax error arguments reply. /// Gets the syntax error arguments reply.
/// </summary> /// </summary>
public static SmtpServerReply SyntaxErrorArguments => public static SmtpServerReply SyntaxErrorArguments => new SmtpServerReply(501, "Syntax error in parameters or arguments");
new SmtpServerReply(501, "Syntax error in parameters or arguments");
/// <summary> /// <summary>
/// Gets the command not implemented reply. /// Gets the command not implemented reply.
@ -102,14 +108,12 @@
/// <summary> /// <summary>
/// Gets the protocol violation reply. /// Gets the protocol violation reply.
/// </summary>= /// </summary>=
public static SmtpServerReply ProtocolViolation => public static SmtpServerReply ProtocolViolation => new SmtpServerReply(451, "Requested action aborted: error in processing");
new SmtpServerReply(451, "Requested action aborted: error in processing");
/// <summary> /// <summary>
/// Gets the system status bye reply. /// Gets the system status bye reply.
/// </summary> /// </summary>
public static SmtpServerReply SystemStatusBye => public static SmtpServerReply SystemStatusBye => new SmtpServerReply(221, "Service closing transmission channel");
new SmtpServerReply(221, "Service closing transmission channel");
/// <summary> /// <summary>
/// Gets the system status help reply. /// Gets the system status help reply.
@ -138,37 +142,49 @@
/// <summary> /// <summary>
/// Gets the response severity. /// Gets the response severity.
/// </summary> /// </summary>
public SmtpReplyCodeSeverities ReplyCodeSeverity { get; } public SmtpReplyCodeSeverities ReplyCodeSeverity {
get;
}
/// <summary> /// <summary>
/// Gets the response category. /// Gets the response category.
/// </summary> /// </summary>
public SmtpReplyCodeCategories ReplyCodeCategory { get; } public SmtpReplyCodeCategories ReplyCodeCategory {
get;
}
/// <summary> /// <summary>
/// Gets the numeric response code. /// Gets the numeric response code.
/// </summary> /// </summary>
public int ReplyCode { get; } public Int32 ReplyCode {
get;
}
/// <summary> /// <summary>
/// Gets the enhanced status code. /// Gets the enhanced status code.
/// </summary> /// </summary>
public string EnhancedStatusCode { get; } public String EnhancedStatusCode {
get;
}
/// <summary> /// <summary>
/// Gets the content. /// Gets the content.
/// </summary> /// </summary>
public List<string> Content { get; } public List<String> Content {
get;
}
/// <summary> /// <summary>
/// Returns true if the response code is between 200 and 599. /// Returns true if the response code is between 200 and 599.
/// </summary> /// </summary>
public bool IsValid { get; } public Boolean IsValid {
get;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this instance is positive. /// Gets a value indicating whether this instance is positive.
/// </summary> /// </summary>
public bool IsPositive => ReplyCode >= 200 && ReplyCode <= 399; public Boolean IsPositive => this.ReplyCode >= 200 && this.ReplyCode <= 399;
#endregion #endregion
@ -179,30 +195,31 @@
/// </summary> /// </summary>
/// <param name="text">The text.</param> /// <param name="text">The text.</param>
/// <returns>A new instance of SMTP server response object.</returns> /// <returns>A new instance of SMTP server response object.</returns>
public static SmtpServerReply Parse(string text) public static SmtpServerReply Parse(String text) {
{ String[] lines = text.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries);
var lines = text.Split(new[] {"\r\n"}, StringSplitOptions.RemoveEmptyEntries); if(lines.Length == 0) {
if (lines.Length == 0) return new SmtpServerReply(); return new SmtpServerReply();
var lastLineParts = lines.Last().Split(new[] {" "}, StringSplitOptions.RemoveEmptyEntries);
var enhancedStatusCode = string.Empty;
int.TryParse(lastLineParts[0], out var responseCode);
if (lastLineParts.Length > 1)
{
if (lastLineParts[1].Split('.').Length == 3)
enhancedStatusCode = lastLineParts[1];
} }
var content = new List<string>(); String[] lastLineParts = lines.Last().Split(new[] { " " }, StringSplitOptions.RemoveEmptyEntries);
String enhancedStatusCode = String.Empty;
_ = Int32.TryParse(lastLineParts[0], out Int32 responseCode);
if(lastLineParts.Length > 1) {
if(lastLineParts[1].Split('.').Length == 3) {
enhancedStatusCode = lastLineParts[1];
}
}
for (var i = 0; i < lines.Length; i++) List<String> content = new List<String>();
{
var splitChar = i == lines.Length - 1 ? " " : "-";
var lineParts = lines[i].Split(new[] {splitChar}, 2, StringSplitOptions.None); for(Int32 i = 0; i < lines.Length; i++) {
var lineContent = lineParts.Last(); String splitChar = i == lines.Length - 1 ? " " : "-";
if (string.IsNullOrWhiteSpace(enhancedStatusCode) == false)
lineContent = lineContent.Replace(enhancedStatusCode, string.Empty).Trim(); String[] lineParts = lines[i].Split(new[] { splitChar }, 2, StringSplitOptions.None);
String lineContent = lineParts.Last();
if(String.IsNullOrWhiteSpace(enhancedStatusCode) == false) {
lineContent = lineContent.Replace(enhancedStatusCode, String.Empty).Trim();
}
content.Add(lineContent); content.Add(lineContent);
} }
@ -216,23 +233,19 @@
/// <returns> /// <returns>
/// A <see cref="System.String" /> that represents this instance. /// A <see cref="System.String" /> that represents this instance.
/// </returns> /// </returns>
public override string ToString() public override String ToString() {
{ String responseCodeText = this.ReplyCode.ToString(CultureInfo.InvariantCulture);
var responseCodeText = ReplyCode.ToString(CultureInfo.InvariantCulture); String statusCodeText = String.IsNullOrWhiteSpace(this.EnhancedStatusCode) ? String.Empty : $" {this.EnhancedStatusCode.Trim()}";
var statusCodeText = string.IsNullOrWhiteSpace(EnhancedStatusCode) if(this.Content.Count == 0) {
? string.Empty return $"{responseCodeText}{statusCodeText}";
: $" {EnhancedStatusCode.Trim()}"; }
if (Content.Count == 0) return $"{responseCodeText}{statusCodeText}";
var builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
for (var i = 0; i < Content.Count; i++) for(Int32 i = 0; i < this.Content.Count; i++) {
{ Boolean isLastLine = i == this.Content.Count - 1;
var isLastLine = i == Content.Count - 1;
builder.Append(isLastLine _ = builder.Append(isLastLine ? $"{responseCodeText}{statusCodeText} {this.Content[i]}" : $"{responseCodeText}-{this.Content[i]}\r\n");
? $"{responseCodeText}{statusCodeText} {Content[i]}"
: $"{responseCodeText}-{Content[i]}\r\n");
} }
return builder.ToString(); return builder.ToString();

View File

@ -1,20 +1,18 @@
namespace Swan.Net.Smtp using System.Collections.Generic;
{ using System;
using System.Collections.Generic;
namespace Swan.Net.Smtp {
/// <summary> /// <summary>
/// Represents the state of an SMTP session associated with a client. /// Represents the state of an SMTP session associated with a client.
/// </summary> /// </summary>
public class SmtpSessionState public class SmtpSessionState {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="SmtpSessionState"/> class. /// Initializes a new instance of the <see cref="SmtpSessionState"/> class.
/// </summary> /// </summary>
public SmtpSessionState() public SmtpSessionState() {
{ this.DataBuffer = new List<Byte>();
DataBuffer = new List<byte>(); this.Reset(true);
Reset(true); this.ResetAuthentication();
ResetAuthentication();
} }
#region Properties #region Properties
@ -22,42 +20,56 @@
/// <summary> /// <summary>
/// Gets the contents of the data buffer. /// Gets the contents of the data buffer.
/// </summary> /// </summary>
public List<byte> DataBuffer { get; protected set; } public List<Byte> DataBuffer {
get; protected set;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether this instance has initiated. /// Gets or sets a value indicating whether this instance has initiated.
/// </summary> /// </summary>
public bool HasInitiated { get; set; } public Boolean HasInitiated {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the current session supports extensions. /// Gets or sets a value indicating whether the current session supports extensions.
/// </summary> /// </summary>
public bool SupportsExtensions { get; set; } public Boolean SupportsExtensions {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the client hostname. /// Gets or sets the client hostname.
/// </summary> /// </summary>
public string ClientHostname { get; set; } public String ClientHostname {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the session is currently receiving DATA. /// Gets or sets a value indicating whether the session is currently receiving DATA.
/// </summary> /// </summary>
public bool IsInDataMode { get; set; } public Boolean IsInDataMode {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the sender address. /// Gets or sets the sender address.
/// </summary> /// </summary>
public string SenderAddress { get; set; } public String SenderAddress {
get; set;
}
/// <summary> /// <summary>
/// Gets the recipients. /// Gets the recipients.
/// </summary> /// </summary>
public List<string> Recipients { get; } = new List<string>(); public List<String> Recipients { get; } = new List<String>();
/// <summary> /// <summary>
/// Gets or sets the extended data supporting any additional field for storage by a responder implementation. /// Gets or sets the extended data supporting any additional field for storage by a responder implementation.
/// </summary> /// </summary>
public object ExtendedData { get; set; } public Object ExtendedData {
get; set;
}
#endregion #endregion
@ -66,48 +78,59 @@
/// <summary> /// <summary>
/// Gets or sets a value indicating whether this instance is in authentication mode. /// Gets or sets a value indicating whether this instance is in authentication mode.
/// </summary> /// </summary>
public bool IsInAuthMode { get; set; } public Boolean IsInAuthMode {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the username. /// Gets or sets the username.
/// </summary> /// </summary>
public string Username { get; set; } public String Username {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the password. /// Gets or sets the password.
/// </summary> /// </summary>
public string Password { get; set; } public String Password {
get; set;
}
/// <summary> /// <summary>
/// Gets a value indicating whether this instance has provided username. /// Gets a value indicating whether this instance has provided username.
/// </summary> /// </summary>
public bool HasProvidedUsername => string.IsNullOrWhiteSpace(Username) == false; public Boolean HasProvidedUsername => String.IsNullOrWhiteSpace(this.Username) == false;
/// <summary> /// <summary>
/// Gets or sets a value indicating whether this instance is authenticated. /// Gets or sets a value indicating whether this instance is authenticated.
/// </summary> /// </summary>
public bool IsAuthenticated { get; set; } public Boolean IsAuthenticated {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the authentication mode. /// Gets or sets the authentication mode.
/// </summary> /// </summary>
public string AuthMode { get; set; } public String AuthMode {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether this instance is channel secure. /// Gets or sets a value indicating whether this instance is channel secure.
/// </summary> /// </summary>
public bool IsChannelSecure { get; set; } public Boolean IsChannelSecure {
get; set;
}
/// <summary> /// <summary>
/// Resets the authentication state. /// Resets the authentication state.
/// </summary> /// </summary>
public void ResetAuthentication() public void ResetAuthentication() {
{ this.Username = String.Empty;
Username = string.Empty; this.Password = String.Empty;
Password = string.Empty; this.AuthMode = String.Empty;
AuthMode = string.Empty; this.IsInAuthMode = false;
IsInAuthMode = false; this.IsAuthenticated = false;
IsAuthenticated = false;
} }
#endregion #endregion
@ -117,38 +140,36 @@
/// <summary> /// <summary>
/// Resets the data mode to false, clears the recipients, the sender address and the data buffer. /// Resets the data mode to false, clears the recipients, the sender address and the data buffer.
/// </summary> /// </summary>
public void ResetEmail() public void ResetEmail() {
{ this.IsInDataMode = false;
IsInDataMode = false; this.Recipients.Clear();
Recipients.Clear(); this.SenderAddress = String.Empty;
SenderAddress = string.Empty; this.DataBuffer.Clear();
DataBuffer.Clear();
} }
/// <summary> /// <summary>
/// Resets the state table entirely. /// Resets the state table entirely.
/// </summary> /// </summary>
/// <param name="clearExtensionData">if set to <c>true</c> [clear extension data].</param> /// <param name="clearExtensionData">if set to <c>true</c> [clear extension data].</param>
public void Reset(bool clearExtensionData) public void Reset(Boolean clearExtensionData) {
{ this.HasInitiated = false;
HasInitiated = false; this.SupportsExtensions = false;
SupportsExtensions = false; this.ClientHostname = String.Empty;
ClientHostname = string.Empty; this.ResetEmail();
ResetEmail();
if (clearExtensionData) if(clearExtensionData) {
ExtendedData = null; this.ExtendedData = null;
}
} }
/// <summary> /// <summary>
/// Creates a new object that is a copy of the current instance. /// Creates a new object that is a copy of the current instance.
/// </summary> /// </summary>
/// <returns>A clone.</returns> /// <returns>A clone.</returns>
public virtual SmtpSessionState Clone() public virtual SmtpSessionState Clone() {
{ SmtpSessionState clonedState = this.CopyPropertiesToNew<SmtpSessionState>(new[] { nameof(this.DataBuffer) });
var clonedState = this.CopyPropertiesToNew<SmtpSessionState>(new[] {nameof(DataBuffer)}); clonedState.DataBuffer.AddRange(this.DataBuffer);
clonedState.DataBuffer.AddRange(DataBuffer); clonedState.Recipients.AddRange(this.Recipients);
clonedState.Recipients.AddRange(Recipients);
return clonedState; return clonedState;
} }

View File

@ -1,22 +1,21 @@
namespace Swan using System;
{
namespace Swan {
/// <summary> /// <summary>
/// Represents the text of the standard output and standard error /// Represents the text of the standard output and standard error
/// of a process, including its exit code. /// of a process, including its exit code.
/// </summary> /// </summary>
public class ProcessResult public class ProcessResult {
{
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ProcessResult" /> class. /// Initializes a new instance of the <see cref="ProcessResult" /> class.
/// </summary> /// </summary>
/// <param name="exitCode">The exit code.</param> /// <param name="exitCode">The exit code.</param>
/// <param name="standardOutput">The standard output.</param> /// <param name="standardOutput">The standard output.</param>
/// <param name="standardError">The standard error.</param> /// <param name="standardError">The standard error.</param>
public ProcessResult(int exitCode, string standardOutput, string standardError) public ProcessResult(Int32 exitCode, String standardOutput, String standardError) {
{ this.ExitCode = exitCode;
ExitCode = exitCode; this.StandardOutput = standardOutput;
StandardOutput = standardOutput; this.StandardError = standardError;
StandardError = standardError;
} }
/// <summary> /// <summary>
@ -25,7 +24,9 @@
/// <value> /// <value>
/// The exit code. /// The exit code.
/// </value> /// </value>
public int ExitCode { get; } public Int32 ExitCode {
get;
}
/// <summary> /// <summary>
/// Gets the text of the standard output. /// Gets the text of the standard output.
@ -33,7 +34,9 @@
/// <value> /// <value>
/// The standard output. /// The standard output.
/// </value> /// </value>
public string StandardOutput { get; } public String StandardOutput {
get;
}
/// <summary> /// <summary>
/// Gets the text of the standard error. /// Gets the text of the standard error.
@ -41,6 +44,8 @@
/// <value> /// <value>
/// The standard error. /// The standard error.
/// </value> /// </value>
public string StandardError { get; } public String StandardError {
get;
}
} }
} }

View File

@ -1,5 +1,4 @@
namespace Swan #nullable enable
{
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
@ -8,19 +7,19 @@
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan {
/// <summary> /// <summary>
/// Provides methods to help create external processes, and efficiently capture the /// Provides methods to help create external processes, and efficiently capture the
/// standard error and standard output streams. /// standard error and standard output streams.
/// </summary> /// </summary>
public static class ProcessRunner public static class ProcessRunner {
{
/// <summary> /// <summary>
/// Defines a delegate to handle binary data reception from the standard /// Defines a delegate to handle binary data reception from the standard
/// output or standard error streams from a process. /// output or standard error streams from a process.
/// </summary> /// </summary>
/// <param name="processData">The process data.</param> /// <param name="processData">The process data.</param>
/// <param name="process">The process.</param> /// <param name="process">The process.</param>
public delegate void ProcessDataReceivedCallback(byte[] processData, Process process); public delegate void ProcessDataReceivedCallback(Byte[] processData, Process process);
/// <summary> /// <summary>
/// Runs the process asynchronously and if the exit code is 0, /// Runs the process asynchronously and if the exit code is 0,
@ -35,7 +34,7 @@
/// <returns>The type of the result produced by this Task.</returns> /// <returns>The type of the result produced by this Task.</returns>
/// <example> /// <example>
/// The following code explains how to run an external process using the /// The following code explains how to run an external process using the
/// <see cref="GetProcessOutputAsync(string, string, CancellationToken)"/> method. /// <see cref="GetProcessOutputAsync(String, String, CancellationToken)"/> method.
/// <code> /// <code>
/// class Example /// class Example
/// { /// {
@ -54,13 +53,8 @@
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public static async Task<string> GetProcessOutputAsync( public static async Task<String> GetProcessOutputAsync(String filename, String arguments = "", String? workingDirectory = null, CancellationToken cancellationToken = default) {
string filename, ProcessResult result = await GetProcessResultAsync(filename, arguments, workingDirectory, cancellationToken: cancellationToken).ConfigureAwait(false);
string arguments = "",
string? workingDirectory = null,
CancellationToken cancellationToken = default)
{
var result = await GetProcessResultAsync(filename, arguments, workingDirectory, cancellationToken: cancellationToken).ConfigureAwait(false);
return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; return result.ExitCode == 0 ? result.StandardOutput : result.StandardError;
} }
@ -78,13 +72,8 @@
/// <returns> /// <returns>
/// The type of the result produced by this Task. /// The type of the result produced by this Task.
/// </returns> /// </returns>
public static async Task<string> GetProcessEncodedOutputAsync( public static async Task<String> GetProcessEncodedOutputAsync(String filename, String arguments = "", Encoding? encoding = null, CancellationToken cancellationToken = default) {
string filename, ProcessResult result = await GetProcessResultAsync(filename, arguments, null, encoding, cancellationToken).ConfigureAwait(false);
string arguments = "",
Encoding? encoding = null,
CancellationToken cancellationToken = default)
{
var result = await GetProcessResultAsync(filename, arguments, null, encoding, cancellationToken).ConfigureAwait(false);
return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; return result.ExitCode == 0 ? result.StandardOutput : result.StandardError;
} }
@ -100,11 +89,7 @@
/// Text of the standard output and standard error streams along with the exit code as a <see cref="ProcessResult" /> instance. /// Text of the standard output and standard error streams along with the exit code as a <see cref="ProcessResult" /> instance.
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">filename.</exception> /// <exception cref="ArgumentNullException">filename.</exception>
public static Task<ProcessResult> GetProcessResultAsync( public static Task<ProcessResult> GetProcessResultAsync(String filename, String arguments = "", CancellationToken cancellationToken = default) => GetProcessResultAsync(filename, arguments, null, Definitions.CurrentAnsiEncoding, cancellationToken);
string filename,
string arguments = "",
CancellationToken cancellationToken = default) =>
GetProcessResultAsync(filename, arguments, null, Definitions.CurrentAnsiEncoding, cancellationToken);
/// <summary> /// <summary>
/// Executes a process asynchronously and returns the text of the standard output and standard error streams /// Executes a process asynchronously and returns the text of the standard output and standard error streams
@ -121,7 +106,7 @@
/// </returns> /// </returns>
/// <exception cref="ArgumentNullException">filename.</exception> /// <exception cref="ArgumentNullException">filename.</exception>
/// <example> /// <example>
/// The following code describes how to run an external process using the <see cref="GetProcessResultAsync(string, string, string, Encoding, CancellationToken)" /> method. /// The following code describes how to run an external process using the <see cref="GetProcessResultAsync(String, String, String, Encoding, CancellationToken)" /> method.
/// <code> /// <code>
/// class Example /// class Example
/// { /// {
@ -143,32 +128,19 @@
/// } /// }
/// } /// }
/// </code></example> /// </code></example>
public static async Task<ProcessResult> GetProcessResultAsync( public static async Task<ProcessResult> GetProcessResultAsync(String filename, String arguments, String? workingDirectory, Encoding? encoding = null, CancellationToken cancellationToken = default) {
string filename, if(filename == null) {
string arguments,
string? workingDirectory,
Encoding? encoding = null,
CancellationToken cancellationToken = default)
{
if (filename == null)
throw new ArgumentNullException(nameof(filename)); throw new ArgumentNullException(nameof(filename));
}
if (encoding == null) if(encoding == null) {
encoding = Definitions.CurrentAnsiEncoding; encoding = Definitions.CurrentAnsiEncoding;
}
var standardOutputBuilder = new StringBuilder(); StringBuilder standardOutputBuilder = new StringBuilder();
var standardErrorBuilder = new StringBuilder(); StringBuilder standardErrorBuilder = new StringBuilder();
var processReturn = await RunProcessAsync( Int32 processReturn = await RunProcessAsync(filename, arguments, workingDirectory, (data, proc) => standardOutputBuilder.Append(encoding.GetString(data)), (data, proc) => standardErrorBuilder.Append(encoding.GetString(data)), encoding, true, cancellationToken).ConfigureAwait(false);
filename,
arguments,
workingDirectory,
(data, proc) => standardOutputBuilder.Append(encoding.GetString(data)),
(data, proc) => standardErrorBuilder.Append(encoding.GetString(data)),
encoding,
true,
cancellationToken)
.ConfigureAwait(false);
return new ProcessResult(processReturn, standardOutputBuilder.ToString(), standardErrorBuilder.ToString()); return new ProcessResult(processReturn, standardOutputBuilder.ToString(), standardErrorBuilder.ToString());
} }
@ -191,27 +163,16 @@
/// <returns> /// <returns>
/// Value type will be -1 for forceful termination of the process. /// Value type will be -1 for forceful termination of the process.
/// </returns> /// </returns>
public static Task<int> RunProcessAsync( public static Task<Int32> RunProcessAsync(String filename, String arguments, String? workingDirectory, ProcessDataReceivedCallback onOutputData, ProcessDataReceivedCallback? onErrorData, Encoding encoding, Boolean syncEvents = true, CancellationToken cancellationToken = default) {
string filename, if(filename == null) {
string arguments,
string? workingDirectory,
ProcessDataReceivedCallback onOutputData,
ProcessDataReceivedCallback onErrorData,
Encoding encoding,
bool syncEvents = true,
CancellationToken cancellationToken = default)
{
if (filename == null)
throw new ArgumentNullException(nameof(filename)); throw new ArgumentNullException(nameof(filename));
}
return Task.Run(() => return Task.Run(() => {
{
// Setup the process and its corresponding start info // Setup the process and its corresponding start info
var process = new Process Process process = new Process {
{
EnableRaisingEvents = false, EnableRaisingEvents = false,
StartInfo = new ProcessStartInfo StartInfo = new ProcessStartInfo {
{
Arguments = arguments, Arguments = arguments,
CreateNoWindow = true, CreateNoWindow = true,
FileName = filename, FileName = filename,
@ -220,73 +181,51 @@
RedirectStandardOutput = true, RedirectStandardOutput = true,
StandardOutputEncoding = encoding, StandardOutputEncoding = encoding,
UseShellExecute = false, UseShellExecute = false,
#if NET461
WindowStyle = ProcessWindowStyle.Hidden,
#endif
}, },
}; };
if (!string.IsNullOrWhiteSpace(workingDirectory)) if(!String.IsNullOrWhiteSpace(workingDirectory)) {
process.StartInfo.WorkingDirectory = workingDirectory; process.StartInfo.WorkingDirectory = workingDirectory;
}
// Launch the process and discard any buffered data for standard error and standard output // Launch the process and discard any buffered data for standard error and standard output
process.Start(); _ = process.Start();
process.StandardError.DiscardBufferedData(); process.StandardError.DiscardBufferedData();
process.StandardOutput.DiscardBufferedData(); process.StandardOutput.DiscardBufferedData();
// Launch the asynchronous stream reading tasks // Launch the asynchronous stream reading tasks
var readTasks = new Task[2]; Task[] readTasks = new Task[2];
readTasks[0] = CopyStreamAsync( readTasks[0] = CopyStreamAsync(process, process.StandardOutput.BaseStream, onOutputData, syncEvents, cancellationToken);
process, readTasks[1] = CopyStreamAsync(process, process.StandardError.BaseStream, onErrorData, syncEvents, cancellationToken);
process.StandardOutput.BaseStream,
onOutputData,
syncEvents,
cancellationToken);
readTasks[1] = CopyStreamAsync(
process,
process.StandardError.BaseStream,
onErrorData,
syncEvents,
cancellationToken);
try try {
{
// Wait for all tasks to complete // Wait for all tasks to complete
Task.WaitAll(readTasks, cancellationToken); Task.WaitAll(readTasks, cancellationToken);
} } catch(TaskCanceledException) {
catch (TaskCanceledException)
{
// ignore // ignore
} } finally {
finally
{
// Wait for the process to exit // Wait for the process to exit
while (cancellationToken.IsCancellationRequested == false) while(cancellationToken.IsCancellationRequested == false) {
{ if(process.HasExited || process.WaitForExit(5)) {
if (process.HasExited || process.WaitForExit(5))
break; break;
} }
}
// Forcefully kill the process if it do not exit // Forcefully kill the process if it do not exit
try try {
{ if(process.HasExited == false) {
if (process.HasExited == false)
process.Kill(); process.Kill();
} }
catch } catch {
{
// swallow // swallow
} }
} }
try try {
{
// Retrieve and return the exit code. // Retrieve and return the exit code.
// -1 signals error // -1 signals error
return process.HasExited ? process.ExitCode : -1; return process.HasExited ? process.ExitCode : -1;
} } catch {
catch
{
return -1; return -1;
} }
}, cancellationToken); }, cancellationToken);
@ -308,7 +247,7 @@
/// <returns>Value type will be -1 for forceful termination of the process.</returns> /// <returns>Value type will be -1 for forceful termination of the process.</returns>
/// <example> /// <example>
/// The following example illustrates how to run an external process using the /// The following example illustrates how to run an external process using the
/// <see cref="RunProcessAsync(string, string, ProcessDataReceivedCallback, ProcessDataReceivedCallback, bool, CancellationToken)"/> /// <see cref="RunProcessAsync(String, String, ProcessDataReceivedCallback, ProcessDataReceivedCallback, Boolean, CancellationToken)"/>
/// method. /// method.
/// <code> /// <code>
/// class Example /// class Example
@ -334,22 +273,7 @@
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public static Task<int> RunProcessAsync( public static Task<Int32> RunProcessAsync(String filename, String arguments, ProcessDataReceivedCallback onOutputData, ProcessDataReceivedCallback? onErrorData, Boolean syncEvents = true, CancellationToken cancellationToken = default) => RunProcessAsync(filename, arguments, null, onOutputData, onErrorData, Definitions.CurrentAnsiEncoding, syncEvents, cancellationToken);
string filename,
string arguments,
ProcessDataReceivedCallback onOutputData,
ProcessDataReceivedCallback onErrorData,
bool syncEvents = true,
CancellationToken cancellationToken = default)
=> RunProcessAsync(
filename,
arguments,
null,
onOutputData,
onErrorData,
Definitions.CurrentAnsiEncoding,
syncEvents,
cancellationToken);
/// <summary> /// <summary>
/// Copies the stream asynchronously. /// Copies the stream asynchronously.
@ -360,79 +284,65 @@
/// <param name="syncEvents">if set to <c>true</c> [synchronize events].</param> /// <param name="syncEvents">if set to <c>true</c> [synchronize events].</param>
/// <param name="ct">The cancellation token.</param> /// <param name="ct">The cancellation token.</param>
/// <returns>Total copies stream.</returns> /// <returns>Total copies stream.</returns>
private static Task<ulong> CopyStreamAsync( private static Task<UInt64> CopyStreamAsync(Process process, Stream baseStream, ProcessDataReceivedCallback? onDataCallback, Boolean syncEvents, CancellationToken ct) => Task.Run(async () => {
Process process,
Stream baseStream,
ProcessDataReceivedCallback onDataCallback,
bool syncEvents,
CancellationToken ct) =>
Task.Run(async () =>
{
// define some state variables // define some state variables
var swapBuffer = new byte[2048]; // the buffer to copy data from one stream to the next Byte[] swapBuffer = new Byte[2048]; // the buffer to copy data from one stream to the next
ulong totalCount = 0; // the total amount of bytes read UInt64 totalCount = 0; // the total amount of bytes read
var hasExited = false; Boolean hasExited = false;
while (ct.IsCancellationRequested == false) while(ct.IsCancellationRequested == false) {
{ try {
try
{
// Check if process is no longer valid // Check if process is no longer valid
// if this condition holds, simply read the last bits of data available. // if this condition holds, simply read the last bits of data available.
int readCount; // the bytes read in any given event Int32 readCount; // the bytes read in any given event
if (process.HasExited || process.WaitForExit(1)) if(process.HasExited || process.WaitForExit(1)) {
{ while(true) {
while (true) try {
{
try
{
readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct); readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct);
if (readCount > 0) if(readCount > 0) {
{ totalCount += (UInt64)readCount;
totalCount += (ulong) readCount;
onDataCallback?.Invoke(swapBuffer.Skip(0).Take(readCount).ToArray(), process); onDataCallback?.Invoke(swapBuffer.Skip(0).Take(readCount).ToArray(), process);
} } else {
else
{
hasExited = true; hasExited = true;
break; break;
} }
} } catch {
catch
{
hasExited = true; hasExited = true;
break; break;
} }
} }
} }
if (hasExited) break; if(hasExited) {
break;
}
// Try reading from the stream. < 0 means no read occurred. // Try reading from the stream. < 0 means no read occurred.
readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct).ConfigureAwait(false); readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct).ConfigureAwait(false);
// When no read is done, we need to let is rest for a bit // When no read is done, we need to let is rest for a bit
if (readCount <= 0) if(readCount <= 0) {
{
await Task.Delay(1, ct).ConfigureAwait(false); // do not hog CPU cycles doing nothing. await Task.Delay(1, ct).ConfigureAwait(false); // do not hog CPU cycles doing nothing.
continue; continue;
} }
totalCount += (ulong) readCount; totalCount += (UInt64)readCount;
if (onDataCallback == null) continue; if(onDataCallback == null) {
continue;
}
// Create the buffer to pass to the callback // Create the buffer to pass to the callback
var eventBuffer = swapBuffer.Skip(0).Take(readCount).ToArray(); Byte[] eventBuffer = swapBuffer.Skip(0).Take(readCount).ToArray();
// Create the data processing callback invocation // Create the data processing callback invocation
var eventTask = Task.Run(() => onDataCallback.Invoke(eventBuffer, process), ct); Task eventTask = Task.Run(() => onDataCallback.Invoke(eventBuffer, process), ct);
// wait for the event to process before the next read occurs // wait for the event to process before the next read occurs
if (syncEvents) eventTask.Wait(ct); if(syncEvents) {
eventTask.Wait(ct);
} }
catch } catch {
{
break; break;
} }
} }

View File

@ -1,22 +1,20 @@
using System; using System;
#if !NET461
namespace Swan.Services namespace Swan.Services {
{
/// <summary> /// <summary>
/// Mimic a Windows ServiceBase class. Useful to keep compatibility with applications /// Mimic a Windows ServiceBase class. Useful to keep compatibility with applications
/// running as services in OS different to Windows. /// running as services in OS different to Windows.
/// </summary> /// </summary>
[Obsolete("This abstract class will be removed in version 3.0")] [Obsolete("This abstract class will be removed in version 3.0")]
public abstract class ServiceBase public abstract class ServiceBase {
{
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the service can be stopped once it has started. /// Gets or sets a value indicating whether the service can be stopped once it has started.
/// </summary> /// </summary>
/// <value> /// <value>
/// <c>true</c> if this instance can stop; otherwise, <c>false</c>. /// <c>true</c> if this instance can stop; otherwise, <c>false</c>.
/// </value> /// </value>
public bool CanStop { get; set; } = true; public Boolean CanStop { get; set; } = true;
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the service should be notified when the system is shutting down. /// Gets or sets a value indicating whether the service should be notified when the system is shutting down.
@ -24,7 +22,9 @@ namespace Swan.Services
/// <value> /// <value>
/// <c>true</c> if this instance can shutdown; otherwise, <c>false</c>. /// <c>true</c> if this instance can shutdown; otherwise, <c>false</c>.
/// </value> /// </value>
public bool CanShutdown { get; set; } public Boolean CanShutdown {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the service can be paused and resumed. /// Gets or sets a value indicating whether the service can be paused and resumed.
@ -32,7 +32,9 @@ namespace Swan.Services
/// <value> /// <value>
/// <c>true</c> if this instance can pause and continue; otherwise, <c>false</c>. /// <c>true</c> if this instance can pause and continue; otherwise, <c>false</c>.
/// </value> /// </value>
public bool CanPauseAndContinue { get; set; } public Boolean CanPauseAndContinue {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the exit code. /// Gets or sets the exit code.
@ -40,7 +42,9 @@ namespace Swan.Services
/// <value> /// <value>
/// The exit code. /// The exit code.
/// </value> /// </value>
public int ExitCode { get; set; } public Int32 ExitCode {
get; set;
}
/// <summary> /// <summary>
/// Indicates whether to report Start, Stop, Pause, and Continue commands in the event log. /// Indicates whether to report Start, Stop, Pause, and Continue commands in the event log.
@ -48,7 +52,9 @@ namespace Swan.Services
/// <value> /// <value>
/// <c>true</c> if [automatic log]; otherwise, <c>false</c>. /// <c>true</c> if [automatic log]; otherwise, <c>false</c>.
/// </value> /// </value>
public bool AutoLog { get; set; } public Boolean AutoLog {
get; set;
}
/// <summary> /// <summary>
/// Gets or sets the name of the service. /// Gets or sets the name of the service.
@ -56,17 +62,20 @@ namespace Swan.Services
/// <value> /// <value>
/// The name of the service. /// The name of the service.
/// </value> /// </value>
public string ServiceName { get; set; } public String ServiceName {
get; set;
}
/// <summary> /// <summary>
/// Stops the executing service. /// Stops the executing service.
/// </summary> /// </summary>
public void Stop() public void Stop() {
{ if(!this.CanStop) {
if (!CanStop) return; return;
}
CanStop = false; this.CanStop = false;
OnStop(); this.OnStop();
} }
/// <summary> /// <summary>
@ -74,8 +83,7 @@ namespace Swan.Services
/// or when the operating system starts (for a service that starts automatically). Specifies actions to take when the service starts. /// or when the operating system starts (for a service that starts automatically). Specifies actions to take when the service starts.
/// </summary> /// </summary>
/// <param name="args">The arguments.</param> /// <param name="args">The arguments.</param>
protected virtual void OnStart(string[] args) protected virtual void OnStart(String[] args) {
{
// do nothing // do nothing
} }
@ -83,10 +91,8 @@ namespace Swan.Services
/// When implemented in a derived class, executes when a Stop command is sent to the service by the Service Control Manager (SCM). /// When implemented in a derived class, executes when a Stop command is sent to the service by the Service Control Manager (SCM).
/// Specifies actions to take when a service stops running. /// Specifies actions to take when a service stops running.
/// </summary> /// </summary>
protected virtual void OnStop() protected virtual void OnStop() {
{
// do nothing // do nothing
} }
} }
} }
#endif

View File

@ -1,10 +1,9 @@
namespace Swan.Threading using System;
{
using System;
using System.Diagnostics; using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Threading {
/// <summary> /// <summary>
/// Represents logic providing several delay mechanisms. /// Represents logic providing several delay mechanisms.
/// </summary> /// </summary>
@ -27,28 +26,23 @@
/// } /// }
/// </code> /// </code>
/// </example> /// </example>
public sealed class DelayProvider : IDisposable public sealed class DelayProvider : IDisposable {
{ private readonly Object _syncRoot = new Object();
private readonly object _syncRoot = new object();
private readonly Stopwatch _delayStopwatch = new Stopwatch(); private readonly Stopwatch _delayStopwatch = new Stopwatch();
private bool _isDisposed; private Boolean _isDisposed;
private IWaitEvent _delayEvent; private IWaitEvent _delayEvent;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="DelayProvider"/> class. /// Initializes a new instance of the <see cref="DelayProvider"/> class.
/// </summary> /// </summary>
/// <param name="strategy">The strategy.</param> /// <param name="strategy">The strategy.</param>
public DelayProvider(DelayStrategy strategy = DelayStrategy.TaskDelay) public DelayProvider(DelayStrategy strategy = DelayStrategy.TaskDelay) => this.Strategy = strategy;
{
Strategy = strategy;
}
/// <summary> /// <summary>
/// Enumerates the different ways of providing delays. /// Enumerates the different ways of providing delays.
/// </summary> /// </summary>
public enum DelayStrategy public enum DelayStrategy {
{
/// <summary> /// <summary>
/// Using the Thread.Sleep(15) mechanism. /// Using the Thread.Sleep(15) mechanism.
/// </summary> /// </summary>
@ -68,22 +62,23 @@
/// <summary> /// <summary>
/// Gets the selected delay strategy. /// Gets the selected delay strategy.
/// </summary> /// </summary>
public DelayStrategy Strategy { get; } public DelayStrategy Strategy {
get;
}
/// <summary> /// <summary>
/// Creates the smallest possible, synchronous delay based on the selected strategy. /// Creates the smallest possible, synchronous delay based on the selected strategy.
/// </summary> /// </summary>
/// <returns>The elapsed time of the delay.</returns> /// <returns>The elapsed time of the delay.</returns>
public TimeSpan WaitOne() public TimeSpan WaitOne() {
{ lock(this._syncRoot) {
lock (_syncRoot) if(this._isDisposed) {
{ return TimeSpan.Zero;
if (_isDisposed) return TimeSpan.Zero; }
_delayStopwatch.Restart(); this._delayStopwatch.Restart();
switch (Strategy) switch(this.Strategy) {
{
case DelayStrategy.ThreadSleep: case DelayStrategy.ThreadSleep:
DelaySleep(); DelaySleep();
break; break;
@ -91,25 +86,26 @@
DelayTask(); DelayTask();
break; break;
case DelayStrategy.ThreadPool: case DelayStrategy.ThreadPool:
DelayThreadPool(); this.DelayThreadPool();
break; break;
} }
return _delayStopwatch.Elapsed; return this._delayStopwatch.Elapsed;
} }
} }
#region Dispose Pattern #region Dispose Pattern
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ lock(this._syncRoot) {
lock (_syncRoot) if(this._isDisposed) {
{ return;
if (_isDisposed) return; }
_isDisposed = true;
_delayEvent?.Dispose(); this._isDisposed = true;
this._delayEvent?.Dispose();
} }
} }
@ -121,19 +117,18 @@
private static void DelayTask() => Task.Delay(1).Wait(); private static void DelayTask() => Task.Delay(1).Wait();
private void DelayThreadPool() private void DelayThreadPool() {
{ if(this._delayEvent == null) {
if (_delayEvent == null) this._delayEvent = WaitEventFactory.Create(isCompleted: true, useSlim: true);
_delayEvent = WaitEventFactory.Create(isCompleted: true, useSlim: true); }
_delayEvent.Begin(); this._delayEvent.Begin();
ThreadPool.QueueUserWorkItem(s => _ = ThreadPool.QueueUserWorkItem(s => {
{
DelaySleep(); DelaySleep();
_delayEvent.Complete(); this._delayEvent.Complete();
}); });
_delayEvent.Wait(); this._delayEvent.Wait();
} }
#endregion #endregion

View File

@ -1,5 +1,4 @@
namespace Swan.Threading namespace Swan.Threading {
{
using System; using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -10,9 +9,8 @@
/// provides the ability to perform fine-grained control on these tasks. /// provides the ability to perform fine-grained control on these tasks.
/// </summary> /// </summary>
/// <seealso cref="IWorker" /> /// <seealso cref="IWorker" />
public abstract class ThreadWorkerBase : WorkerBase public abstract class ThreadWorkerBase : WorkerBase {
{ private readonly Object _syncLock = new Object();
private readonly object _syncLock = new object();
private readonly Thread _thread; private readonly Thread _thread;
/// <summary> /// <summary>
@ -22,12 +20,9 @@
/// <param name="priority">The thread priority.</param> /// <param name="priority">The thread priority.</param>
/// <param name="period">The interval of cycle execution.</param> /// <param name="period">The interval of cycle execution.</param>
/// <param name="delayProvider">The cycle delay provide implementation.</param> /// <param name="delayProvider">The cycle delay provide implementation.</param>
protected ThreadWorkerBase(string name, ThreadPriority priority, TimeSpan period, IWorkerDelayProvider delayProvider) protected ThreadWorkerBase(String name, ThreadPriority priority, TimeSpan period, IWorkerDelayProvider delayProvider) : base(name, period) {
: base(name, period) this.DelayProvider = delayProvider;
{ this._thread = new Thread(this.RunWorkerLoop) {
DelayProvider = delayProvider;
_thread = new Thread(RunWorkerLoop)
{
IsBackground = true, IsBackground = true,
Priority = priority, Priority = priority,
Name = name, Name = name,
@ -39,76 +34,61 @@
/// </summary> /// </summary>
/// <param name="name">The name.</param> /// <param name="name">The name.</param>
/// <param name="period">The execution interval.</param> /// <param name="period">The execution interval.</param>
protected ThreadWorkerBase(string name, TimeSpan period) protected ThreadWorkerBase(String name, TimeSpan period) : this(name, ThreadPriority.Normal, period, WorkerDelayProvider.Default) {
: this(name, ThreadPriority.Normal, period, WorkerDelayProvider.Default)
{
// placeholder // placeholder
} }
/// <summary> /// <summary>
/// Provides an implementation on a cycle delay provider. /// Provides an implementation on a cycle delay provider.
/// </summary> /// </summary>
protected IWorkerDelayProvider DelayProvider { get; } protected IWorkerDelayProvider DelayProvider {
get;
}
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> StartAsync() public override Task<WorkerState> StartAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Paused || this.WorkerState == WorkerState.Waiting) {
{ return this.ResumeAsync();
if (WorkerState == WorkerState.Paused || WorkerState == WorkerState.Waiting) }
return ResumeAsync();
if (WorkerState != WorkerState.Created) if(this.WorkerState != WorkerState.Created) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
if (IsStopRequested) if(this.IsStopRequested) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
var task = QueueStateChange(StateChangeRequest.Start); Task<WorkerState> task = this.QueueStateChange(StateChangeRequest.Start);
_thread.Start(); this._thread.Start();
return task; return task;
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> PauseAsync() public override Task<WorkerState> PauseAsync() {
{ lock(this._syncLock) {
lock (_syncLock) return this.WorkerState != WorkerState.Running && this.WorkerState != WorkerState.Waiting ? Task.FromResult(this.WorkerState) : this.IsStopRequested ? Task.FromResult(this.WorkerState) : this.QueueStateChange(StateChangeRequest.Pause);
{
if (WorkerState != WorkerState.Running && WorkerState != WorkerState.Waiting)
return Task.FromResult(WorkerState);
return IsStopRequested ? Task.FromResult(WorkerState) : QueueStateChange(StateChangeRequest.Pause);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> ResumeAsync() public override Task<WorkerState> ResumeAsync() {
{ lock(this._syncLock) {
lock (_syncLock) return this.WorkerState == WorkerState.Created ? this.StartAsync() : this.WorkerState != WorkerState.Paused && this.WorkerState != WorkerState.Waiting ? Task.FromResult(this.WorkerState) : this.IsStopRequested ? Task.FromResult(this.WorkerState) : this.QueueStateChange(StateChangeRequest.Resume);
{
if (WorkerState == WorkerState.Created)
return StartAsync();
if (WorkerState != WorkerState.Paused && WorkerState != WorkerState.Waiting)
return Task.FromResult(WorkerState);
return IsStopRequested ? Task.FromResult(WorkerState) : QueueStateChange(StateChangeRequest.Resume);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> StopAsync() public override Task<WorkerState> StopAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Stopped || this.WorkerState == WorkerState.Created) {
{ this.WorkerState = WorkerState.Stopped;
if (WorkerState == WorkerState.Stopped || WorkerState == WorkerState.Created) return Task.FromResult(this.WorkerState);
{
WorkerState = WorkerState.Stopped;
return Task.FromResult(WorkerState);
} }
return QueueStateChange(StateChangeRequest.Stop); return this.QueueStateChange(StateChangeRequest.Stop);
} }
} }
@ -120,81 +100,73 @@
/// <param name="wantedDelay">The remaining delay to wait for in the cycle.</param> /// <param name="wantedDelay">The remaining delay to wait for in the cycle.</param>
/// <param name="delayTask">Contains a reference to a task with the scheduled period delay.</param> /// <param name="delayTask">Contains a reference to a task with the scheduled period delay.</param>
/// <param name="token">The cancellation token to cancel waiting.</param> /// <param name="token">The cancellation token to cancel waiting.</param>
protected virtual void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) => protected virtual void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) =>
DelayProvider?.ExecuteCycleDelay(wantedDelay, delayTask, token); this.DelayProvider?.ExecuteCycleDelay(wantedDelay, delayTask, token);
/// <inheritdoc /> /// <inheritdoc />
protected override void OnDisposing() protected override void OnDisposing() {
{ lock(this._syncLock) {
lock (_syncLock) if((this._thread.ThreadState & ThreadState.Unstarted) != ThreadState.Unstarted) {
{ this._thread.Join();
if ((_thread.ThreadState & ThreadState.Unstarted) != ThreadState.Unstarted) }
_thread.Join();
} }
} }
/// <summary> /// <summary>
/// Implements worker control, execution and delay logic in a loop. /// Implements worker control, execution and delay logic in a loop.
/// </summary> /// </summary>
private void RunWorkerLoop() private void RunWorkerLoop() {
{ while(this.WorkerState != WorkerState.Stopped && !this.IsDisposing && !this.IsDisposed) {
while (WorkerState != WorkerState.Stopped && !IsDisposing && !IsDisposed) this.CycleStopwatch.Restart();
{ CancellationToken interruptToken = this.CycleCancellation.Token;
CycleStopwatch.Restart(); Int32 period = this.Period.TotalMilliseconds >= Int32.MaxValue ? -1 : Convert.ToInt32(Math.Floor(this.Period.TotalMilliseconds));
var interruptToken = CycleCancellation.Token; Task delayTask = Task.Delay(period, interruptToken);
var period = Period.TotalMilliseconds >= int.MaxValue ? -1 : Convert.ToInt32(Math.Floor(Period.TotalMilliseconds)); WorkerState initialWorkerState = this.WorkerState;
var delayTask = Task.Delay(period, interruptToken);
var initialWorkerState = WorkerState;
// Lock the cycle and capture relevant state valid for this cycle // Lock the cycle and capture relevant state valid for this cycle
CycleCompletedEvent.Reset(); this.CycleCompletedEvent.Reset();
// Process the tasks that are awaiting // Process the tasks that are awaiting
if (ProcessStateChangeRequests()) if(this.ProcessStateChangeRequests()) {
continue; continue;
}
try try {
{
if(initialWorkerState == WorkerState.Waiting && if(initialWorkerState == WorkerState.Waiting &&
!interruptToken.IsCancellationRequested) !interruptToken.IsCancellationRequested) {
{
// Mark the state as Running // Mark the state as Running
WorkerState = WorkerState.Running; this.WorkerState = WorkerState.Running;
// Call the execution logic // Call the execution logic
ExecuteCycleLogic(interruptToken); this.ExecuteCycleLogic(interruptToken);
} }
} } catch(Exception ex) {
catch (Exception ex) this.OnCycleException(ex);
{ } finally {
OnCycleException(ex);
}
finally
{
// Update the state // Update the state
WorkerState = initialWorkerState == WorkerState.Paused this.WorkerState = initialWorkerState == WorkerState.Paused
? WorkerState.Paused ? WorkerState.Paused
: WorkerState.Waiting; : WorkerState.Waiting;
// Signal the cycle has been completed so new cycles can be executed // Signal the cycle has been completed so new cycles can be executed
CycleCompletedEvent.Set(); this.CycleCompletedEvent.Set();
if (!interruptToken.IsCancellationRequested) if(!interruptToken.IsCancellationRequested) {
{ Int32 cycleDelay = this.ComputeCycleDelay(initialWorkerState);
var cycleDelay = ComputeCycleDelay(initialWorkerState); if(cycleDelay == Timeout.Infinite) {
if (cycleDelay == Timeout.Infinite)
delayTask = Task.Delay(Timeout.Infinite, interruptToken); delayTask = Task.Delay(Timeout.Infinite, interruptToken);
}
ExecuteCycleDelay( this.ExecuteCycleDelay(
cycleDelay, cycleDelay,
delayTask, delayTask,
CycleCancellation.Token); this.CycleCancellation.Token);
} }
} }
} }
ClearStateChangeRequests(); this.ClearStateChangeRequests();
WorkerState = WorkerState.Stopped; this.WorkerState = WorkerState.Stopped;
} }
/// <summary> /// <summary>
@ -203,28 +175,25 @@
/// </summary> /// </summary>
/// <param name="request">The request.</param> /// <param name="request">The request.</param>
/// <returns>The awaitable task.</returns> /// <returns>The awaitable task.</returns>
private Task<WorkerState> QueueStateChange(StateChangeRequest request) private Task<WorkerState> QueueStateChange(StateChangeRequest request) {
{ lock(this._syncLock) {
lock (_syncLock) if(this.StateChangeTask != null) {
{ return this.StateChangeTask;
if (StateChangeTask != null) }
return StateChangeTask;
var waitingTask = new Task<WorkerState>(() => Task<WorkerState> waitingTask = new Task<WorkerState>(() => {
{ this.StateChangedEvent.Wait();
StateChangedEvent.Wait(); lock(this._syncLock) {
lock (_syncLock) this.StateChangeTask = null;
{ return this.WorkerState;
StateChangeTask = null;
return WorkerState;
} }
}); });
StateChangeTask = waitingTask; this.StateChangeTask = waitingTask;
StateChangedEvent.Reset(); this.StateChangedEvent.Reset();
StateChangeRequests[request] = true; this.StateChangeRequests[request] = true;
waitingTask.Start(); waitingTask.Start();
CycleCancellation.Cancel(); this.CycleCancellation.Cancel();
return waitingTask; return waitingTask;
} }
@ -235,36 +204,28 @@
/// cycle execution accordingly. The <see cref="WorkerState"/> is also updated. /// cycle execution accordingly. The <see cref="WorkerState"/> is also updated.
/// </summary> /// </summary>
/// <returns>Returns <c>true</c> if the execution should be terminated. <c>false</c> otherwise.</returns> /// <returns>Returns <c>true</c> if the execution should be terminated. <c>false</c> otherwise.</returns>
private bool ProcessStateChangeRequests() private Boolean ProcessStateChangeRequests() {
{ lock(this._syncLock) {
lock (_syncLock) Boolean hasRequest = false;
{ WorkerState currentState = this.WorkerState;
var hasRequest = false;
var currentState = WorkerState;
// Update the state in the given priority // Update the state in the given priority
if (StateChangeRequests[StateChangeRequest.Stop] || IsDisposing || IsDisposed) if(this.StateChangeRequests[StateChangeRequest.Stop] || this.IsDisposing || this.IsDisposed) {
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Stopped; this.WorkerState = WorkerState.Stopped;
} } else if(this.StateChangeRequests[StateChangeRequest.Pause]) {
else if (StateChangeRequests[StateChangeRequest.Pause])
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Paused; this.WorkerState = WorkerState.Paused;
} } else if(this.StateChangeRequests[StateChangeRequest.Start] || this.StateChangeRequests[StateChangeRequest.Resume]) {
else if (StateChangeRequests[StateChangeRequest.Start] || StateChangeRequests[StateChangeRequest.Resume])
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Waiting; this.WorkerState = WorkerState.Waiting;
} }
// Signals all state changes to continue // Signals all state changes to continue
// as a command has been handled. // as a command has been handled.
if (hasRequest) if(hasRequest) {
{ this.ClearStateChangeRequests();
ClearStateChangeRequests(); this.OnStateChangeProcessed(currentState, this.WorkerState);
OnStateChangeProcessed(currentState, WorkerState);
} }
return hasRequest; return hasRequest;
@ -274,18 +235,16 @@
/// <summary> /// <summary>
/// Signals all state change requests to set. /// Signals all state change requests to set.
/// </summary> /// </summary>
private void ClearStateChangeRequests() private void ClearStateChangeRequests() {
{ lock(this._syncLock) {
lock (_syncLock)
{
// Mark all events as completed // Mark all events as completed
StateChangeRequests[StateChangeRequest.Start] = false; this.StateChangeRequests[StateChangeRequest.Start] = false;
StateChangeRequests[StateChangeRequest.Pause] = false; this.StateChangeRequests[StateChangeRequest.Pause] = false;
StateChangeRequests[StateChangeRequest.Resume] = false; this.StateChangeRequests[StateChangeRequest.Resume] = false;
StateChangeRequests[StateChangeRequest.Stop] = false; this.StateChangeRequests[StateChangeRequest.Stop] = false;
StateChangedEvent.Set(); this.StateChangedEvent.Set();
CycleCompletedEvent.Set(); this.CycleCompletedEvent.Set();
} }
} }
} }

View File

@ -1,105 +1,95 @@
namespace Swan.Threading using System;
{
using System;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Threading {
/// <summary> /// <summary>
/// Provides a base implementation for application workers. /// Provides a base implementation for application workers.
/// </summary> /// </summary>
/// <seealso cref="IWorker" /> /// <seealso cref="IWorker" />
public abstract class TimerWorkerBase : WorkerBase public abstract class TimerWorkerBase : WorkerBase {
{ private readonly Object _syncLock = new Object();
private readonly object _syncLock = new object();
private readonly Timer _timer; private readonly Timer _timer;
private bool _isTimerAlive = true; private Boolean _isTimerAlive = true;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="TimerWorkerBase"/> class. /// Initializes a new instance of the <see cref="TimerWorkerBase"/> class.
/// </summary> /// </summary>
/// <param name="name">The name.</param> /// <param name="name">The name.</param>
/// <param name="period">The execution interval.</param> /// <param name="period">The execution interval.</param>
protected TimerWorkerBase(string name, TimeSpan period) protected TimerWorkerBase(String name, TimeSpan period) : base(name, period) =>
: base(name, period)
{
// Instantiate the timer that will be used to schedule cycles // Instantiate the timer that will be used to schedule cycles
_timer = new Timer( this._timer = new Timer(this.ExecuteTimerCallback, this, Timeout.Infinite, Timeout.Infinite);
ExecuteTimerCallback,
this,
Timeout.Infinite,
Timeout.Infinite);
}
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> StartAsync() public override Task<WorkerState> StartAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Paused || this.WorkerState == WorkerState.Waiting) {
{ return this.ResumeAsync();
if (WorkerState == WorkerState.Paused || WorkerState == WorkerState.Waiting) }
return ResumeAsync();
if (WorkerState != WorkerState.Created) if(this.WorkerState != WorkerState.Created) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
if (IsStopRequested) if(this.IsStopRequested) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
var task = QueueStateChange(StateChangeRequest.Start); Task<WorkerState> task = this.QueueStateChange(StateChangeRequest.Start);
Interrupt(); this.Interrupt();
return task; return task;
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> PauseAsync() public override Task<WorkerState> PauseAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState != WorkerState.Running && this.WorkerState != WorkerState.Waiting) {
{ return Task.FromResult(this.WorkerState);
if (WorkerState != WorkerState.Running && WorkerState != WorkerState.Waiting) }
return Task.FromResult(WorkerState);
if (IsStopRequested) if(this.IsStopRequested) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
var task = QueueStateChange(StateChangeRequest.Pause); Task<WorkerState> task = this.QueueStateChange(StateChangeRequest.Pause);
Interrupt(); this.Interrupt();
return task; return task;
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> ResumeAsync() public override Task<WorkerState> ResumeAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Created) {
{ return this.StartAsync();
if (WorkerState == WorkerState.Created) }
return StartAsync();
if (WorkerState != WorkerState.Paused && WorkerState != WorkerState.Waiting) if(this.WorkerState != WorkerState.Paused && this.WorkerState != WorkerState.Waiting) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
if (IsStopRequested) if(this.IsStopRequested) {
return Task.FromResult(WorkerState); return Task.FromResult(this.WorkerState);
}
var task = QueueStateChange(StateChangeRequest.Resume); Task<WorkerState> task = this.QueueStateChange(StateChangeRequest.Resume);
Interrupt(); this.Interrupt();
return task; return task;
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override Task<WorkerState> StopAsync() public override Task<WorkerState> StopAsync() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Stopped || this.WorkerState == WorkerState.Created) {
{ this.WorkerState = WorkerState.Stopped;
if (WorkerState == WorkerState.Stopped || WorkerState == WorkerState.Created) return Task.FromResult(this.WorkerState);
{
WorkerState = WorkerState.Stopped;
return Task.FromResult(WorkerState);
} }
var task = QueueStateChange(StateChangeRequest.Stop); Task<WorkerState> task = this.QueueStateChange(StateChangeRequest.Stop);
Interrupt(); this.Interrupt();
return task; return task;
} }
} }
@ -110,40 +100,41 @@
/// immediately. /// immediately.
/// </summary> /// </summary>
/// <param name="delay">The delay.</param> /// <param name="delay">The delay.</param>
protected void ScheduleCycle(int delay) protected void ScheduleCycle(Int32 delay) {
{ lock(this._syncLock) {
lock (_syncLock) if(!this._isTimerAlive) {
{ return;
if (!_isTimerAlive) return; }
_timer.Change(delay, Timeout.Infinite);
_ = this._timer.Change(delay, Timeout.Infinite);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Dispose(bool disposing) protected override void Dispose(Boolean disposing) {
{
base.Dispose(disposing); base.Dispose(disposing);
lock (_syncLock) lock(this._syncLock) {
{ if(!this._isTimerAlive) {
if (!_isTimerAlive) return; return;
_isTimerAlive = false; }
_timer.Dispose();
this._isTimerAlive = false;
this._timer.Dispose();
} }
} }
/// <summary> /// <summary>
/// Cancels the current token and schedules a new cycle immediately. /// Cancels the current token and schedules a new cycle immediately.
/// </summary> /// </summary>
private void Interrupt() private void Interrupt() {
{ lock(this._syncLock) {
lock (_syncLock) if(this.WorkerState == WorkerState.Stopped) {
{
if (WorkerState == WorkerState.Stopped)
return; return;
}
CycleCancellation.Cancel(); this.CycleCancellation.Cancel();
ScheduleCycle(0); this.ScheduleCycle(0);
} }
} }
@ -153,68 +144,62 @@
/// the execution of use cycle code, /// the execution of use cycle code,
/// and the scheduling of new cycles. /// and the scheduling of new cycles.
/// </summary> /// </summary>
private void ExecuteWorkerCycle() private void ExecuteWorkerCycle() {
{ this.CycleStopwatch.Restart();
CycleStopwatch.Restart();
lock (_syncLock) lock(this._syncLock) {
{ if(this.IsDisposing || this.IsDisposed) {
if (IsDisposing || IsDisposed) this.WorkerState = WorkerState.Stopped;
{
WorkerState = WorkerState.Stopped;
// Cancel any awaiters // Cancel any awaiters
try { StateChangedEvent.Set(); } try {
catch { /* Ignore */ } this.StateChangedEvent.Set();
} catch { /* Ignore */ }
return; return;
} }
// Prevent running another instance of the cycle // Prevent running another instance of the cycle
if (CycleCompletedEvent.IsSet == false) return; if(this.CycleCompletedEvent.IsSet == false) {
return;
}
// Lock the cycle and capture relevant state valid for this cycle // Lock the cycle and capture relevant state valid for this cycle
CycleCompletedEvent.Reset(); this.CycleCompletedEvent.Reset();
} }
var interruptToken = CycleCancellation.Token; CancellationToken interruptToken = this.CycleCancellation.Token;
var initialWorkerState = WorkerState; WorkerState initialWorkerState = this.WorkerState;
// Process the tasks that are awaiting // Process the tasks that are awaiting
if (ProcessStateChangeRequests()) if(this.ProcessStateChangeRequests()) {
return; return;
}
try try {
{
if(initialWorkerState == WorkerState.Waiting && if(initialWorkerState == WorkerState.Waiting &&
!interruptToken.IsCancellationRequested) !interruptToken.IsCancellationRequested) {
{
// Mark the state as Running // Mark the state as Running
WorkerState = WorkerState.Running; this.WorkerState = WorkerState.Running;
// Call the execution logic // Call the execution logic
ExecuteCycleLogic(interruptToken); this.ExecuteCycleLogic(interruptToken);
} }
} } catch(Exception ex) {
catch (Exception ex) this.OnCycleException(ex);
{ } finally {
OnCycleException(ex);
}
finally
{
// Update the state // Update the state
WorkerState = initialWorkerState == WorkerState.Paused this.WorkerState = initialWorkerState == WorkerState.Paused
? WorkerState.Paused ? WorkerState.Paused
: WorkerState.Waiting; : WorkerState.Waiting;
lock (_syncLock) lock(this._syncLock) {
{
// Signal the cycle has been completed so new cycles can be executed // Signal the cycle has been completed so new cycles can be executed
CycleCompletedEvent.Set(); this.CycleCompletedEvent.Set();
// Schedule a new cycle // Schedule a new cycle
ScheduleCycle(!interruptToken.IsCancellationRequested this.ScheduleCycle(!interruptToken.IsCancellationRequested
? ComputeCycleDelay(initialWorkerState) ? this.ComputeCycleDelay(initialWorkerState)
: 0); : 0);
} }
} }
@ -224,7 +209,7 @@
/// Represents the callback that is executed when the <see cref="_timer"/> ticks. /// Represents the callback that is executed when the <see cref="_timer"/> ticks.
/// </summary> /// </summary>
/// <param name="state">The state -- this contains the worker.</param> /// <param name="state">The state -- this contains the worker.</param>
private void ExecuteTimerCallback(object state) => ExecuteWorkerCycle(); private void ExecuteTimerCallback(Object state) => this.ExecuteWorkerCycle();
/// <summary> /// <summary>
/// Queues a transition in worker state for processing. Returns a task that can be awaited /// Queues a transition in worker state for processing. Returns a task that can be awaited
@ -232,28 +217,25 @@
/// </summary> /// </summary>
/// <param name="request">The request.</param> /// <param name="request">The request.</param>
/// <returns>The awaitable task.</returns> /// <returns>The awaitable task.</returns>
private Task<WorkerState> QueueStateChange(StateChangeRequest request) private Task<WorkerState> QueueStateChange(StateChangeRequest request) {
{ lock(this._syncLock) {
lock (_syncLock) if(this.StateChangeTask != null) {
{ return this.StateChangeTask;
if (StateChangeTask != null) }
return StateChangeTask;
var waitingTask = new Task<WorkerState>(() => Task<WorkerState> waitingTask = new Task<WorkerState>(() => {
{ this.StateChangedEvent.Wait();
StateChangedEvent.Wait(); lock(this._syncLock) {
lock (_syncLock) this.StateChangeTask = null;
{ return this.WorkerState;
StateChangeTask = null;
return WorkerState;
} }
}); });
StateChangeTask = waitingTask; this.StateChangeTask = waitingTask;
StateChangedEvent.Reset(); this.StateChangedEvent.Reset();
StateChangeRequests[request] = true; this.StateChangeRequests[request] = true;
waitingTask.Start(); waitingTask.Start();
CycleCancellation.Cancel(); this.CycleCancellation.Cancel();
return waitingTask; return waitingTask;
} }
@ -264,38 +246,30 @@
/// cycle execution accordingly. The <see cref="WorkerState"/> is also updated. /// cycle execution accordingly. The <see cref="WorkerState"/> is also updated.
/// </summary> /// </summary>
/// <returns>Returns <c>true</c> if the execution should be terminated. <c>false</c> otherwise.</returns> /// <returns>Returns <c>true</c> if the execution should be terminated. <c>false</c> otherwise.</returns>
private bool ProcessStateChangeRequests() private Boolean ProcessStateChangeRequests() {
{ lock(this._syncLock) {
lock (_syncLock) WorkerState currentState = this.WorkerState;
{ Boolean hasRequest = false;
var currentState = WorkerState; Int32 schedule = 0;
var hasRequest = false;
var schedule = 0;
// Update the state according to request priority // Update the state according to request priority
if (StateChangeRequests[StateChangeRequest.Stop] || IsDisposing || IsDisposed) if(this.StateChangeRequests[StateChangeRequest.Stop] || this.IsDisposing || this.IsDisposed) {
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Stopped; this.WorkerState = WorkerState.Stopped;
schedule = StateChangeRequests[StateChangeRequest.Stop] ? Timeout.Infinite : 0; schedule = this.StateChangeRequests[StateChangeRequest.Stop] ? Timeout.Infinite : 0;
} } else if(this.StateChangeRequests[StateChangeRequest.Pause]) {
else if (StateChangeRequests[StateChangeRequest.Pause])
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Paused; this.WorkerState = WorkerState.Paused;
schedule = Timeout.Infinite; schedule = Timeout.Infinite;
} } else if(this.StateChangeRequests[StateChangeRequest.Start] || this.StateChangeRequests[StateChangeRequest.Resume]) {
else if (StateChangeRequests[StateChangeRequest.Start] || StateChangeRequests[StateChangeRequest.Resume])
{
hasRequest = true; hasRequest = true;
WorkerState = WorkerState.Waiting; this.WorkerState = WorkerState.Waiting;
} }
// Signals all state changes to continue // Signals all state changes to continue
// as a command has been handled. // as a command has been handled.
if (hasRequest) if(hasRequest) {
{ this.ClearStateChangeRequests(schedule, currentState, this.WorkerState);
ClearStateChangeRequests(schedule, currentState, WorkerState);
} }
return hasRequest; return hasRequest;
@ -308,20 +282,18 @@
/// <param name="schedule">The cycle schedule.</param> /// <param name="schedule">The cycle schedule.</param>
/// <param name="oldState">The previous worker state.</param> /// <param name="oldState">The previous worker state.</param>
/// <param name="newState">The new worker state.</param> /// <param name="newState">The new worker state.</param>
private void ClearStateChangeRequests(int schedule, WorkerState oldState, WorkerState newState) private void ClearStateChangeRequests(Int32 schedule, WorkerState oldState, WorkerState newState) {
{ lock(this._syncLock) {
lock (_syncLock)
{
// Mark all events as completed // Mark all events as completed
StateChangeRequests[StateChangeRequest.Start] = false; this.StateChangeRequests[StateChangeRequest.Start] = false;
StateChangeRequests[StateChangeRequest.Pause] = false; this.StateChangeRequests[StateChangeRequest.Pause] = false;
StateChangeRequests[StateChangeRequest.Resume] = false; this.StateChangeRequests[StateChangeRequest.Resume] = false;
StateChangeRequests[StateChangeRequest.Stop] = false; this.StateChangeRequests[StateChangeRequest.Stop] = false;
StateChangedEvent.Set(); this.StateChangedEvent.Set();
CycleCompletedEvent.Set(); this.CycleCompletedEvent.Set();
OnStateChangeProcessed(oldState, newState); this.OnStateChangeProcessed(oldState, newState);
ScheduleCycle(schedule); this.ScheduleCycle(schedule);
} }
} }
} }

View File

@ -1,20 +1,19 @@
namespace Swan.Threading #nullable enable
{
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Threading {
/// <summary> /// <summary>
/// Provides base infrastructure for Timer and Thread workers. /// Provides base infrastructure for Timer and Thread workers.
/// </summary> /// </summary>
/// <seealso cref="IWorker" /> /// <seealso cref="IWorker" />
public abstract class WorkerBase : IWorker, IDisposable public abstract class WorkerBase : IWorker, IDisposable {
{
// Since these are API property backers, we use interlocked to read from them // Since these are API property backers, we use interlocked to read from them
// to avoid deadlocked reads // to avoid deadlocked reads
private readonly object _syncLock = new object(); private readonly Object _syncLock = new Object();
private readonly AtomicBoolean _isDisposed = new AtomicBoolean(); private readonly AtomicBoolean _isDisposed = new AtomicBoolean();
private readonly AtomicBoolean _isDisposing = new AtomicBoolean(); private readonly AtomicBoolean _isDisposing = new AtomicBoolean();
@ -26,13 +25,11 @@
/// </summary> /// </summary>
/// <param name="name">The name.</param> /// <param name="name">The name.</param>
/// <param name="period">The execution interval.</param> /// <param name="period">The execution interval.</param>
protected WorkerBase(string name, TimeSpan period) protected WorkerBase(String name, TimeSpan period) {
{ this.Name = name;
Name = name; this._timeSpan = new AtomicTimeSpan(period);
_timeSpan = new AtomicTimeSpan(period);
StateChangeRequests = new Dictionary<StateChangeRequest, bool>(5) this.StateChangeRequests = new Dictionary<StateChangeRequest, Boolean>(5) {
{
[StateChangeRequest.Start] = false, [StateChangeRequest.Start] = false,
[StateChangeRequest.Pause] = false, [StateChangeRequest.Pause] = false,
[StateChangeRequest.Resume] = false, [StateChangeRequest.Resume] = false,
@ -43,8 +40,7 @@
/// <summary> /// <summary>
/// Enumerates all the different state change requests. /// Enumerates all the different state change requests.
/// </summary> /// </summary>
protected enum StateChangeRequest protected enum StateChangeRequest {
{
/// <summary> /// <summary>
/// No state change request. /// No state change request.
/// </summary> /// </summary>
@ -72,34 +68,32 @@
} }
/// <inheritdoc /> /// <inheritdoc />
public string Name { get; } public String Name {
get;
/// <inheritdoc />
public TimeSpan Period
{
get => _timeSpan.Value;
set => _timeSpan.Value = value;
} }
/// <inheritdoc /> /// <inheritdoc />
public WorkerState WorkerState public TimeSpan Period {
{ get => this._timeSpan.Value;
get => _workerState.Value; set => this._timeSpan.Value = value;
protected set => _workerState.Value = value;
} }
/// <inheritdoc /> /// <inheritdoc />
public bool IsDisposed public WorkerState WorkerState {
{ get => this._workerState.Value;
get => _isDisposed.Value; protected set => this._workerState.Value = value;
protected set => _isDisposed.Value = value;
} }
/// <inheritdoc /> /// <inheritdoc />
public bool IsDisposing public Boolean IsDisposed {
{ get => this._isDisposed.Value;
get => _isDisposing.Value; protected set => this._isDisposed.Value = value;
protected set => _isDisposing.Value = value; }
/// <inheritdoc />
public Boolean IsDisposing {
get => this._isDisposing.Value;
protected set => this._isDisposing.Value = value;
} }
/// <summary> /// <summary>
@ -111,7 +105,7 @@
/// Gets a value indicating whether stop has been requested. /// Gets a value indicating whether stop has been requested.
/// This is useful to prevent more requests from being issued. /// This is useful to prevent more requests from being issued.
/// </summary> /// </summary>
protected bool IsStopRequested => StateChangeRequests[StateChangeRequest.Stop]; protected Boolean IsStopRequested => this.StateChangeRequests[StateChangeRequest.Stop];
/// <summary> /// <summary>
/// Gets the cycle stopwatch. /// Gets the cycle stopwatch.
@ -121,7 +115,9 @@
/// <summary> /// <summary>
/// Gets the state change requests. /// Gets the state change requests.
/// </summary> /// </summary>
protected Dictionary<StateChangeRequest, bool> StateChangeRequests { get; } protected Dictionary<StateChangeRequest, Boolean> StateChangeRequests {
get;
}
/// <summary> /// <summary>
/// Gets the cycle completed event. /// Gets the cycle completed event.
@ -141,7 +137,9 @@
/// <summary> /// <summary>
/// Gets or sets the state change task. /// Gets or sets the state change task.
/// </summary> /// </summary>
protected Task<WorkerState>? StateChangeTask { get; set; } protected Task<WorkerState>? StateChangeTask {
get; set;
}
/// <inheritdoc /> /// <inheritdoc />
public abstract Task<WorkerState> StartAsync(); public abstract Task<WorkerState> StartAsync();
@ -156,9 +154,8 @@
public abstract Task<WorkerState> StopAsync(); public abstract Task<WorkerState> StopAsync();
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose() {
{ this.Dispose(true);
Dispose(true);
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
@ -166,28 +163,29 @@
/// Releases unmanaged and - optionally - managed resources. /// Releases unmanaged and - optionally - managed resources.
/// </summary> /// </summary>
/// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param> /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param>
protected virtual void Dispose(bool disposing) protected virtual void Dispose(Boolean disposing) {
{ lock(this._syncLock) {
lock (_syncLock) if(this.IsDisposed || this.IsDisposing) {
{ return;
if (IsDisposed || IsDisposing) return; }
IsDisposing = true;
this.IsDisposing = true;
} }
// This also ensures the state change queue gets cleared // This also ensures the state change queue gets cleared
StopAsync().Wait(); this.StopAsync().Wait();
StateChangedEvent.Set(); this.StateChangedEvent.Set();
CycleCompletedEvent.Set(); this.CycleCompletedEvent.Set();
OnDisposing(); this.OnDisposing();
CycleStopwatch.Stop(); this.CycleStopwatch.Stop();
StateChangedEvent.Dispose(); this.StateChangedEvent.Dispose();
CycleCompletedEvent.Dispose(); this.CycleCompletedEvent.Dispose();
CycleCancellation.Dispose(); this.CycleCancellation.Dispose();
IsDisposed = true; this.IsDisposed = true;
IsDisposing = false; this.IsDisposing = false;
} }
/// <summary> /// <summary>
@ -214,8 +212,7 @@
/// </summary> /// </summary>
/// <param name="previousState">The state before the change.</param> /// <param name="previousState">The state before the change.</param>
/// <param name="newState">The new state.</param> /// <param name="newState">The new state.</param>
protected virtual void OnStateChangeProcessed(WorkerState previousState, WorkerState newState) protected virtual void OnStateChangeProcessed(WorkerState previousState, WorkerState newState) {
{
// placeholder // placeholder
} }
@ -224,17 +221,13 @@
/// </summary> /// </summary>
/// <param name="initialWorkerState">Initial state of the worker.</param> /// <param name="initialWorkerState">Initial state of the worker.</param>
/// <returns>The number of milliseconds to delay for.</returns> /// <returns>The number of milliseconds to delay for.</returns>
protected int ComputeCycleDelay(WorkerState initialWorkerState) protected Int32 ComputeCycleDelay(WorkerState initialWorkerState) {
{ Int64 elapsedMillis = this.CycleStopwatch.ElapsedMilliseconds;
var elapsedMillis = CycleStopwatch.ElapsedMilliseconds; TimeSpan period = this.Period;
var period = Period; Double periodMillis = period.TotalMilliseconds;
var periodMillis = period.TotalMilliseconds; Double delayMillis = periodMillis - elapsedMillis;
var delayMillis = periodMillis - elapsedMillis;
if (initialWorkerState == WorkerState.Paused || period == TimeSpan.MaxValue || delayMillis >= int.MaxValue) return initialWorkerState == WorkerState.Paused || period == TimeSpan.MaxValue || delayMillis >= Int32.MaxValue ? Timeout.Infinite : elapsedMillis >= periodMillis ? 0 : Convert.ToInt32(Math.Floor(delayMillis));
return Timeout.Infinite;
return elapsedMillis >= periodMillis ? 0 : Convert.ToInt32(Math.Floor(delayMillis));
} }
} }
} }

View File

@ -1,15 +1,13 @@
namespace Swan.Threading using System;
{
using System;
using System.Diagnostics; using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan.Threading {
/// <summary> /// <summary>
/// Represents a class that implements delay logic for thread workers. /// Represents a class that implements delay logic for thread workers.
/// </summary> /// </summary>
public static class WorkerDelayProvider public static class WorkerDelayProvider {
{
/// <summary> /// <summary>
/// Gets the default delay provider. /// Gets the default delay provider.
/// </summary> /// </summary>
@ -38,114 +36,111 @@
/// </summary> /// </summary>
public static IWorkerDelayProvider SteppedToken => new SteppedTokenDelay(); public static IWorkerDelayProvider SteppedToken => new SteppedTokenDelay();
private class TokenCancellableDelay : IWorkerDelayProvider private class TokenCancellableDelay : IWorkerDelayProvider {
{ public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) {
public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) if(wantedDelay == 0 || wantedDelay < -1) {
{
if (wantedDelay == 0 || wantedDelay < -1)
return; return;
}
// for wanted delays of less than 30ms it is not worth // for wanted delays of less than 30ms it is not worth
// passing a timeout or a token as it only adds unnecessary // passing a timeout or a token as it only adds unnecessary
// overhead. // overhead.
if (wantedDelay <= 30) if(wantedDelay <= 30) {
{ try {
try { delayTask.Wait(token); } delayTask.Wait(token);
catch { /* ignore */ } } catch { /* ignore */ }
return; return;
} }
// only wait on the cancellation token // only wait on the cancellation token
// or until the task completes normally // or until the task completes normally
try { delayTask.Wait(token); } try {
catch { /* ignore */ } delayTask.Wait(token);
} catch { /* ignore */ }
} }
} }
private class TokenTimeoutCancellableDelay : IWorkerDelayProvider private class TokenTimeoutCancellableDelay : IWorkerDelayProvider {
{ public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) {
public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) if(wantedDelay == 0 || wantedDelay < -1) {
{
if (wantedDelay == 0 || wantedDelay < -1)
return; return;
}
// for wanted delays of less than 30ms it is not worth // for wanted delays of less than 30ms it is not worth
// passing a timeout or a token as it only adds unnecessary // passing a timeout or a token as it only adds unnecessary
// overhead. // overhead.
if (wantedDelay <= 30) if(wantedDelay <= 30) {
{ try {
try { delayTask.Wait(token); } delayTask.Wait(token);
catch { /* ignore */ } } catch { /* ignore */ }
return; return;
} }
try { delayTask.Wait(wantedDelay, token); } try {
catch { /* ignore */ } _ = delayTask.Wait(wantedDelay, token);
} catch { /* ignore */ }
} }
} }
private class TokenSleepDelay : IWorkerDelayProvider private class TokenSleepDelay : IWorkerDelayProvider {
{
private readonly Stopwatch _elapsedWait = new Stopwatch(); private readonly Stopwatch _elapsedWait = new Stopwatch();
public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) {
{ this._elapsedWait.Restart();
_elapsedWait.Restart();
if (wantedDelay == 0 || wantedDelay < -1) if(wantedDelay == 0 || wantedDelay < -1) {
return; return;
}
while (!token.IsCancellationRequested) while(!token.IsCancellationRequested) {
{
Thread.Sleep(5); Thread.Sleep(5);
if (wantedDelay != Timeout.Infinite && _elapsedWait.ElapsedMilliseconds >= wantedDelay) if(wantedDelay != Timeout.Infinite && this._elapsedWait.ElapsedMilliseconds >= wantedDelay) {
break; break;
} }
} }
} }
}
private class SteppedTokenDelay : IWorkerDelayProvider private class SteppedTokenDelay : IWorkerDelayProvider {
{ private const Int32 StepMilliseconds = 15;
private const int StepMilliseconds = 15;
private readonly Stopwatch _elapsedWait = new Stopwatch(); private readonly Stopwatch _elapsedWait = new Stopwatch();
public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) {
{ this._elapsedWait.Restart();
_elapsedWait.Restart();
if (wantedDelay == 0 || wantedDelay < -1) if(wantedDelay == 0 || wantedDelay < -1) {
return;
if (wantedDelay == Timeout.Infinite)
{
try { delayTask.Wait(wantedDelay, token); }
catch { /* Ignore cancelled tasks */ }
return; return;
} }
while (!token.IsCancellationRequested) if(wantedDelay == Timeout.Infinite) {
{ try {
var remainingWaitTime = wantedDelay - Convert.ToInt32(_elapsedWait.ElapsedMilliseconds); _ = delayTask.Wait(wantedDelay, token);
} catch { /* Ignore cancelled tasks */ }
return;
}
while(!token.IsCancellationRequested) {
Int32 remainingWaitTime = wantedDelay - Convert.ToInt32(this._elapsedWait.ElapsedMilliseconds);
// Exit for no remaining wait time // Exit for no remaining wait time
if (remainingWaitTime <= 0) if(remainingWaitTime <= 0) {
break; break;
}
if (remainingWaitTime >= StepMilliseconds) if(remainingWaitTime >= StepMilliseconds) {
{
Task.Delay(StepMilliseconds, token).Wait(token); Task.Delay(StepMilliseconds, token).Wait(token);
} } else {
else try {
{ _ = delayTask.Wait(remainingWaitTime);
try { delayTask.Wait(remainingWaitTime); } } catch { /* ignore cancellation of task exception */ }
catch { /* ignore cancellation of task exception */ }
} }
if (_elapsedWait.ElapsedMilliseconds >= wantedDelay) if(this._elapsedWait.ElapsedMilliseconds >= wantedDelay) {
break; break;
} }
} }
} }
} }
} }
}

View File

@ -1,37 +1,31 @@
using System.Collections.Concurrent; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Swan namespace Swan {
{
/// <summary> /// <summary>
/// A base class for implementing models that fire notifications when their properties change. /// A base class for implementing models that fire notifications when their properties change.
/// This class is ideal for implementing MVVM driven UIs. /// This class is ideal for implementing MVVM driven UIs.
/// </summary> /// </summary>
/// <seealso cref="INotifyPropertyChanged" /> /// <seealso cref="INotifyPropertyChanged" />
public abstract class ViewModelBase : INotifyPropertyChanged public abstract class ViewModelBase : INotifyPropertyChanged {
{ private readonly ConcurrentDictionary<String, Boolean> _queuedNotifications = new ConcurrentDictionary<String, Boolean>();
private readonly ConcurrentDictionary<string, bool> _queuedNotifications = new ConcurrentDictionary<string, bool>(); private readonly Boolean _useDeferredNotifications;
private readonly bool _useDeferredNotifications;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ViewModelBase"/> class. /// Initializes a new instance of the <see cref="ViewModelBase"/> class.
/// </summary> /// </summary>
/// <param name="useDeferredNotifications">Set to <c>true</c> to use deferred notifications in the background.</param> /// <param name="useDeferredNotifications">Set to <c>true</c> to use deferred notifications in the background.</param>
protected ViewModelBase(bool useDeferredNotifications) protected ViewModelBase(Boolean useDeferredNotifications) => this._useDeferredNotifications = useDeferredNotifications;
{
_useDeferredNotifications = useDeferredNotifications;
}
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ViewModelBase"/> class. /// Initializes a new instance of the <see cref="ViewModelBase"/> class.
/// </summary> /// </summary>
protected ViewModelBase() protected ViewModelBase() : this(false) {
: this(false)
{
// placeholder // placeholder
} }
@ -49,13 +43,13 @@ namespace Swan
/// <param name="notifyAlso">An array of property names to notify in addition to notifying the changes on the current property name.</param> /// <param name="notifyAlso">An array of property names to notify in addition to notifying the changes on the current property name.</param>
/// <returns>True if the value was changed, false if the existing value matched the /// <returns>True if the value was changed, false if the existing value matched the
/// desired value.</returns> /// desired value.</returns>
protected bool SetProperty<T>(ref T storage, T value, [CallerMemberName] string propertyName = "", string[] notifyAlso = null) protected Boolean SetProperty<T>(ref T storage, T value, [CallerMemberName] String propertyName = "", String[] notifyAlso = null) {
{ if(EqualityComparer<T>.Default.Equals(storage, value)) {
if (EqualityComparer<T>.Default.Equals(storage, value))
return false; return false;
}
storage = value; storage = value;
NotifyPropertyChanged(propertyName, notifyAlso); this.NotifyPropertyChanged(propertyName, notifyAlso);
return true; return true;
} }
@ -63,54 +57,55 @@ namespace Swan
/// Notifies one or more properties changed. /// Notifies one or more properties changed.
/// </summary> /// </summary>
/// <param name="propertyNames">The property names.</param> /// <param name="propertyNames">The property names.</param>
protected void NotifyPropertyChanged(params string[] propertyNames) => NotifyPropertyChanged(null, propertyNames); protected void NotifyPropertyChanged(params String[] propertyNames) => this.NotifyPropertyChanged(null, propertyNames);
/// <summary> /// <summary>
/// Notifies one or more properties changed. /// Notifies one or more properties changed.
/// </summary> /// </summary>
/// <param name="mainProperty">The main property.</param> /// <param name="mainProperty">The main property.</param>
/// <param name="auxiliaryProperties">The auxiliary properties.</param> /// <param name="auxiliaryProperties">The auxiliary properties.</param>
private void NotifyPropertyChanged(string mainProperty, string[] auxiliaryProperties) private void NotifyPropertyChanged(String mainProperty, String[] auxiliaryProperties) {
{
// Queue property notification // Queue property notification
if (string.IsNullOrWhiteSpace(mainProperty) == false) if(String.IsNullOrWhiteSpace(mainProperty) == false) {
_queuedNotifications[mainProperty] = true; this._queuedNotifications[mainProperty] = true;
}
// Set the state for notification properties // Set the state for notification properties
if (auxiliaryProperties != null) if(auxiliaryProperties != null) {
{ foreach(String property in auxiliaryProperties) {
foreach (var property in auxiliaryProperties) if(String.IsNullOrWhiteSpace(property) == false) {
{ this._queuedNotifications[property] = true;
if (string.IsNullOrWhiteSpace(property) == false) }
_queuedNotifications[property] = true;
} }
} }
// Depending on operation mode, either fire the notifications in the background // Depending on operation mode, either fire the notifications in the background
// or fire them immediately // or fire them immediately
if (_useDeferredNotifications) if(this._useDeferredNotifications) {
Task.Run(NotifyQueuedProperties); _ = Task.Run(this.NotifyQueuedProperties);
else } else {
NotifyQueuedProperties(); this.NotifyQueuedProperties();
}
} }
/// <summary> /// <summary>
/// Notifies the queued properties and resets the property name to a non-queued stated. /// Notifies the queued properties and resets the property name to a non-queued stated.
/// </summary> /// </summary>
private void NotifyQueuedProperties() private void NotifyQueuedProperties() {
{
// get a snapshot of property names. // get a snapshot of property names.
var propertyNames = _queuedNotifications.Keys.ToArray(); String[] propertyNames = this._queuedNotifications.Keys.ToArray();
// Iterate through the properties // Iterate through the properties
foreach (var property in propertyNames) foreach(String property in propertyNames) {
{
// don't notify if we don't have a change // don't notify if we don't have a change
if (!_queuedNotifications[property]) continue; if(!this._queuedNotifications[property]) {
continue;
}
// notify and reset queued state to false // notify and reset queued state to false
try { OnPropertyChanged(property); } try {
finally { _queuedNotifications[property] = false; } this.OnPropertyChanged(property);
} finally { this._queuedNotifications[property] = false; }
} }
} }
@ -118,7 +113,6 @@ namespace Swan
/// Called when a property changes its backing value. /// Called when a property changes its backing value.
/// </summary> /// </summary>
/// <param name="propertyName">Name of the property.</param> /// <param name="propertyName">Name of the property.</param>
private void OnPropertyChanged(string propertyName) => private void OnPropertyChanged(String propertyName) => PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName ?? String.Empty));
PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName ?? string.Empty));
} }
} }