diff --git a/Swan/DependencyInjection/DependencyContainer.cs b/Swan/DependencyInjection/DependencyContainer.cs index 48d6461..efbe114 100644 --- a/Swan/DependencyInjection/DependencyContainer.cs +++ b/Swan/DependencyInjection/DependencyContainer.cs @@ -1,705 +1,541 @@ -namespace Swan.DependencyInjection -{ - using System; - using System.Collections.Generic; - using System.Linq; - using System.Reflection; - +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Swan.DependencyInjection { + /// + /// The concrete implementation of a simple IoC container + /// based largely on TinyIoC (https://github.com/grumpydev/TinyIoC). + /// + /// + public partial class DependencyContainer : IDisposable { + private readonly Object _autoRegisterLock = new Object(); + + private Boolean _disposed; + + static DependencyContainer() { + } + /// - /// The concrete implementation of a simple IoC container - /// based largely on TinyIoC (https://github.com/grumpydev/TinyIoC). + /// Initializes a new instance of the class. /// - /// - public partial class DependencyContainer : IDisposable - { - private readonly object _autoRegisterLock = new object(); - - private bool _disposed; - - static DependencyContainer() - { - } - - /// - /// Initializes a new instance of the class. - /// - public DependencyContainer() - { - RegisteredTypes = new TypesConcurrentDictionary(this); - Register(this); - } - - private DependencyContainer(DependencyContainer parent) - : this() - { - Parent = parent; - } - - /// - /// Lazy created Singleton instance of the container for simple scenarios. - /// - public static DependencyContainer Current { get; } = new DependencyContainer(); - - internal DependencyContainer Parent { get; } - - internal TypesConcurrentDictionary RegisteredTypes { get; } - - /// - public void Dispose() - { - if (_disposed) return; - - _disposed = true; - - foreach (var disposable in RegisteredTypes.Values.Select(item => item as IDisposable)) - { - disposable?.Dispose(); - } - - GC.SuppressFinalize(this); - } - - /// - /// Gets the child container. - /// - /// A new instance of the class. - public DependencyContainer GetChildContainer() => new DependencyContainer(this); - - #region Registration - - /// - /// Attempt to automatically register all non-generic classes and interfaces in the current app domain. - /// Types will only be registered if they pass the supplied registration predicate. - /// - /// What action to take when encountering duplicate implementations of an interface/base class. - /// Predicate to determine if a particular type should be registered. - public void AutoRegister( - DependencyContainerDuplicateImplementationAction duplicateAction = - DependencyContainerDuplicateImplementationAction.RegisterSingle, - Func registrationPredicate = null) - { - AutoRegister( - AppDomain.CurrentDomain.GetAssemblies().Where(a => !IsIgnoredAssembly(a)), - duplicateAction, - registrationPredicate); - } - - /// - /// Attempt to automatically register all non-generic classes and interfaces in the specified assemblies - /// Types will only be registered if they pass the supplied registration predicate. - /// - /// Assemblies to process. - /// What action to take when encountering duplicate implementations of an interface/base class. - /// Predicate to determine if a particular type should be registered. - public void AutoRegister( - IEnumerable assemblies, - DependencyContainerDuplicateImplementationAction duplicateAction = - DependencyContainerDuplicateImplementationAction.RegisterSingle, - Func registrationPredicate = null) - { - lock (_autoRegisterLock) - { - var types = assemblies - .SelectMany(a => a.GetAllTypes()) - .Where(t => !IsIgnoredType(t, registrationPredicate)) - .ToList(); - - var concreteTypes = types - .Where(type => - type.IsClass && !type.IsAbstract && - (type != GetType() && (type.DeclaringType != GetType()) && !type.IsGenericTypeDefinition)) - .ToList(); - - foreach (var type in concreteTypes) - { - try - { - RegisteredTypes.Register(type, string.Empty, GetDefaultObjectFactory(type, type)); - } - catch (MethodAccessException) - { - // Ignore methods we can't access - added for Silverlight - } - } - - var abstractInterfaceTypes = types.Where( - type => - ((type.IsInterface || type.IsAbstract) && (type.DeclaringType != GetType()) && - (!type.IsGenericTypeDefinition))); - - foreach (var type in abstractInterfaceTypes) - { - var localType = type; - var implementations = concreteTypes - .Where(implementationType => localType.IsAssignableFrom(implementationType)).ToList(); - - if (implementations.Skip(1).Any()) - { - if (duplicateAction == DependencyContainerDuplicateImplementationAction.Fail) - throw new DependencyContainerRegistrationException(type, implementations); - - if (duplicateAction == DependencyContainerDuplicateImplementationAction.RegisterMultiple) - { - RegisterMultiple(type, implementations); - } - } - - var firstImplementation = implementations.FirstOrDefault(); - - if (firstImplementation == null) continue; - - try - { - RegisteredTypes.Register(type, string.Empty, GetDefaultObjectFactory(type, firstImplementation)); - } - catch (MethodAccessException) - { - // Ignore methods we can't access - added for Silverlight - } - } - } - } - - /// - /// Creates/replaces a named container class registration with default options. - /// - /// Type to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(Type registerType, string name = "") - => RegisteredTypes.Register( - registerType, - name, - GetDefaultObjectFactory(registerType, registerType)); - - /// - /// Creates/replaces a named container class registration with a given implementation and default options. - /// - /// Type to register. - /// Type to instantiate that implements RegisterType. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(Type registerType, Type registerImplementation, string name = "") => - RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerImplementation)); - - /// - /// Creates/replaces a named container class registration with a specific, strong referenced, instance. - /// - /// Type to register. - /// Instance of RegisterType to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(Type registerType, object instance, string name = "") => - RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerType, instance)); - - /// - /// Creates/replaces a named container class registration with a specific, strong referenced, instance. - /// - /// Type to register. - /// Type of instance to register that implements RegisterType. - /// Instance of RegisterImplementation to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register( - Type registerType, - Type registerImplementation, - object instance, - string name = "") - => RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerImplementation, instance)); - - /// - /// Creates/replaces a container class registration with a user specified factory. - /// - /// Type to register. - /// Factory/lambda that returns an instance of RegisterType. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register( - Type registerType, - Func, object> factory, - string name = "") - => RegisteredTypes.Register(registerType, name, new DelegateFactory(registerType, factory)); - - /// - /// Creates/replaces a named container class registration with default options. - /// - /// Type to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(string name = "") - where TRegister : class - { - return Register(typeof(TRegister), name); - } - - /// - /// Creates/replaces a named container class registration with a given implementation and default options. - /// - /// Type to register. - /// Type to instantiate that implements RegisterType. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(string name = "") - where TRegister : class - where TRegisterImplementation : class, TRegister - { - return Register(typeof(TRegister), typeof(TRegisterImplementation), name); - } - - /// - /// Creates/replaces a named container class registration with a specific, strong referenced, instance. - /// - /// Type to register. - /// Instance of RegisterType to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(TRegister instance, string name = "") - where TRegister : class - { - return Register(typeof(TRegister), instance, name); - } - - /// - /// Creates/replaces a named container class registration with a specific, strong referenced, instance. - /// - /// Type to register. - /// Type of instance to register that implements RegisterType. - /// Instance of RegisterImplementation to register. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register(TRegisterImplementation instance, - string name = "") - where TRegister : class - where TRegisterImplementation : class, TRegister - { - return Register(typeof(TRegister), typeof(TRegisterImplementation), instance, name); - } - - /// - /// Creates/replaces a named container class registration with a user specified factory. - /// - /// Type to register. - /// Factory/lambda that returns an instance of RegisterType. - /// Name of registration. - /// RegisterOptions for fluent API. - public RegisterOptions Register( - Func, TRegister> factory, string name = "") - where TRegister : class - { - if (factory == null) - throw new ArgumentNullException(nameof(factory)); - - return Register(typeof(TRegister), factory, name); - } - - /// - /// Register multiple implementations of a type. - /// - /// Internally this registers each implementation using the full name of the class as its registration name. - /// - /// Type that each implementation implements. - /// Types that implement RegisterType. - /// MultiRegisterOptions for the fluent API. - public MultiRegisterOptions RegisterMultiple(IEnumerable implementationTypes) => - RegisterMultiple(typeof(TRegister), implementationTypes); - - /// - /// Register multiple implementations of a type. - /// - /// Internally this registers each implementation using the full name of the class as its registration name. - /// - /// Type that each implementation implements. - /// Types that implement RegisterType. - /// MultiRegisterOptions for the fluent API. - public MultiRegisterOptions RegisterMultiple(Type registrationType, IEnumerable implementationTypes) - { - if (implementationTypes == null) - throw new ArgumentNullException(nameof(implementationTypes), "types is null."); - - foreach (var type in implementationTypes.Where(type => !registrationType.IsAssignableFrom(type))) - { - throw new ArgumentException( - $"types: The type {registrationType.FullName} is not assignable from {type.FullName}"); - } - - if (implementationTypes.Count() != implementationTypes.Distinct().Count()) - { - var queryForDuplicatedTypes = implementationTypes - .GroupBy(i => i) - .Where(j => j.Count() > 1) - .Select(j => j.Key.FullName); - - var fullNamesOfDuplicatedTypes = string.Join(",\n", queryForDuplicatedTypes.ToArray()); - - throw new ArgumentException( - $"types: The same implementation type cannot be specified multiple times for {registrationType.FullName}\n\n{fullNamesOfDuplicatedTypes}"); - } - - var registerOptions = implementationTypes - .Select(type => Register(registrationType, type, type.FullName)) - .ToList(); - - return new MultiRegisterOptions(registerOptions); - } - - #endregion - - #region Unregistration - - /// - /// Remove a named container class registration. - /// - /// Type to unregister. - /// Name of registration. - /// true if the registration is successfully found and removed; otherwise, false. - public bool Unregister(string name = "") => Unregister(typeof(TRegister), name); - - /// - /// Remove a named container class registration. - /// - /// Type to unregister. - /// Name of registration. - /// true if the registration is successfully found and removed; otherwise, false. - public bool Unregister(Type registerType, string name = "") => - RegisteredTypes.RemoveRegistration(new TypeRegistration(registerType, name)); - - #endregion - - #region Resolution - - /// - /// Attempts to resolve a named type using specified options and the supplied constructor parameters. - /// - /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). - /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. - /// - /// Type to resolve. - /// Name of registration. - /// Resolution options. - /// Instance of type. - /// Unable to resolve the type. - public object Resolve( - Type resolveType, - string name = null, - DependencyContainerResolveOptions options = null) - => RegisteredTypes.ResolveInternal(new TypeRegistration(resolveType, name), options ?? DependencyContainerResolveOptions.Default); - - /// - /// Attempts to resolve a named type using specified options and the supplied constructor parameters. - /// - /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). - /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. - /// - /// Type to resolve. - /// Name of registration. - /// Resolution options. - /// Instance of type. - /// Unable to resolve the type. - public TResolveType Resolve( - string name = null, - DependencyContainerResolveOptions options = null) - where TResolveType : class - { - return (TResolveType)Resolve(typeof(TResolveType), name, options); - } - - /// - /// Attempts to predict whether a given type can be resolved with the supplied constructor parameters options. - /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). - /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. - /// Note: Resolution may still fail if user defined factory registrations fail to construct objects when called. - /// - /// Type to resolve. - /// The name. - /// Resolution options. - /// - /// Bool indicating whether the type can be resolved. - /// - public bool CanResolve( - Type resolveType, - string name = null, - DependencyContainerResolveOptions options = null) => - RegisteredTypes.CanResolve(new TypeRegistration(resolveType, name), options); - - /// - /// Attempts to predict whether a given named type can be resolved with the supplied constructor parameters options. - /// - /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). - /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. - /// - /// Note: Resolution may still fail if user defined factory registrations fail to construct objects when called. - /// - /// Type to resolve. - /// Name of registration. - /// Resolution options. - /// Bool indicating whether the type can be resolved. - public bool CanResolve( - string name = null, - DependencyContainerResolveOptions options = null) - where TResolveType : class - { - return CanResolve(typeof(TResolveType), name, options); - } - - /// - /// Attempts to resolve a type using the default options. - /// - /// Type to resolve. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(Type resolveType, out object resolvedType) - { - try - { - resolvedType = Resolve(resolveType); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = null; - return false; - } - } - - /// - /// Attempts to resolve a type using the given options. - /// - /// Type to resolve. - /// Resolution options. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(Type resolveType, DependencyContainerResolveOptions options, out object resolvedType) - { - try - { - resolvedType = Resolve(resolveType, options: options); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = null; - return false; - } - } - - /// - /// Attempts to resolve a type using the default options and given name. - /// - /// Type to resolve. - /// Name of registration. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(Type resolveType, string name, out object resolvedType) - { - try - { - resolvedType = Resolve(resolveType, name); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = null; - return false; - } - } - - /// - /// Attempts to resolve a type using the given options and name. - /// - /// Type to resolve. - /// Name of registration. - /// Resolution options. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve( - Type resolveType, - string name, - DependencyContainerResolveOptions options, - out object resolvedType) - { - try - { - resolvedType = Resolve(resolveType, name, options); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = null; - return false; - } - } - - /// - /// Attempts to resolve a type using the default options. - /// - /// Type to resolve. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(out TResolveType resolvedType) - where TResolveType : class - { - try - { - resolvedType = Resolve(); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = default; - return false; - } - } - - /// - /// Attempts to resolve a type using the given options. - /// - /// Type to resolve. - /// Resolution options. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(DependencyContainerResolveOptions options, out TResolveType resolvedType) - where TResolveType : class - { - try - { - resolvedType = Resolve(options: options); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = default; - return false; - } - } - - /// - /// Attempts to resolve a type using the default options and given name. - /// - /// Type to resolve. - /// Name of registration. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve(string name, out TResolveType resolvedType) - where TResolveType : class - { - try - { - resolvedType = Resolve(name); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = default; - return false; - } - } - - /// - /// Attempts to resolve a type using the given options and name. - /// - /// Type to resolve. - /// Name of registration. - /// Resolution options. - /// Resolved type or default if resolve fails. - /// true if resolved successfully, false otherwise. - public bool TryResolve( - string name, - DependencyContainerResolveOptions options, - out TResolveType resolvedType) - where TResolveType : class - { - try - { - resolvedType = Resolve(name, options); - return true; - } - catch (DependencyContainerResolutionException) - { - resolvedType = default; - return false; - } - } - - /// - /// Returns all registrations of a type. - /// - /// Type to resolveAll. - /// Whether to include un-named (default) registrations. - /// IEnumerable. - public IEnumerable ResolveAll(Type resolveType, bool includeUnnamed = false) - => RegisteredTypes.Resolve(resolveType, includeUnnamed); - - /// - /// Returns all registrations of a type. - /// - /// Type to resolveAll. - /// Whether to include un-named (default) registrations. - /// IEnumerable. - public IEnumerable ResolveAll(bool includeUnnamed = true) - where TResolveType : class - { - return ResolveAll(typeof(TResolveType), includeUnnamed).Cast(); - } - - /// - /// Attempts to resolve all public property dependencies on the given object using the given resolve options. - /// - /// Object to "build up". - /// Resolve options to use. - public void BuildUp(object input, DependencyContainerResolveOptions resolveOptions = null) - { - if (resolveOptions == null) - resolveOptions = DependencyContainerResolveOptions.Default; - - var properties = input.GetType() - .GetProperties() - .Where(property => property.GetCacheGetMethod() != null && property.GetCacheSetMethod() != null && - !property.PropertyType.IsValueType); - - foreach (var property in properties.Where(property => property.GetValue(input, null) == null)) - { - try - { - property.SetValue( - input, - RegisteredTypes.ResolveInternal(new TypeRegistration(property.PropertyType), resolveOptions), - null); - } - catch (DependencyContainerResolutionException) - { - // Catch any resolution errors and ignore them - } - } - } - - #endregion - - #region Internal Methods - - internal static bool IsValidAssignment(Type registerType, Type registerImplementation) - { - if (!registerType.IsGenericTypeDefinition) - { - if (!registerType.IsAssignableFrom(registerImplementation)) - return false; - } - else - { - if (registerType.IsInterface && registerImplementation.GetInterfaces().All(t => t.Name != registerType.Name)) - return false; - - if (registerType.IsAbstract && registerImplementation.BaseType != registerType) - return false; - } - - return true; - } - - private static bool IsIgnoredAssembly(Assembly assembly) - { - // TODO - find a better way to remove "system" assemblies from the auto registration - var ignoreChecks = new List> - { + public DependencyContainer() { + this.RegisteredTypes = new TypesConcurrentDictionary(this); + _ = this.Register(this); + } + + private DependencyContainer(DependencyContainer parent) : this() => this.Parent = parent; + + /// + /// Lazy created Singleton instance of the container for simple scenarios. + /// + public static DependencyContainer Current { get; } = new DependencyContainer(); + + internal DependencyContainer Parent { + get; + } + + internal TypesConcurrentDictionary RegisteredTypes { + get; + } + + /// + public void Dispose() { + if(this._disposed) { + return; + } + + this._disposed = true; + + foreach(IDisposable disposable in this.RegisteredTypes.Values.Select(item => item as IDisposable)) { + disposable?.Dispose(); + } + + GC.SuppressFinalize(this); + } + + /// + /// Gets the child container. + /// + /// A new instance of the class. + public DependencyContainer GetChildContainer() => new DependencyContainer(this); + + #region Registration + + /// + /// Attempt to automatically register all non-generic classes and interfaces in the current app domain. + /// Types will only be registered if they pass the supplied registration predicate. + /// + /// What action to take when encountering duplicate implementations of an interface/base class. + /// Predicate to determine if a particular type should be registered. + public void AutoRegister(DependencyContainerDuplicateImplementationAction duplicateAction = DependencyContainerDuplicateImplementationAction.RegisterSingle, Func registrationPredicate = null) => this.AutoRegister(AppDomain.CurrentDomain.GetAssemblies().Where(a => !IsIgnoredAssembly(a)), duplicateAction, registrationPredicate); + + /// + /// Attempt to automatically register all non-generic classes and interfaces in the specified assemblies + /// Types will only be registered if they pass the supplied registration predicate. + /// + /// Assemblies to process. + /// What action to take when encountering duplicate implementations of an interface/base class. + /// Predicate to determine if a particular type should be registered. + public void AutoRegister(IEnumerable assemblies, DependencyContainerDuplicateImplementationAction duplicateAction = DependencyContainerDuplicateImplementationAction.RegisterSingle, Func registrationPredicate = null) { + lock(this._autoRegisterLock) { + List types = assemblies.SelectMany(a => a.GetAllTypes()).Where(t => !IsIgnoredType(t, registrationPredicate)).ToList(); + + List concreteTypes = types.Where(type => type.IsClass && !type.IsAbstract && type != this.GetType() && type.DeclaringType != this.GetType() && !type.IsGenericTypeDefinition).ToList(); + + foreach(Type type in concreteTypes) { + try { + _ = this.RegisteredTypes.Register(type, String.Empty, GetDefaultObjectFactory(type, type)); + } catch(MethodAccessException) { + // Ignore methods we can't access - added for Silverlight + } + } + + IEnumerable abstractInterfaceTypes = types.Where(type => (type.IsInterface || type.IsAbstract) && type.DeclaringType != this.GetType() && !type.IsGenericTypeDefinition); + + foreach(Type type in abstractInterfaceTypes) { + Type localType = type; + List implementations = concreteTypes.Where(implementationType => localType.IsAssignableFrom(implementationType)).ToList(); + + if(implementations.Skip(1).Any()) { + if(duplicateAction == DependencyContainerDuplicateImplementationAction.Fail) { + throw new DependencyContainerRegistrationException(type, implementations); + } + + if(duplicateAction == DependencyContainerDuplicateImplementationAction.RegisterMultiple) { + _ = this.RegisterMultiple(type, implementations); + } + } + + Type firstImplementation = implementations.FirstOrDefault(); + + if(firstImplementation == null) { + continue; + } + + try { + _ = this.RegisteredTypes.Register(type, String.Empty, GetDefaultObjectFactory(type, firstImplementation)); + } catch(MethodAccessException) { + // Ignore methods we can't access - added for Silverlight + } + } + } + } + + /// + /// Creates/replaces a named container class registration with default options. + /// + /// Type to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Type registerType, String name = "") => this.RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerType)); + + /// + /// Creates/replaces a named container class registration with a given implementation and default options. + /// + /// Type to register. + /// Type to instantiate that implements RegisterType. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Type registerType, Type registerImplementation, String name = "") => this.RegisteredTypes.Register(registerType, name, GetDefaultObjectFactory(registerType, registerImplementation)); + + /// + /// Creates/replaces a named container class registration with a specific, strong referenced, instance. + /// + /// Type to register. + /// Instance of RegisterType to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Type registerType, Object instance, String name = "") => this.RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerType, instance)); + + /// + /// Creates/replaces a named container class registration with a specific, strong referenced, instance. + /// + /// Type to register. + /// Type of instance to register that implements RegisterType. + /// Instance of RegisterImplementation to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Type registerType, Type registerImplementation, Object instance, String name = "") => this.RegisteredTypes.Register(registerType, name, new InstanceFactory(registerType, registerImplementation, instance)); + + /// + /// Creates/replaces a container class registration with a user specified factory. + /// + /// Type to register. + /// Factory/lambda that returns an instance of RegisterType. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Type registerType, Func, Object> factory, String name = "") => this.RegisteredTypes.Register(registerType, name, new DelegateFactory(registerType, factory)); + + /// + /// Creates/replaces a named container class registration with default options. + /// + /// Type to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(String name = "") where TRegister : class => this.Register(typeof(TRegister), name); + + /// + /// Creates/replaces a named container class registration with a given implementation and default options. + /// + /// Type to register. + /// Type to instantiate that implements RegisterType. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(String name = "") where TRegister : class where TRegisterImplementation : class, TRegister => this.Register(typeof(TRegister), typeof(TRegisterImplementation), name); + + /// + /// Creates/replaces a named container class registration with a specific, strong referenced, instance. + /// + /// Type to register. + /// Instance of RegisterType to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(TRegister instance, String name = "") where TRegister : class => this.Register(typeof(TRegister), instance, name); + + /// + /// Creates/replaces a named container class registration with a specific, strong referenced, instance. + /// + /// Type to register. + /// Type of instance to register that implements RegisterType. + /// Instance of RegisterImplementation to register. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(TRegisterImplementation instance, String name = "") where TRegister : class where TRegisterImplementation : class, TRegister => this.Register(typeof(TRegister), typeof(TRegisterImplementation), instance, name); + + /// + /// Creates/replaces a named container class registration with a user specified factory. + /// + /// Type to register. + /// Factory/lambda that returns an instance of RegisterType. + /// Name of registration. + /// RegisterOptions for fluent API. + public RegisterOptions Register(Func, TRegister> factory, String name = "") where TRegister : class { + if(factory == null) { + throw new ArgumentNullException(nameof(factory)); + } + + return this.Register(typeof(TRegister), factory, name); + } + + /// + /// Register multiple implementations of a type. + /// + /// Internally this registers each implementation using the full name of the class as its registration name. + /// + /// Type that each implementation implements. + /// Types that implement RegisterType. + /// MultiRegisterOptions for the fluent API. + public MultiRegisterOptions RegisterMultiple(IEnumerable implementationTypes) => this.RegisterMultiple(typeof(TRegister), implementationTypes); + + /// + /// Register multiple implementations of a type. + /// + /// Internally this registers each implementation using the full name of the class as its registration name. + /// + /// Type that each implementation implements. + /// Types that implement RegisterType. + /// MultiRegisterOptions for the fluent API. + public MultiRegisterOptions RegisterMultiple(Type registrationType, IEnumerable implementationTypes) { + if(implementationTypes == null) { + throw new ArgumentNullException(nameof(implementationTypes), "types is null."); + } + + foreach(Type type in implementationTypes.Where(type => !registrationType.IsAssignableFrom(type))) { + throw new ArgumentException($"types: The type {registrationType.FullName} is not assignable from {type.FullName}"); + } + + if(implementationTypes.Count() != implementationTypes.Distinct().Count()) { + IEnumerable queryForDuplicatedTypes = implementationTypes.GroupBy(i => i).Where(j => j.Count() > 1).Select(j => j.Key.FullName); + + String fullNamesOfDuplicatedTypes = String.Join(",\n", queryForDuplicatedTypes.ToArray()); + + throw new ArgumentException($"types: The same implementation type cannot be specified multiple times for {registrationType.FullName}\n\n{fullNamesOfDuplicatedTypes}"); + } + + List registerOptions = implementationTypes.Select(type => this.Register(registrationType, type, type.FullName)).ToList(); + + return new MultiRegisterOptions(registerOptions); + } + + #endregion + + #region Unregistration + + /// + /// Remove a named container class registration. + /// + /// Type to unregister. + /// Name of registration. + /// true if the registration is successfully found and removed; otherwise, false. + public Boolean Unregister(String name = "") => this.Unregister(typeof(TRegister), name); + + /// + /// Remove a named container class registration. + /// + /// Type to unregister. + /// Name of registration. + /// true if the registration is successfully found and removed; otherwise, false. + public Boolean Unregister(Type registerType, String name = "") => this.RegisteredTypes.RemoveRegistration(new TypeRegistration(registerType, name)); + + #endregion + + #region Resolution + + /// + /// Attempts to resolve a named type using specified options and the supplied constructor parameters. + /// + /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). + /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. + /// + /// Type to resolve. + /// Name of registration. + /// Resolution options. + /// Instance of type. + /// Unable to resolve the type. + public Object Resolve(Type resolveType, String name = null, DependencyContainerResolveOptions options = null) => this.RegisteredTypes.ResolveInternal(new TypeRegistration(resolveType, name), options ?? DependencyContainerResolveOptions.Default); + + /// + /// Attempts to resolve a named type using specified options and the supplied constructor parameters. + /// + /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). + /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. + /// + /// Type to resolve. + /// Name of registration. + /// Resolution options. + /// Instance of type. + /// Unable to resolve the type. + public TResolveType Resolve(String name = null, DependencyContainerResolveOptions options = null) where TResolveType : class => (TResolveType)this.Resolve(typeof(TResolveType), name, options); + + /// + /// Attempts to predict whether a given type can be resolved with the supplied constructor parameters options. + /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). + /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. + /// Note: Resolution may still fail if user defined factory registrations fail to construct objects when called. + /// + /// Type to resolve. + /// The name. + /// Resolution options. + /// + /// Bool indicating whether the type can be resolved. + /// + public Boolean CanResolve(Type resolveType, String name = null, DependencyContainerResolveOptions options = null) => this.RegisteredTypes.CanResolve(new TypeRegistration(resolveType, name), options); + + /// + /// Attempts to predict whether a given named type can be resolved with the supplied constructor parameters options. + /// + /// Parameters are used in conjunction with normal container resolution to find the most suitable constructor (if one exists). + /// All user supplied parameters must exist in at least one resolvable constructor of RegisterType or resolution will fail. + /// + /// Note: Resolution may still fail if user defined factory registrations fail to construct objects when called. + /// + /// Type to resolve. + /// Name of registration. + /// Resolution options. + /// Bool indicating whether the type can be resolved. + public Boolean CanResolve(String name = null, DependencyContainerResolveOptions options = null) where TResolveType : class => this.CanResolve(typeof(TResolveType), name, options); + + /// + /// Attempts to resolve a type using the default options. + /// + /// Type to resolve. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(Type resolveType, out Object resolvedType) { + try { + resolvedType = this.Resolve(resolveType); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = null; + return false; + } + } + + /// + /// Attempts to resolve a type using the given options. + /// + /// Type to resolve. + /// Resolution options. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(Type resolveType, DependencyContainerResolveOptions options, out Object resolvedType) { + try { + resolvedType = this.Resolve(resolveType, options: options); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = null; + return false; + } + } + + /// + /// Attempts to resolve a type using the default options and given name. + /// + /// Type to resolve. + /// Name of registration. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(Type resolveType, String name, out Object resolvedType) { + try { + resolvedType = this.Resolve(resolveType, name); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = null; + return false; + } + } + + /// + /// Attempts to resolve a type using the given options and name. + /// + /// Type to resolve. + /// Name of registration. + /// Resolution options. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(Type resolveType, String name, DependencyContainerResolveOptions options, out Object resolvedType) { + try { + resolvedType = this.Resolve(resolveType, name, options); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = null; + return false; + } + } + + /// + /// Attempts to resolve a type using the default options. + /// + /// Type to resolve. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(out TResolveType resolvedType) where TResolveType : class { + try { + resolvedType = this.Resolve(); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = default; + return false; + } + } + + /// + /// Attempts to resolve a type using the given options. + /// + /// Type to resolve. + /// Resolution options. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(DependencyContainerResolveOptions options, out TResolveType resolvedType) where TResolveType : class { + try { + resolvedType = this.Resolve(options: options); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = default; + return false; + } + } + + /// + /// Attempts to resolve a type using the default options and given name. + /// + /// Type to resolve. + /// Name of registration. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(String name, out TResolveType resolvedType) where TResolveType : class { + try { + resolvedType = this.Resolve(name); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = default; + return false; + } + } + + /// + /// Attempts to resolve a type using the given options and name. + /// + /// Type to resolve. + /// Name of registration. + /// Resolution options. + /// Resolved type or default if resolve fails. + /// true if resolved successfully, false otherwise. + public Boolean TryResolve(String name, DependencyContainerResolveOptions options, out TResolveType resolvedType) where TResolveType : class { + try { + resolvedType = this.Resolve(name, options); + return true; + } catch(DependencyContainerResolutionException) { + resolvedType = default; + return false; + } + } + + /// + /// Returns all registrations of a type. + /// + /// Type to resolveAll. + /// Whether to include un-named (default) registrations. + /// IEnumerable. + public IEnumerable ResolveAll(Type resolveType, Boolean includeUnnamed = false) => this.RegisteredTypes.Resolve(resolveType, includeUnnamed); + + /// + /// Returns all registrations of a type. + /// + /// Type to resolveAll. + /// Whether to include un-named (default) registrations. + /// IEnumerable. + public IEnumerable ResolveAll(Boolean includeUnnamed = true) where TResolveType : class => this.ResolveAll(typeof(TResolveType), includeUnnamed).Cast(); + + /// + /// Attempts to resolve all public property dependencies on the given object using the given resolve options. + /// + /// Object to "build up". + /// Resolve options to use. + public void BuildUp(Object input, DependencyContainerResolveOptions resolveOptions = null) { + if(resolveOptions == null) { + resolveOptions = DependencyContainerResolveOptions.Default; + } + + IEnumerable properties = input.GetType().GetProperties().Where(property => property.GetCacheGetMethod() != null && property.GetCacheSetMethod() != null && !property.PropertyType.IsValueType); + + foreach(PropertyInfo property in properties.Where(property => property.GetValue(input, null) == null)) { + try { + property.SetValue(input, this.RegisteredTypes.ResolveInternal(new TypeRegistration(property.PropertyType), resolveOptions), null); + } catch(DependencyContainerResolutionException) { + // Catch any resolution errors and ignore them + } + } + } + + #endregion + + #region Internal Methods + + internal static Boolean IsValidAssignment(Type registerType, Type registerImplementation) { + if(!registerType.IsGenericTypeDefinition) { + if(!registerType.IsAssignableFrom(registerImplementation)) { + return false; + } + } else { + if(registerType.IsInterface && registerImplementation.GetInterfaces().All(t => t.Name != registerType.Name)) { + return false; + } + + if(registerType.IsAbstract && registerImplementation.BaseType != registerType) { + return false; + } + } + + return true; + } + + private static Boolean IsIgnoredAssembly(Assembly assembly) { + // TODO - find a better way to remove "system" assemblies from the auto registration + List> ignoreChecks = new List> + { asm => asm.FullName.StartsWith("Microsoft.", StringComparison.Ordinal), asm => asm.FullName.StartsWith("System.", StringComparison.Ordinal), asm => asm.FullName.StartsWith("System,", StringComparison.Ordinal), @@ -708,36 +544,32 @@ asm => asm.FullName.StartsWith("CR_VSTest", StringComparison.Ordinal), asm => asm.FullName.StartsWith("DevExpress.CodeRush", StringComparison.Ordinal), asm => asm.FullName.StartsWith("xunit.", StringComparison.Ordinal), - }; - - return ignoreChecks.Any(check => check(assembly)); - } - - private static bool IsIgnoredType(Type type, Func registrationPredicate) - { - // TODO - find a better way to remove "system" types from the auto registration - var ignoreChecks = new List>() - { + }; + + return ignoreChecks.Any(check => check(assembly)); + } + + private static Boolean IsIgnoredType(Type type, Func registrationPredicate) { + // TODO - find a better way to remove "system" types from the auto registration + List> ignoreChecks = new List>() + { t => t.FullName?.StartsWith("System.", StringComparison.Ordinal) ?? false, t => t.FullName?.StartsWith("Microsoft.", StringComparison.Ordinal) ?? false, t => t.IsPrimitive, t => t.IsGenericTypeDefinition, - t => (t.GetConstructors(BindingFlags.Instance | BindingFlags.Public).Length == 0) && + t => t.GetConstructors(BindingFlags.Instance | BindingFlags.Public).Length == 0 && !(t.IsInterface || t.IsAbstract), - }; - - if (registrationPredicate != null) - { - ignoreChecks.Add(t => !registrationPredicate(t)); - } - - return ignoreChecks.Any(check => check(type)); - } - - private static ObjectFactoryBase GetDefaultObjectFactory(Type registerType, Type registerImplementation) => registerType.IsInterface || registerType.IsAbstract - ? (ObjectFactoryBase)new SingletonFactory(registerType, registerImplementation) - : new MultiInstanceFactory(registerType, registerImplementation); - - #endregion - } + }; + + if(registrationPredicate != null) { + ignoreChecks.Add(t => !registrationPredicate(t)); + } + + return ignoreChecks.Any(check => check(type)); + } + + private static ObjectFactoryBase GetDefaultObjectFactory(Type registerType, Type registerImplementation) => registerType.IsInterface || registerType.IsAbstract ? (ObjectFactoryBase)new SingletonFactory(registerType, registerImplementation) : new MultiInstanceFactory(registerType, registerImplementation); + + #endregion + } } diff --git a/Swan/DependencyInjection/DependencyContainerRegistrationException.cs b/Swan/DependencyInjection/DependencyContainerRegistrationException.cs index 3890744..38dfae6 100644 --- a/Swan/DependencyInjection/DependencyContainerRegistrationException.cs +++ b/Swan/DependencyInjection/DependencyContainerRegistrationException.cs @@ -1,46 +1,34 @@ -namespace Swan.DependencyInjection -{ - using System; - using System.Collections.Generic; - using System.Linq; - +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Swan.DependencyInjection { + /// + /// Generic Constraint Registration Exception. + /// + /// + public class DependencyContainerRegistrationException : Exception { + private const String ConvertErrorText = "Cannot convert current registration of {0} to {1}"; + private const String RegisterErrorText = "Cannot register type {0} - abstract classes or interfaces are not valid implementation types for {1}."; + private const String ErrorText = "Duplicate implementation of type {0} found ({1})."; + /// - /// Generic Constraint Registration Exception. + /// Initializes a new instance of the class. /// - /// - public class DependencyContainerRegistrationException : Exception - { - private const string ConvertErrorText = "Cannot convert current registration of {0} to {1}"; - private const string RegisterErrorText = - "Cannot register type {0} - abstract classes or interfaces are not valid implementation types for {1}."; - private const string ErrorText = "Duplicate implementation of type {0} found ({1})."; - - /// - /// Initializes a new instance of the class. - /// - /// Type of the register. - /// The types. - public DependencyContainerRegistrationException(Type registerType, IEnumerable types) - : base(string.Format(ErrorText, registerType, GetTypesString(types))) - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The type. - /// The method. - /// if set to true [is type factory]. - public DependencyContainerRegistrationException(Type type, string method, bool isTypeFactory = false) - : base(isTypeFactory - ? string.Format(RegisterErrorText, type.FullName, method) - : string.Format(ConvertErrorText, type.FullName, method)) - { - } - - private static string GetTypesString(IEnumerable types) - { - return string.Join(",", types.Select(type => type.FullName)); - } - } + /// Type of the register. + /// The types. + public DependencyContainerRegistrationException(Type registerType, IEnumerable types) : base(String.Format(ErrorText, registerType, GetTypesString(types))) { + } + + /// + /// Initializes a new instance of the class. + /// + /// The type. + /// The method. + /// if set to true [is type factory]. + public DependencyContainerRegistrationException(Type type, String method, Boolean isTypeFactory = false) : base(isTypeFactory ? String.Format(RegisterErrorText, type.FullName, method) : String.Format(ConvertErrorText, type.FullName, method)) { + } + + private static String GetTypesString(IEnumerable types) => String.Join(",", types.Select(type => type.FullName)); + } } \ No newline at end of file diff --git a/Swan/DependencyInjection/DependencyContainerResolutionException.cs b/Swan/DependencyInjection/DependencyContainerResolutionException.cs index da98665..63d6795 100644 --- a/Swan/DependencyInjection/DependencyContainerResolutionException.cs +++ b/Swan/DependencyInjection/DependencyContainerResolutionException.cs @@ -1,31 +1,25 @@ -namespace Swan.DependencyInjection -{ - using System; - +using System; + +namespace Swan.DependencyInjection { + /// + /// An exception for dependency resolutions. + /// + /// + [Serializable] + public class DependencyContainerResolutionException : Exception { /// - /// An exception for dependency resolutions. + /// Initializes a new instance of the class. /// - /// - [Serializable] - public class DependencyContainerResolutionException : Exception - { - /// - /// Initializes a new instance of the class. - /// - /// The type. - public DependencyContainerResolutionException(Type type) - : base($"Unable to resolve type: {type.FullName}") - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The type. - /// The inner exception. - public DependencyContainerResolutionException(Type type, Exception innerException) - : base($"Unable to resolve type: {type.FullName}", innerException) - { - } - } + /// The type. + public DependencyContainerResolutionException(Type type) : base($"Unable to resolve type: {type.FullName}") { + } + + /// + /// Initializes a new instance of the class. + /// + /// The type. + /// The inner exception. + public DependencyContainerResolutionException(Type type, Exception innerException) : base($"Unable to resolve type: {type.FullName}", innerException) { + } + } } diff --git a/Swan/DependencyInjection/DependencyContainerResolveOptions.cs b/Swan/DependencyInjection/DependencyContainerResolveOptions.cs index 28d5cb0..4bf2546 100644 --- a/Swan/DependencyInjection/DependencyContainerResolveOptions.cs +++ b/Swan/DependencyInjection/DependencyContainerResolveOptions.cs @@ -1,114 +1,106 @@ -namespace Swan.DependencyInjection -{ - using System.Collections.Generic; - +using System.Collections.Generic; + +namespace Swan.DependencyInjection { + /// + /// Resolution settings. + /// + public class DependencyContainerResolveOptions { /// - /// Resolution settings. + /// Gets the default options (attempt resolution of unregistered types, fail on named resolution if name not found). /// - public class DependencyContainerResolveOptions - { - /// - /// Gets the default options (attempt resolution of unregistered types, fail on named resolution if name not found). - /// - public static DependencyContainerResolveOptions Default { get; } = new DependencyContainerResolveOptions(); - - /// - /// Gets or sets the unregistered resolution action. - /// - /// - /// The unregistered resolution action. - /// - public DependencyContainerUnregisteredResolutionAction UnregisteredResolutionAction { get; set; } = - DependencyContainerUnregisteredResolutionAction.AttemptResolve; - - /// - /// Gets or sets the named resolution failure action. - /// - /// - /// The named resolution failure action. - /// - public DependencyContainerNamedResolutionFailureAction NamedResolutionFailureAction { get; set; } = - DependencyContainerNamedResolutionFailureAction.Fail; - - /// - /// Gets the constructor parameters. - /// - /// - /// The constructor parameters. - /// - public Dictionary ConstructorParameters { get; } = new Dictionary(); - - /// - /// Clones this instance. - /// - /// - public DependencyContainerResolveOptions Clone() => new DependencyContainerResolveOptions - { - NamedResolutionFailureAction = NamedResolutionFailureAction, - UnregisteredResolutionAction = UnregisteredResolutionAction, - }; - } - + public static DependencyContainerResolveOptions Default { get; } = new DependencyContainerResolveOptions(); + /// - /// Defines Resolution actions. + /// Gets or sets the unregistered resolution action. /// - public enum DependencyContainerUnregisteredResolutionAction - { - /// - /// Attempt to resolve type, even if the type isn't registered. - /// - /// Registered types/options will always take precedence. - /// - AttemptResolve, - - /// - /// Fail resolution if type not explicitly registered - /// - Fail, - - /// - /// Attempt to resolve unregistered type if requested type is generic - /// and no registration exists for the specific generic parameters used. - /// - /// Registered types/options will always take precedence. - /// - GenericsOnly, - } - + /// + /// The unregistered resolution action. + /// + public DependencyContainerUnregisteredResolutionAction UnregisteredResolutionAction { get; set; } = DependencyContainerUnregisteredResolutionAction.AttemptResolve; + /// - /// Enumerates failure actions. + /// Gets or sets the named resolution failure action. /// - public enum DependencyContainerNamedResolutionFailureAction - { - /// - /// The attempt unnamed resolution - /// - AttemptUnnamedResolution, - - /// - /// The fail - /// - Fail, - } - + /// + /// The named resolution failure action. + /// + public DependencyContainerNamedResolutionFailureAction NamedResolutionFailureAction { get; set; } = DependencyContainerNamedResolutionFailureAction.Fail; + /// - /// Enumerates duplicate definition actions. + /// Gets the constructor parameters. /// - public enum DependencyContainerDuplicateImplementationAction - { - /// - /// The register single - /// - RegisterSingle, - - /// - /// The register multiple - /// - RegisterMultiple, - - /// - /// The fail - /// - Fail, - } + /// + /// The constructor parameters. + /// + public Dictionary ConstructorParameters { get; } = new Dictionary(); + + /// + /// Clones this instance. + /// + /// + public DependencyContainerResolveOptions Clone() => new DependencyContainerResolveOptions { + NamedResolutionFailureAction = NamedResolutionFailureAction, + UnregisteredResolutionAction = UnregisteredResolutionAction, + }; + } + + /// + /// Defines Resolution actions. + /// + public enum DependencyContainerUnregisteredResolutionAction { + /// + /// Attempt to resolve type, even if the type isn't registered. + /// + /// Registered types/options will always take precedence. + /// + AttemptResolve, + + /// + /// Fail resolution if type not explicitly registered + /// + Fail, + + /// + /// Attempt to resolve unregistered type if requested type is generic + /// and no registration exists for the specific generic parameters used. + /// + /// Registered types/options will always take precedence. + /// + GenericsOnly, + } + + /// + /// Enumerates failure actions. + /// + public enum DependencyContainerNamedResolutionFailureAction { + /// + /// The attempt unnamed resolution + /// + AttemptUnnamedResolution, + + /// + /// The fail + /// + Fail, + } + + /// + /// Enumerates duplicate definition actions. + /// + public enum DependencyContainerDuplicateImplementationAction { + /// + /// The register single + /// + RegisterSingle, + + /// + /// The register multiple + /// + RegisterMultiple, + + /// + /// The fail + /// + Fail, + } } \ No newline at end of file diff --git a/Swan/DependencyInjection/DependencyContainerWeakReferenceException.cs b/Swan/DependencyInjection/DependencyContainerWeakReferenceException.cs index eb1e085..22953f4 100644 --- a/Swan/DependencyInjection/DependencyContainerWeakReferenceException.cs +++ b/Swan/DependencyInjection/DependencyContainerWeakReferenceException.cs @@ -1,22 +1,18 @@ -namespace Swan.DependencyInjection -{ - using System; - +using System; + +namespace Swan.DependencyInjection { + /// + /// Weak Reference Exception. + /// + /// + public class DependencyContainerWeakReferenceException : Exception { + private const String ErrorText = "Unable to instantiate {0} - referenced object has been reclaimed"; + /// - /// Weak Reference Exception. + /// Initializes a new instance of the class. /// - /// - public class DependencyContainerWeakReferenceException : Exception - { - private const string ErrorText = "Unable to instantiate {0} - referenced object has been reclaimed"; - - /// - /// Initializes a new instance of the class. - /// - /// The type. - public DependencyContainerWeakReferenceException(Type type) - : base(string.Format(ErrorText, type.FullName)) - { - } - } + /// The type. + public DependencyContainerWeakReferenceException(Type type) : base(String.Format(ErrorText, type.FullName)) { + } + } } diff --git a/Swan/DependencyInjection/ObjectFactoryBase.cs b/Swan/DependencyInjection/ObjectFactoryBase.cs index 51ff96d..d2189ee 100644 --- a/Swan/DependencyInjection/ObjectFactoryBase.cs +++ b/Swan/DependencyInjection/ObjectFactoryBase.cs @@ -1,423 +1,352 @@ -namespace Swan.DependencyInjection -{ - using System; - using System.Collections.Generic; - using System.Reflection; - +using System; +using System.Collections.Generic; +using System.Reflection; + +namespace Swan.DependencyInjection { + /// + /// Represents an abstract class for Object Factory. + /// + public abstract class ObjectFactoryBase { /// - /// Represents an abstract class for Object Factory. + /// Whether to assume this factory successfully constructs its objects + /// + /// Generally set to true for delegate style factories as CanResolve cannot delve + /// into the delegates they contain. /// - public abstract class ObjectFactoryBase - { - /// - /// Whether to assume this factory successfully constructs its objects - /// - /// Generally set to true for delegate style factories as CanResolve cannot delve - /// into the delegates they contain. - /// - public virtual bool AssumeConstruction => false; - - /// - /// The type the factory instantiates. - /// - public abstract Type CreatesType { get; } - - /// - /// Constructor to use, if specified. - /// - public ConstructorInfo Constructor { get; private set; } - - /// - /// Gets the singleton variant. - /// - /// - /// The singleton variant. - /// - /// singleton. - public virtual ObjectFactoryBase SingletonVariant => - throw new DependencyContainerRegistrationException(GetType(), "singleton"); - - /// - /// Gets the multi instance variant. - /// - /// - /// The multi instance variant. - /// - /// multi-instance. - public virtual ObjectFactoryBase MultiInstanceVariant => - throw new DependencyContainerRegistrationException(GetType(), "multi-instance"); - - /// - /// Gets the strong reference variant. - /// - /// - /// The strong reference variant. - /// - /// strong reference. - public virtual ObjectFactoryBase StrongReferenceVariant => - throw new DependencyContainerRegistrationException(GetType(), "strong reference"); - - /// - /// Gets the weak reference variant. - /// - /// - /// The weak reference variant. - /// - /// weak reference. - public virtual ObjectFactoryBase WeakReferenceVariant => - throw new DependencyContainerRegistrationException(GetType(), "weak reference"); - - /// - /// Create the type. - /// - /// Type user requested to be resolved. - /// Container that requested the creation. - /// The options. - /// Instance of type. - public abstract object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options); - - /// - /// Gets the factory for child container. - /// - /// The type. - /// The parent. - /// The child. - /// - public virtual ObjectFactoryBase GetFactoryForChildContainer( - Type type, - DependencyContainer parent, - DependencyContainer child) - { - return this; - } - } - - /// + public virtual Boolean AssumeConstruction => false; + /// - /// IObjectFactory that creates new instances of types for each resolution. + /// The type the factory instantiates. /// - internal class MultiInstanceFactory : ObjectFactoryBase - { - private readonly Type _registerType; - private readonly Type _registerImplementation; - - public MultiInstanceFactory(Type registerType, Type registerImplementation) - { - if (registerImplementation.IsAbstract || registerImplementation.IsInterface) - { - throw new DependencyContainerRegistrationException(registerImplementation, - "MultiInstanceFactory", - true); - } - - if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) - { - throw new DependencyContainerRegistrationException(registerImplementation, - "MultiInstanceFactory", - true); - } - - _registerType = registerType; - _registerImplementation = registerImplementation; - } - - public override Type CreatesType => _registerImplementation; - - public override ObjectFactoryBase SingletonVariant => - new SingletonFactory(_registerType, _registerImplementation); - - public override ObjectFactoryBase MultiInstanceVariant => this; - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - try - { - return container.RegisteredTypes.ConstructType(_registerImplementation, Constructor, options); - } - catch (DependencyContainerResolutionException ex) - { - throw new DependencyContainerResolutionException(_registerType, ex); - } - } - } - - /// + public abstract Type CreatesType { + get; + } + /// - /// IObjectFactory that invokes a specified delegate to construct the object. + /// Constructor to use, if specified. /// - internal class DelegateFactory : ObjectFactoryBase - { - private readonly Type _registerType; - - private readonly Func, object> _factory; - - public DelegateFactory( - Type registerType, - Func, object> factory) - { - _factory = factory ?? throw new ArgumentNullException(nameof(factory)); - - _registerType = registerType; - } - - public override bool AssumeConstruction => true; - - public override Type CreatesType => _registerType; - - public override ObjectFactoryBase WeakReferenceVariant => new WeakDelegateFactory(_registerType, _factory); - - public override ObjectFactoryBase StrongReferenceVariant => this; - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - try - { - return _factory.Invoke(container, options.ConstructorParameters); - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(_registerType, ex); - } - } - } - - /// + public ConstructorInfo Constructor { + get; private set; + } + /// - /// IObjectFactory that invokes a specified delegate to construct the object - /// Holds the delegate using a weak reference. + /// Gets the singleton variant. /// - internal class WeakDelegateFactory : ObjectFactoryBase - { - private readonly Type _registerType; - - private readonly WeakReference _factory; - - public WeakDelegateFactory( - Type registerType, - Func, object> factory) - { - if (factory == null) - throw new ArgumentNullException(nameof(factory)); - - _factory = new WeakReference(factory); - - _registerType = registerType; - } - - public override bool AssumeConstruction => true; - - public override Type CreatesType => _registerType; - - public override ObjectFactoryBase StrongReferenceVariant - { - get - { - if (!(_factory.Target is Func, object> factory)) - throw new DependencyContainerWeakReferenceException(_registerType); - - return new DelegateFactory(_registerType, factory); - } - } - - public override ObjectFactoryBase WeakReferenceVariant => this; - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - if (!(_factory.Target is Func, object> factory)) - throw new DependencyContainerWeakReferenceException(_registerType); - - try - { - return factory.Invoke(container, options.ConstructorParameters); - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(_registerType, ex); - } - } - } - + /// + /// The singleton variant. + /// + /// singleton. + public virtual ObjectFactoryBase SingletonVariant => throw new DependencyContainerRegistrationException(this.GetType(), "singleton"); + /// - /// Stores an particular instance to return for a type. + /// Gets the multi instance variant. /// - internal class InstanceFactory : ObjectFactoryBase, IDisposable - { - private readonly Type _registerType; - private readonly Type _registerImplementation; - private readonly object _instance; - - public InstanceFactory(Type registerType, Type registerImplementation, object instance) - { - if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) - throw new DependencyContainerRegistrationException(registerImplementation, "InstanceFactory", true); - - _registerType = registerType; - _registerImplementation = registerImplementation; - _instance = instance; - } - - public override bool AssumeConstruction => true; - - public override Type CreatesType => _registerImplementation; - - public override ObjectFactoryBase MultiInstanceVariant => - new MultiInstanceFactory(_registerType, _registerImplementation); - - public override ObjectFactoryBase WeakReferenceVariant => - new WeakInstanceFactory(_registerType, _registerImplementation, _instance); - - public override ObjectFactoryBase StrongReferenceVariant => this; - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - return _instance; - } - - public void Dispose() - { - var disposable = _instance as IDisposable; - - disposable?.Dispose(); - } - } - + /// + /// The multi instance variant. + /// + /// multi-instance. + public virtual ObjectFactoryBase MultiInstanceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "multi-instance"); + /// - /// Stores the instance with a weak reference. + /// Gets the strong reference variant. /// - internal class WeakInstanceFactory : ObjectFactoryBase, IDisposable - { - private readonly Type _registerType; - private readonly Type _registerImplementation; - private readonly WeakReference _instance; - - public WeakInstanceFactory(Type registerType, Type registerImplementation, object instance) - { - if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) - { - throw new DependencyContainerRegistrationException( - registerImplementation, - "WeakInstanceFactory", - true); - } - - _registerType = registerType; - _registerImplementation = registerImplementation; - _instance = new WeakReference(instance); - } - - public override Type CreatesType => _registerImplementation; - - public override ObjectFactoryBase MultiInstanceVariant => - new MultiInstanceFactory(_registerType, _registerImplementation); - - public override ObjectFactoryBase WeakReferenceVariant => this; - - public override ObjectFactoryBase StrongReferenceVariant - { - get - { - var instance = _instance.Target; - - if (instance == null) - throw new DependencyContainerWeakReferenceException(_registerType); - - return new InstanceFactory(_registerType, _registerImplementation, instance); - } - } - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - var instance = _instance.Target; - - if (instance == null) - throw new DependencyContainerWeakReferenceException(_registerType); - - return instance; - } - - public void Dispose() => (_instance.Target as IDisposable)?.Dispose(); - } - + /// + /// The strong reference variant. + /// + /// strong reference. + public virtual ObjectFactoryBase StrongReferenceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "strong reference"); + /// - /// A factory that lazy instantiates a type and always returns the same instance. + /// Gets the weak reference variant. /// - internal class SingletonFactory : ObjectFactoryBase, IDisposable - { - private readonly Type _registerType; - private readonly Type _registerImplementation; - private readonly object _singletonLock = new object(); - private object _current; - - public SingletonFactory(Type registerType, Type registerImplementation) - { - if (registerImplementation.IsAbstract || registerImplementation.IsInterface) - { - throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); - } - - if (!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) - { - throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); - } - - _registerType = registerType; - _registerImplementation = registerImplementation; - } - - public override Type CreatesType => _registerImplementation; - - public override ObjectFactoryBase SingletonVariant => this; - - public override ObjectFactoryBase MultiInstanceVariant => - new MultiInstanceFactory(_registerType, _registerImplementation); - - public override object GetObject( - Type requestedType, - DependencyContainer container, - DependencyContainerResolveOptions options) - { - if (options.ConstructorParameters.Count != 0) - throw new ArgumentException("Cannot specify parameters for singleton types"); - - lock (_singletonLock) - { - if (_current == null) - _current = container.RegisteredTypes.ConstructType(_registerImplementation, Constructor, options); - } - - return _current; - } - - public override ObjectFactoryBase GetFactoryForChildContainer( - Type type, - DependencyContainer parent, - DependencyContainer child) - { - // We make sure that the singleton is constructed before the child container takes the factory. - // Otherwise the results would vary depending on whether or not the parent container had resolved - // the type before the child container does. - GetObject(type, parent, DependencyContainerResolveOptions.Default); - return this; - } - - public void Dispose() => (_current as IDisposable)?.Dispose(); - } + /// + /// The weak reference variant. + /// + /// weak reference. + public virtual ObjectFactoryBase WeakReferenceVariant => throw new DependencyContainerRegistrationException(this.GetType(), "weak reference"); + + /// + /// Create the type. + /// + /// Type user requested to be resolved. + /// Container that requested the creation. + /// The options. + /// Instance of type. + public abstract Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options); + + /// + /// Gets the factory for child container. + /// + /// The type. + /// The parent. + /// The child. + /// + public virtual ObjectFactoryBase GetFactoryForChildContainer(Type type, DependencyContainer parent, DependencyContainer child) => this; + } + + /// + /// + /// IObjectFactory that creates new instances of types for each resolution. + /// + internal class MultiInstanceFactory : ObjectFactoryBase { + private readonly Type _registerType; + private readonly Type _registerImplementation; + + public MultiInstanceFactory(Type registerType, Type registerImplementation) { + if(registerImplementation.IsAbstract || registerImplementation.IsInterface) { + throw new DependencyContainerRegistrationException(registerImplementation, "MultiInstanceFactory", true); + } + + if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) { + throw new DependencyContainerRegistrationException(registerImplementation, "MultiInstanceFactory", true); + } + + this._registerType = registerType; + this._registerImplementation = registerImplementation; + } + + public override Type CreatesType => this._registerImplementation; + + public override ObjectFactoryBase SingletonVariant => + new SingletonFactory(this._registerType, this._registerImplementation); + + public override ObjectFactoryBase MultiInstanceVariant => this; + + public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) { + try { + return container.RegisteredTypes.ConstructType(this._registerImplementation, this.Constructor, options); + } catch(DependencyContainerResolutionException ex) { + throw new DependencyContainerResolutionException(this._registerType, ex); + } + } + } + + /// + /// + /// IObjectFactory that invokes a specified delegate to construct the object. + /// + internal class DelegateFactory : ObjectFactoryBase { + private readonly Type _registerType; + + private readonly Func, Object> _factory; + + public DelegateFactory( + Type registerType, + Func, Object> factory) { + this._factory = factory ?? throw new ArgumentNullException(nameof(factory)); + + this._registerType = registerType; + } + + public override Boolean AssumeConstruction => true; + + public override Type CreatesType => this._registerType; + + public override ObjectFactoryBase WeakReferenceVariant => new WeakDelegateFactory(this._registerType, this._factory); + + public override ObjectFactoryBase StrongReferenceVariant => this; + + public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) { + try { + return this._factory.Invoke(container, options.ConstructorParameters); + } catch(Exception ex) { + throw new DependencyContainerResolutionException(this._registerType, ex); + } + } + } + + /// + /// + /// IObjectFactory that invokes a specified delegate to construct the object + /// Holds the delegate using a weak reference. + /// + internal class WeakDelegateFactory : ObjectFactoryBase { + private readonly Type _registerType; + + private readonly WeakReference _factory; + + public WeakDelegateFactory(Type registerType, Func, Object> factory) { + if(factory == null) { + throw new ArgumentNullException(nameof(factory)); + } + + this._factory = new WeakReference(factory); + + this._registerType = registerType; + } + + public override Boolean AssumeConstruction => true; + + public override Type CreatesType => this._registerType; + + public override ObjectFactoryBase StrongReferenceVariant { + get { + if(!(this._factory.Target is Func, global::System.Object> factory)) { + throw new DependencyContainerWeakReferenceException(this._registerType); + } + + return new DelegateFactory(this._registerType, factory); + } + } + + public override ObjectFactoryBase WeakReferenceVariant => this; + + public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) { + if(!(this._factory.Target is Func, global::System.Object> factory)) { + throw new DependencyContainerWeakReferenceException(this._registerType); + } + + try { + return factory.Invoke(container, options.ConstructorParameters); + } catch(Exception ex) { + throw new DependencyContainerResolutionException(this._registerType, ex); + } + } + } + + /// + /// Stores an particular instance to return for a type. + /// + internal class InstanceFactory : ObjectFactoryBase, IDisposable { + private readonly Type _registerType; + private readonly Type _registerImplementation; + private readonly Object _instance; + + public InstanceFactory(Type registerType, Type registerImplementation, Object instance) { + if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) { + throw new DependencyContainerRegistrationException(registerImplementation, "InstanceFactory", true); + } + + this._registerType = registerType; + this._registerImplementation = registerImplementation; + this._instance = instance; + } + + public override Boolean AssumeConstruction => true; + + public override Type CreatesType => this._registerImplementation; + + public override ObjectFactoryBase MultiInstanceVariant => new MultiInstanceFactory(this._registerType, this._registerImplementation); + + public override ObjectFactoryBase WeakReferenceVariant => new WeakInstanceFactory(this._registerType, this._registerImplementation, this._instance); + + public override ObjectFactoryBase StrongReferenceVariant => this; + + public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) => this._instance; + + public void Dispose() { + IDisposable disposable = this._instance as IDisposable; + + disposable?.Dispose(); + } + } + + /// + /// Stores the instance with a weak reference. + /// + internal class WeakInstanceFactory : ObjectFactoryBase, IDisposable { + private readonly Type _registerType; + private readonly Type _registerImplementation; + private readonly WeakReference _instance; + + public WeakInstanceFactory(Type registerType, Type registerImplementation, Object instance) { + if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) { + throw new DependencyContainerRegistrationException(registerImplementation, "WeakInstanceFactory", true); + } + + this._registerType = registerType; + this._registerImplementation = registerImplementation; + this._instance = new WeakReference(instance); + } + + public override Type CreatesType => this._registerImplementation; + + public override ObjectFactoryBase MultiInstanceVariant => new MultiInstanceFactory(this._registerType, this._registerImplementation); + + public override ObjectFactoryBase WeakReferenceVariant => this; + + public override ObjectFactoryBase StrongReferenceVariant { + get { + Object instance = this._instance.Target; + + if(instance == null) { + throw new DependencyContainerWeakReferenceException(this._registerType); + } + + return new InstanceFactory(this._registerType, this._registerImplementation, instance); + } + } + + public override Object GetObject(Type requestedType, DependencyContainer container, DependencyContainerResolveOptions options) { + Object instance = this._instance.Target; + + if(instance == null) { + throw new DependencyContainerWeakReferenceException(this._registerType); + } + + return instance; + } + + public void Dispose() => (this._instance.Target as IDisposable)?.Dispose(); + } + + /// + /// A factory that lazy instantiates a type and always returns the same instance. + /// + internal class SingletonFactory : ObjectFactoryBase, IDisposable { + private readonly Type _registerType; + private readonly Type _registerImplementation; + private readonly Object _singletonLock = new Object(); + private Object _current; + + public SingletonFactory(Type registerType, Type registerImplementation) { + if(registerImplementation.IsAbstract || registerImplementation.IsInterface) { + throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); + } + + if(!DependencyContainer.IsValidAssignment(registerType, registerImplementation)) { + throw new DependencyContainerRegistrationException(registerImplementation, nameof(SingletonFactory), true); + } + + this._registerType = registerType; + this._registerImplementation = registerImplementation; + } + + public override Type CreatesType => this._registerImplementation; + + public override ObjectFactoryBase SingletonVariant => this; + + public override ObjectFactoryBase MultiInstanceVariant => + new MultiInstanceFactory(this._registerType, this._registerImplementation); + + public override Object GetObject( + Type requestedType, + DependencyContainer container, + DependencyContainerResolveOptions options) { + if(options.ConstructorParameters.Count != 0) { + throw new ArgumentException("Cannot specify parameters for singleton types"); + } + + lock(this._singletonLock) { + if(this._current == null) { + this._current = container.RegisteredTypes.ConstructType(this._registerImplementation, this.Constructor, options); + } + } + + return this._current; + } + + public override ObjectFactoryBase GetFactoryForChildContainer( + Type type, + DependencyContainer parent, + DependencyContainer child) { + // We make sure that the singleton is constructed before the child container takes the factory. + // Otherwise the results would vary depending on whether or not the parent container had resolved + // the type before the child container does. + _ = this.GetObject(type, parent, DependencyContainerResolveOptions.Default); + return this; + } + + public void Dispose() => (this._current as IDisposable)?.Dispose(); + } } diff --git a/Swan/DependencyInjection/RegisterOptions.cs b/Swan/DependencyInjection/RegisterOptions.cs index f8ebfe8..4c83ed2 100644 --- a/Swan/DependencyInjection/RegisterOptions.cs +++ b/Swan/DependencyInjection/RegisterOptions.cs @@ -1,131 +1,119 @@ -namespace Swan.DependencyInjection -{ - using System; - using System.Collections.Generic; - using System.Linq; - +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Swan.DependencyInjection { + /// + /// Registration options for "fluent" API. + /// + public sealed class RegisterOptions { + private readonly TypesConcurrentDictionary _registeredTypes; + private readonly DependencyContainer.TypeRegistration _registration; + /// - /// Registration options for "fluent" API. + /// Initializes a new instance of the class. /// - public sealed class RegisterOptions - { - private readonly TypesConcurrentDictionary _registeredTypes; - private readonly DependencyContainer.TypeRegistration _registration; - - /// - /// Initializes a new instance of the class. - /// - /// The registered types. - /// The registration. - public RegisterOptions(TypesConcurrentDictionary registeredTypes, DependencyContainer.TypeRegistration registration) - { - _registeredTypes = registeredTypes; - _registration = registration; - } - - /// - /// Make registration a singleton (single instance) if possible. - /// - /// A registration options for fluent API. - /// Generic constraint registration exception. - public RegisterOptions AsSingleton() - { - var currentFactory = _registeredTypes.GetCurrentFactory(_registration); - - if (currentFactory == null) - throw new DependencyContainerRegistrationException(_registration.Type, "singleton"); - - return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.SingletonVariant); - } - - /// - /// Make registration multi-instance if possible. - /// - /// A registration options for fluent API. - /// Generic constraint registration exception. - public RegisterOptions AsMultiInstance() - { - var currentFactory = _registeredTypes.GetCurrentFactory(_registration); - - if (currentFactory == null) - throw new DependencyContainerRegistrationException(_registration.Type, "multi-instance"); - - return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.MultiInstanceVariant); - } - - /// - /// Make registration hold a weak reference if possible. - /// - /// A registration options for fluent API. - /// Generic constraint registration exception. - public RegisterOptions WithWeakReference() - { - var currentFactory = _registeredTypes.GetCurrentFactory(_registration); - - if (currentFactory == null) - throw new DependencyContainerRegistrationException(_registration.Type, "weak reference"); - - return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.WeakReferenceVariant); - } - - /// - /// Make registration hold a strong reference if possible. - /// - /// A registration options for fluent API. - /// Generic constraint registration exception. - public RegisterOptions WithStrongReference() - { - var currentFactory = _registeredTypes.GetCurrentFactory(_registration); - - if (currentFactory == null) - throw new DependencyContainerRegistrationException(_registration.Type, "strong reference"); - - return _registeredTypes.AddUpdateRegistration(_registration, currentFactory.StrongReferenceVariant); - } - } - + /// The registered types. + /// The registration. + public RegisterOptions(TypesConcurrentDictionary registeredTypes, DependencyContainer.TypeRegistration registration) { + this._registeredTypes = registeredTypes; + this._registration = registration; + } + /// - /// Registration options for "fluent" API when registering multiple implementations. + /// Make registration a singleton (single instance) if possible. /// - public sealed class MultiRegisterOptions - { - private IEnumerable _registerOptions; - - /// - /// Initializes a new instance of the class. - /// - /// The register options. - public MultiRegisterOptions(IEnumerable registerOptions) - { - _registerOptions = registerOptions; - } - - /// - /// Make registration a singleton (single instance) if possible. - /// - /// A registration multi-instance for fluent API. - /// Generic Constraint Registration Exception. - public MultiRegisterOptions AsSingleton() - { - _registerOptions = ExecuteOnAllRegisterOptions(ro => ro.AsSingleton()); - return this; - } - - /// - /// Make registration multi-instance if possible. - /// - /// A registration multi-instance for fluent API. - /// Generic Constraint Registration Exception. - public MultiRegisterOptions AsMultiInstance() - { - _registerOptions = ExecuteOnAllRegisterOptions(ro => ro.AsMultiInstance()); - return this; - } - - private IEnumerable ExecuteOnAllRegisterOptions( - Func action) - { - return _registerOptions.Select(action).ToList(); - } - } + /// A registration options for fluent API. + /// Generic constraint registration exception. + public RegisterOptions AsSingleton() { + ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration); + + if(currentFactory == null) { + throw new DependencyContainerRegistrationException(this._registration.Type, "singleton"); + } + + return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.SingletonVariant); + } + + /// + /// Make registration multi-instance if possible. + /// + /// A registration options for fluent API. + /// Generic constraint registration exception. + public RegisterOptions AsMultiInstance() { + ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration); + + if(currentFactory == null) { + throw new DependencyContainerRegistrationException(this._registration.Type, "multi-instance"); + } + + return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.MultiInstanceVariant); + } + + /// + /// Make registration hold a weak reference if possible. + /// + /// A registration options for fluent API. + /// Generic constraint registration exception. + public RegisterOptions WithWeakReference() { + ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration); + + if(currentFactory == null) { + throw new DependencyContainerRegistrationException(this._registration.Type, "weak reference"); + } + + return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.WeakReferenceVariant); + } + + /// + /// Make registration hold a strong reference if possible. + /// + /// A registration options for fluent API. + /// Generic constraint registration exception. + public RegisterOptions WithStrongReference() { + ObjectFactoryBase currentFactory = this._registeredTypes.GetCurrentFactory(this._registration); + + if(currentFactory == null) { + throw new DependencyContainerRegistrationException(this._registration.Type, "strong reference"); + } + + return this._registeredTypes.AddUpdateRegistration(this._registration, currentFactory.StrongReferenceVariant); + } + } + + /// + /// Registration options for "fluent" API when registering multiple implementations. + /// + public sealed class MultiRegisterOptions { + private IEnumerable _registerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// The register options. + public MultiRegisterOptions(IEnumerable registerOptions) => this._registerOptions = registerOptions; + + /// + /// Make registration a singleton (single instance) if possible. + /// + /// A registration multi-instance for fluent API. + /// Generic Constraint Registration Exception. + public MultiRegisterOptions AsSingleton() { + this._registerOptions = this.ExecuteOnAllRegisterOptions(ro => ro.AsSingleton()); + return this; + } + + /// + /// Make registration multi-instance if possible. + /// + /// A registration multi-instance for fluent API. + /// Generic Constraint Registration Exception. + public MultiRegisterOptions AsMultiInstance() { + this._registerOptions = this.ExecuteOnAllRegisterOptions(ro => ro.AsMultiInstance()); + return this; + } + + private IEnumerable ExecuteOnAllRegisterOptions( + Func action) => this._registerOptions.Select(action).ToList(); + } } \ No newline at end of file diff --git a/Swan/DependencyInjection/TypeRegistration.cs b/Swan/DependencyInjection/TypeRegistration.cs index 0b196f8..2af00c0 100644 --- a/Swan/DependencyInjection/TypeRegistration.cs +++ b/Swan/DependencyInjection/TypeRegistration.cs @@ -1,67 +1,61 @@ -namespace Swan.DependencyInjection -{ - using System; - - public partial class DependencyContainer - { - /// - /// Represents a Type Registration within the IoC Container. - /// - public sealed class TypeRegistration - { - private readonly int _hashCode; - - /// - /// Initializes a new instance of the class. - /// - /// The type. - /// The name. - public TypeRegistration(Type type, string name = null) - { - Type = type; - Name = name ?? string.Empty; - - _hashCode = string.Concat(Type.FullName, "|", Name).GetHashCode(); - } - - /// - /// Gets the type. - /// - /// - /// The type. - /// - public Type Type { get; } - - /// - /// Gets the name. - /// - /// - /// The name. - /// - public string Name { get; } - - /// - /// Determines whether the specified , is equal to this instance. - /// - /// The to compare with this instance. - /// - /// true if the specified is equal to this instance; otherwise, false. - /// - public override bool Equals(object obj) - { - if (!(obj is TypeRegistration typeRegistration) || typeRegistration.Type != Type) - return false; - - return string.Compare(Name, typeRegistration.Name, StringComparison.Ordinal) == 0; - } - - /// - /// Returns a hash code for this instance. - /// - /// - /// A hash code for this instance, suitable for use in hashing algorithms and data structures like a hash table. - /// - public override int GetHashCode() => _hashCode; - } - } +using System; + +namespace Swan.DependencyInjection { + public partial class DependencyContainer { + /// + /// Represents a Type Registration within the IoC Container. + /// + public sealed class TypeRegistration { + private readonly Int32 _hashCode; + + /// + /// Initializes a new instance of the class. + /// + /// The type. + /// The name. + public TypeRegistration(Type type, String name = null) { + this.Type = type; + this.Name = name ?? String.Empty; + + this._hashCode = String.Concat(this.Type.FullName, "|", this.Name).GetHashCode(); + } + + /// + /// Gets the type. + /// + /// + /// The type. + /// + public Type Type { + get; + } + + /// + /// Gets the name. + /// + /// + /// The name. + /// + public String Name { + get; + } + + /// + /// Determines whether the specified , is equal to this instance. + /// + /// The to compare with this instance. + /// + /// true if the specified is equal to this instance; otherwise, false. + /// + public override Boolean Equals(Object obj) => !(obj is TypeRegistration typeRegistration) || typeRegistration.Type != this.Type ? false : String.Compare(this.Name, typeRegistration.Name, StringComparison.Ordinal) == 0; + + /// + /// Returns a hash code for this instance. + /// + /// + /// A hash code for this instance, suitable for use in hashing algorithms and data structures like a hash table. + /// + public override Int32 GetHashCode() => this._hashCode; + } + } } \ No newline at end of file diff --git a/Swan/DependencyInjection/TypesConcurrentDictionary.cs b/Swan/DependencyInjection/TypesConcurrentDictionary.cs index dd9a590..143ffb1 100644 --- a/Swan/DependencyInjection/TypesConcurrentDictionary.cs +++ b/Swan/DependencyInjection/TypesConcurrentDictionary.cs @@ -1,351 +1,265 @@ -namespace Swan.DependencyInjection -{ - using System; - using System.Linq.Expressions; - using System.Reflection; - using System.Collections.Generic; - using System.Linq; - using System.Collections.Concurrent; - +#nullable enable +using System; +using System.Linq.Expressions; +using System.Reflection; +using System.Collections.Generic; +using System.Linq; +using System.Collections.Concurrent; + +namespace Swan.DependencyInjection { + /// + /// Represents a Concurrent Dictionary for TypeRegistration. + /// + public class TypesConcurrentDictionary : ConcurrentDictionary { + private static readonly ConcurrentDictionary ObjectConstructorCache = new ConcurrentDictionary(); + + private readonly DependencyContainer _dependencyContainer; + + internal TypesConcurrentDictionary(DependencyContainer dependencyContainer) => this._dependencyContainer = dependencyContainer; + /// - /// Represents a Concurrent Dictionary for TypeRegistration. + /// Represents a delegate to build an object with the parameters. /// - public class TypesConcurrentDictionary : ConcurrentDictionary - { - private static readonly ConcurrentDictionary ObjectConstructorCache = - new ConcurrentDictionary(); - - private readonly DependencyContainer _dependencyContainer; - - internal TypesConcurrentDictionary(DependencyContainer dependencyContainer) - { - _dependencyContainer = dependencyContainer; - } - - /// - /// Represents a delegate to build an object with the parameters. - /// - /// The parameters. - /// The built object. - public delegate object ObjectConstructor(params object[] parameters); - - internal IEnumerable Resolve(Type resolveType, bool includeUnnamed) - { - var registrations = Keys.Where(tr => tr.Type == resolveType) - .Concat(GetParentRegistrationsForType(resolveType)).Distinct(); - - if (!includeUnnamed) - registrations = registrations.Where(tr => !string.IsNullOrEmpty(tr.Name)); - - return registrations.Select(registration => - ResolveInternal(registration, DependencyContainerResolveOptions.Default)); - } - - internal ObjectFactoryBase GetCurrentFactory(DependencyContainer.TypeRegistration registration) - { - TryGetValue(registration, out var current); - - return current; - } - - internal RegisterOptions Register(Type registerType, string name, ObjectFactoryBase factory) - => AddUpdateRegistration(new DependencyContainer.TypeRegistration(registerType, name), factory); - - internal RegisterOptions AddUpdateRegistration(DependencyContainer.TypeRegistration typeRegistration, ObjectFactoryBase factory) - { - this[typeRegistration] = factory; - - return new RegisterOptions(this, typeRegistration); - } - - internal bool RemoveRegistration(DependencyContainer.TypeRegistration typeRegistration) - => TryRemove(typeRegistration, out _); - - internal object ResolveInternal( - DependencyContainer.TypeRegistration registration, - DependencyContainerResolveOptions? options = null) - { - if (options == null) - options = DependencyContainerResolveOptions.Default; - - // Attempt container resolution - if (TryGetValue(registration, out var factory)) - { - try - { - return factory.GetObject(registration.Type, _dependencyContainer, options); - } - catch (DependencyContainerResolutionException) - { - throw; - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(registration.Type, ex); - } - } - - // Attempt to get a factory from parent if we can - var bubbledObjectFactory = GetParentObjectFactory(registration); - if (bubbledObjectFactory != null) - { - try - { - return bubbledObjectFactory.GetObject(registration.Type, _dependencyContainer, options); - } - catch (DependencyContainerResolutionException) - { - throw; - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(registration.Type, ex); - } - } - - // Fail if requesting named resolution and settings set to fail if unresolved - if (!string.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == - DependencyContainerNamedResolutionFailureAction.Fail) - throw new DependencyContainerResolutionException(registration.Type); - - // Attempted unnamed fallback container resolution if relevant and requested - if (!string.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == - DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) - { - if (TryGetValue(new DependencyContainer.TypeRegistration(registration.Type, string.Empty), out factory)) - { - try - { - return factory.GetObject(registration.Type, _dependencyContainer, options); - } - catch (DependencyContainerResolutionException) - { - throw; - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(registration.Type, ex); - } - } - } - - // Attempt unregistered construction if possible and requested - var isValid = (options.UnregisteredResolutionAction == - DependencyContainerUnregisteredResolutionAction.AttemptResolve) || - (registration.Type.IsGenericType && options.UnregisteredResolutionAction == - DependencyContainerUnregisteredResolutionAction.GenericsOnly); - - return isValid && !registration.Type.IsAbstract && !registration.Type.IsInterface - ? ConstructType(registration.Type, null, options) - : throw new DependencyContainerResolutionException(registration.Type); - } - - internal bool CanResolve( - DependencyContainer.TypeRegistration registration, - DependencyContainerResolveOptions? options = null) - { - if (options == null) - options = DependencyContainerResolveOptions.Default; - - var checkType = registration.Type; - var name = registration.Name; - - if (TryGetValue(new DependencyContainer.TypeRegistration(checkType, name), out var factory)) - { - if (factory.AssumeConstruction) - return true; - - if (factory.Constructor == null) - return GetBestConstructor(factory.CreatesType, options) != null; - - return CanConstruct(factory.Constructor, options); - } - - // Fail if requesting named resolution and settings set to fail if unresolved - // Or bubble up if we have a parent - if (!string.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == - DependencyContainerNamedResolutionFailureAction.Fail) - return _dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false; - - // Attempted unnamed fallback container resolution if relevant and requested - if (!string.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == - DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) - { - if (TryGetValue(new DependencyContainer.TypeRegistration(checkType), out factory)) - { - if (factory.AssumeConstruction) - return true; - - return GetBestConstructor(factory.CreatesType, options) != null; - } - } - - // Check if type is an automatic lazy factory request or an IEnumerable - if (IsAutomaticLazyFactoryRequest(checkType) || registration.Type.IsIEnumerable()) - return true; - - // Attempt unregistered construction if possible and requested - // If we cant', bubble if we have a parent - if ((options.UnregisteredResolutionAction == - DependencyContainerUnregisteredResolutionAction.AttemptResolve) || - (checkType.IsGenericType && options.UnregisteredResolutionAction == - DependencyContainerUnregisteredResolutionAction.GenericsOnly)) - { - return (GetBestConstructor(checkType, options) != null) || - (_dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false); - } - - // Bubble resolution up the container tree if we have a parent - return _dependencyContainer.Parent != null && _dependencyContainer.Parent.RegisteredTypes.CanResolve(registration, options.Clone()); - } - - internal object ConstructType( - Type implementationType, - ConstructorInfo constructor, - DependencyContainerResolveOptions? options = null) - { - var typeToConstruct = implementationType; - - if (constructor == null) - { - // Try and get the best constructor that we can construct - // if we can't construct any then get the constructor - // with the least number of parameters so we can throw a meaningful - // resolve exception - constructor = GetBestConstructor(typeToConstruct, options) ?? - GetTypeConstructors(typeToConstruct).LastOrDefault(); - } - - if (constructor == null) - throw new DependencyContainerResolutionException(typeToConstruct); - - var ctorParams = constructor.GetParameters(); - var args = new object?[ctorParams.Length]; - - for (var parameterIndex = 0; parameterIndex < ctorParams.Length; parameterIndex++) - { - var currentParam = ctorParams[parameterIndex]; - - try - { - args[parameterIndex] = options?.ConstructorParameters.GetValueOrDefault(currentParam.Name, ResolveInternal(new DependencyContainer.TypeRegistration(currentParam.ParameterType), options.Clone())); - } - catch (DependencyContainerResolutionException ex) - { - // If a constructor parameter can't be resolved - // it will throw, so wrap it and throw that this can't - // be resolved. - throw new DependencyContainerResolutionException(typeToConstruct, ex); - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(typeToConstruct, ex); - } - } - - try - { - return CreateObjectConstructionDelegateWithCache(constructor).Invoke(args); - } - catch (Exception ex) - { - throw new DependencyContainerResolutionException(typeToConstruct, ex); - } - } - - private static ObjectConstructor CreateObjectConstructionDelegateWithCache(ConstructorInfo constructor) - { - if (ObjectConstructorCache.TryGetValue(constructor, out var objectConstructor)) - return objectConstructor; - - // We could lock the cache here, but there's no real side - // effect to two threads creating the same ObjectConstructor - // at the same time, compared to the cost of a lock for - // every creation. - var constructorParams = constructor.GetParameters(); - var lambdaParams = Expression.Parameter(typeof(object[]), "parameters"); - var newParams = new Expression[constructorParams.Length]; - - for (var i = 0; i < constructorParams.Length; i++) - { - var paramsParameter = Expression.ArrayIndex(lambdaParams, Expression.Constant(i)); - - newParams[i] = Expression.Convert(paramsParameter, constructorParams[i].ParameterType); - } - - var newExpression = Expression.New(constructor, newParams); - - var constructionLambda = Expression.Lambda(typeof(ObjectConstructor), newExpression, lambdaParams); - - objectConstructor = (ObjectConstructor)constructionLambda.Compile(); - - ObjectConstructorCache[constructor] = objectConstructor; - return objectConstructor; - } - - private static IEnumerable GetTypeConstructors(Type type) - => type.GetConstructors().OrderByDescending(ctor => ctor.GetParameters().Length); - - private static bool IsAutomaticLazyFactoryRequest(Type type) - { - if (!type.IsGenericType) - return false; - - var genericType = type.GetGenericTypeDefinition(); - - // Just a func - if (genericType == typeof(Func<>)) - return true; - - // 2 parameter func with string as first parameter (name) - if (genericType == typeof(Func<,>) && type.GetGenericArguments()[0] == typeof(string)) - return true; - - // 3 parameter func with string as first parameter (name) and IDictionary as second (parameters) - return genericType == typeof(Func<,,>) && type.GetGenericArguments()[0] == typeof(string) && - type.GetGenericArguments()[1] == typeof(IDictionary); - } - - private ObjectFactoryBase? GetParentObjectFactory(DependencyContainer.TypeRegistration registration) - { - if (_dependencyContainer.Parent == null) - return null; - - return _dependencyContainer.Parent.RegisteredTypes.TryGetValue(registration, out var factory) - ? factory.GetFactoryForChildContainer(registration.Type, _dependencyContainer.Parent, _dependencyContainer) - : _dependencyContainer.Parent.RegisteredTypes.GetParentObjectFactory(registration); - } - - private ConstructorInfo? GetBestConstructor( - Type type, - DependencyContainerResolveOptions options) - => type.IsValueType ? null : GetTypeConstructors(type).FirstOrDefault(ctor => CanConstruct(ctor, options)); - - private bool CanConstruct( - MethodBase ctor, - DependencyContainerResolveOptions? options) - { - foreach (var parameter in ctor.GetParameters()) - { - if (string.IsNullOrEmpty(parameter.Name)) - return false; - - var isParameterOverload = options.ConstructorParameters.ContainsKey(parameter.Name); - - if (parameter.ParameterType.IsPrimitive && !isParameterOverload) - return false; - - if (!isParameterOverload && - !CanResolve(new DependencyContainer.TypeRegistration(parameter.ParameterType), options.Clone())) - return false; - } - - return true; - } - - private IEnumerable GetParentRegistrationsForType(Type resolveType) - => _dependencyContainer.Parent == null - ? Array.Empty() - : _dependencyContainer.Parent.RegisteredTypes.Keys.Where(tr => tr.Type == resolveType).Concat(_dependencyContainer.Parent.RegisteredTypes.GetParentRegistrationsForType(resolveType)); - } + /// The parameters. + /// The built object. + public delegate Object ObjectConstructor(params Object?[] parameters); + + internal IEnumerable Resolve(Type resolveType, Boolean includeUnnamed) { + IEnumerable registrations = this.Keys.Where(tr => tr.Type == resolveType).Concat(this.GetParentRegistrationsForType(resolveType)).Distinct(); + + if(!includeUnnamed) { + registrations = registrations.Where(tr => !String.IsNullOrEmpty(tr.Name)); + } + + return registrations.Select(registration => this.ResolveInternal(registration, DependencyContainerResolveOptions.Default)); + } + + internal ObjectFactoryBase GetCurrentFactory(DependencyContainer.TypeRegistration registration) { + _ = this.TryGetValue(registration, out ObjectFactoryBase? current); + + return current!; + } + + internal RegisterOptions Register(Type registerType, String name, ObjectFactoryBase factory) => this.AddUpdateRegistration(new DependencyContainer.TypeRegistration(registerType, name), factory); + + internal RegisterOptions AddUpdateRegistration(DependencyContainer.TypeRegistration typeRegistration, ObjectFactoryBase factory) { + this[typeRegistration] = factory; + + return new RegisterOptions(this, typeRegistration); + } + + internal Boolean RemoveRegistration(DependencyContainer.TypeRegistration typeRegistration) => this.TryRemove(typeRegistration, out _); + + internal Object ResolveInternal(DependencyContainer.TypeRegistration registration, DependencyContainerResolveOptions? options = null) { + if(options == null) { + options = DependencyContainerResolveOptions.Default; + } + + // Attempt container resolution + if(this.TryGetValue(registration, out ObjectFactoryBase? factory)) { + try { + return factory.GetObject(registration.Type, this._dependencyContainer, options); + } catch(DependencyContainerResolutionException) { + throw; + } catch(Exception ex) { + throw new DependencyContainerResolutionException(registration.Type, ex); + } + } + + // Attempt to get a factory from parent if we can + ObjectFactoryBase? bubbledObjectFactory = this.GetParentObjectFactory(registration); + if(bubbledObjectFactory != null) { + try { + return bubbledObjectFactory.GetObject(registration.Type, this._dependencyContainer, options); + } catch(DependencyContainerResolutionException) { + throw; + } catch(Exception ex) { + throw new DependencyContainerResolutionException(registration.Type, ex); + } + } + + // Fail if requesting named resolution and settings set to fail if unresolved + if(!String.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.Fail) { + throw new DependencyContainerResolutionException(registration.Type); + } + + // Attempted unnamed fallback container resolution if relevant and requested + if(!String.IsNullOrEmpty(registration.Name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) { + if(this.TryGetValue(new DependencyContainer.TypeRegistration(registration.Type, String.Empty), out factory)) { + try { + return factory.GetObject(registration.Type, this._dependencyContainer, options); + } catch(DependencyContainerResolutionException) { + throw; + } catch(Exception ex) { + throw new DependencyContainerResolutionException(registration.Type, ex); + } + } + } + + // Attempt unregistered construction if possible and requested + Boolean isValid = options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.AttemptResolve || registration.Type.IsGenericType && options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.GenericsOnly; + + return isValid && !registration.Type.IsAbstract && !registration.Type.IsInterface ? this.ConstructType(registration.Type, null, options) : throw new DependencyContainerResolutionException(registration.Type); + } + + internal Boolean CanResolve(DependencyContainer.TypeRegistration registration, DependencyContainerResolveOptions? options = null) { + if(options == null) { + options = DependencyContainerResolveOptions.Default; + } + + Type checkType = registration.Type; + String name = registration.Name; + + if(this.TryGetValue(new DependencyContainer.TypeRegistration(checkType, name), out ObjectFactoryBase? factory)) { + return factory.AssumeConstruction ? true : factory.Constructor == null ? this.GetBestConstructor(factory.CreatesType, options) != null : this.CanConstruct(factory.Constructor, options); + } + + // Fail if requesting named resolution and settings set to fail if unresolved + // Or bubble up if we have a parent + if(!String.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.Fail) { + return this._dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false; + } + + // Attempted unnamed fallback container resolution if relevant and requested + if(!String.IsNullOrEmpty(name) && options.NamedResolutionFailureAction == DependencyContainerNamedResolutionFailureAction.AttemptUnnamedResolution) { + if(this.TryGetValue(new DependencyContainer.TypeRegistration(checkType), out factory)) { + return factory.AssumeConstruction ? true : this.GetBestConstructor(factory.CreatesType, options) != null; + } + } + + // Check if type is an automatic lazy factory request or an IEnumerable + if(IsAutomaticLazyFactoryRequest(checkType) || registration.Type.IsIEnumerable()) { + return true; + } + + // Attempt unregistered construction if possible and requested + // If we cant', bubble if we have a parent + if(options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.AttemptResolve || checkType.IsGenericType && options.UnregisteredResolutionAction == DependencyContainerUnregisteredResolutionAction.GenericsOnly) { + return this.GetBestConstructor(checkType, options) != null || (this._dependencyContainer.Parent?.RegisteredTypes.CanResolve(registration, options.Clone()) ?? false); + } + + // Bubble resolution up the container tree if we have a parent + return this._dependencyContainer.Parent != null && this._dependencyContainer.Parent.RegisteredTypes.CanResolve(registration, options.Clone()); + } + + internal Object ConstructType(Type implementationType, ConstructorInfo? constructor, DependencyContainerResolveOptions? options = null) { + Type typeToConstruct = implementationType; + + if(constructor == null) { + // Try and get the best constructor that we can construct + // if we can't construct any then get the constructor + // with the least number of parameters so we can throw a meaningful + // resolve exception + constructor = this.GetBestConstructor(typeToConstruct, options) ?? GetTypeConstructors(typeToConstruct).LastOrDefault(); + } + + if(constructor == null) { + throw new DependencyContainerResolutionException(typeToConstruct); + } + + ParameterInfo[] ctorParams = constructor.GetParameters(); + Object?[] args = new Object?[ctorParams.Length]; + + for(Int32 parameterIndex = 0; parameterIndex < ctorParams.Length; parameterIndex++) { + ParameterInfo currentParam = ctorParams[parameterIndex]; + + try { + args[parameterIndex] = options?.ConstructorParameters.GetValueOrDefault(currentParam.Name, this.ResolveInternal(new DependencyContainer.TypeRegistration(currentParam.ParameterType), options.Clone())); + } catch(DependencyContainerResolutionException ex) { + // If a constructor parameter can't be resolved + // it will throw, so wrap it and throw that this can't + // be resolved. + throw new DependencyContainerResolutionException(typeToConstruct, ex); + } catch(Exception ex) { + throw new DependencyContainerResolutionException(typeToConstruct, ex); + } + } + + try { + return CreateObjectConstructionDelegateWithCache(constructor).Invoke(args); + } catch(Exception ex) { + throw new DependencyContainerResolutionException(typeToConstruct, ex); + } + } + + private static ObjectConstructor CreateObjectConstructionDelegateWithCache(ConstructorInfo constructor) { + if(ObjectConstructorCache.TryGetValue(constructor, out ObjectConstructor? objectConstructor)) { + return objectConstructor; + } + + // We could lock the cache here, but there's no real side + // effect to two threads creating the same ObjectConstructor + // at the same time, compared to the cost of a lock for + // every creation. + ParameterInfo[] constructorParams = constructor.GetParameters(); + ParameterExpression lambdaParams = Expression.Parameter(typeof(Object[]), "parameters"); + Expression[] newParams = new Expression[constructorParams.Length]; + + for(Int32 i = 0; i < constructorParams.Length; i++) { + BinaryExpression paramsParameter = Expression.ArrayIndex(lambdaParams, Expression.Constant(i)); + + newParams[i] = Expression.Convert(paramsParameter, constructorParams[i].ParameterType); + } + + NewExpression newExpression = Expression.New(constructor, newParams); + + LambdaExpression constructionLambda = Expression.Lambda(typeof(ObjectConstructor), newExpression, lambdaParams); + + objectConstructor = (ObjectConstructor)constructionLambda.Compile(); + + ObjectConstructorCache[constructor] = objectConstructor; + return objectConstructor; + } + + private static IEnumerable GetTypeConstructors(Type type) => type.GetConstructors().OrderByDescending(ctor => ctor.GetParameters().Length); + + private static Boolean IsAutomaticLazyFactoryRequest(Type type) { + if(!type.IsGenericType) { + return false; + } + + Type genericType = type.GetGenericTypeDefinition(); + + // Just a func + if(genericType == typeof(Func<>)) { + return true; + } + + // 2 parameter func with string as first parameter (name) + if(genericType == typeof(Func<,>) && type.GetGenericArguments()[0] == typeof(String)) { + return true; + } + + // 3 parameter func with string as first parameter (name) and IDictionary as second (parameters) + return genericType == typeof(Func<,,>) && type.GetGenericArguments()[0] == typeof(String) && type.GetGenericArguments()[1] == typeof(IDictionary); + } + + private ObjectFactoryBase? GetParentObjectFactory(DependencyContainer.TypeRegistration registration) => this._dependencyContainer.Parent == null + ? null + : this._dependencyContainer.Parent.RegisteredTypes.TryGetValue(registration, out ObjectFactoryBase? factory) ? factory.GetFactoryForChildContainer(registration.Type, this._dependencyContainer.Parent, this._dependencyContainer) : this._dependencyContainer.Parent.RegisteredTypes.GetParentObjectFactory(registration); + + private ConstructorInfo? GetBestConstructor(Type type, DependencyContainerResolveOptions? options) => type.IsValueType ? null : GetTypeConstructors(type).FirstOrDefault(ctor => this.CanConstruct(ctor, options)); + + private Boolean CanConstruct(MethodBase ctor, DependencyContainerResolveOptions? options) { + foreach(ParameterInfo parameter in ctor.GetParameters()) { + if(String.IsNullOrEmpty(parameter.Name)) { + return false; + } + + Boolean isParameterOverload = options!.ConstructorParameters.ContainsKey(parameter.Name); + + if(parameter.ParameterType.IsPrimitive && !isParameterOverload) { + return false; + } + + if(!isParameterOverload && !this.CanResolve(new DependencyContainer.TypeRegistration(parameter.ParameterType), options.Clone())) { + return false; + } + } + + return true; + } + + private IEnumerable GetParentRegistrationsForType(Type resolveType) => this._dependencyContainer.Parent == null ? Array.Empty() : this._dependencyContainer.Parent.RegisteredTypes.Keys.Where(tr => tr.Type == resolveType).Concat(this._dependencyContainer.Parent.RegisteredTypes.GetParentRegistrationsForType(resolveType)); + } } diff --git a/Swan/Diagnostics/RealtimeClock.cs b/Swan/Diagnostics/RealtimeClock.cs index d4ebb08..5d2648f 100644 --- a/Swan/Diagnostics/RealtimeClock.cs +++ b/Swan/Diagnostics/RealtimeClock.cs @@ -1,143 +1,128 @@ -namespace Swan.Diagnostics -{ - using System; - using System.Diagnostics; - using Threading; - +#nullable enable +using System; +using System.Diagnostics; +using Swan.Threading; + +namespace Swan.Diagnostics { + /// + /// A time measurement artifact. + /// + internal sealed class RealTimeClock : IDisposable { + private readonly Stopwatch _chrono = new Stopwatch(); + private ISyncLocker? _locker = SyncLockerFactory.Create(useSlim: true); + private Int64 _offsetTicks; + private Double _speedRatio = 1.0d; + private Boolean _isDisposed; + /// - /// A time measurement artifact. + /// Initializes a new instance of the class. + /// The clock starts paused and at the 0 position. /// - internal sealed class RealTimeClock : IDisposable - { - private readonly Stopwatch _chrono = new Stopwatch(); - private ISyncLocker? _locker = SyncLockerFactory.Create(useSlim: true); - private long _offsetTicks; - private double _speedRatio = 1.0d; - private bool _isDisposed; - - /// - /// Initializes a new instance of the class. - /// The clock starts paused and at the 0 position. - /// - public RealTimeClock() - { - Reset(); - } - - /// - /// Gets or sets the clock position. - /// - public TimeSpan Position - { - get - { - using (_locker?.AcquireReaderLock()) - { - return TimeSpan.FromTicks( - _offsetTicks + Convert.ToInt64(_chrono.Elapsed.Ticks * SpeedRatio)); - } - } - } - - /// - /// Gets a value indicating whether the clock is running. - /// - public bool IsRunning - { - get - { - using (_locker?.AcquireReaderLock()) - { - return _chrono.IsRunning; - } - } - } - - /// - /// Gets or sets the speed ratio at which the clock runs. - /// - public double SpeedRatio - { - get - { - using (_locker?.AcquireReaderLock()) - { - return _speedRatio; - } - } - set - { - using (_locker?.AcquireWriterLock()) - { - if (value < 0d) value = 0d; - - // Capture the initial position se we set it even after the Speed Ratio has changed - // this ensures a smooth position transition - var initialPosition = Position; - _speedRatio = value; - Update(initialPosition); - } - } - } - - /// - /// Sets a new position value atomically. - /// - /// The new value that the position property will hold. - public void Update(TimeSpan value) - { - using (_locker?.AcquireWriterLock()) - { - var resume = _chrono.IsRunning; - _chrono.Reset(); - _offsetTicks = value.Ticks; - if (resume) _chrono.Start(); - } - } - - /// - /// Starts or resumes the clock. - /// - public void Play() - { - using (_locker?.AcquireWriterLock()) - { - if (_chrono.IsRunning) return; - _chrono.Start(); - } - } - - /// - /// Pauses the clock. - /// - public void Pause() - { - using (_locker?.AcquireWriterLock()) - { - _chrono.Stop(); - } - } - - /// - /// Sets the clock position to 0 and stops it. - /// The speed ratio is not modified. - /// - public void Reset() - { - using (_locker?.AcquireWriterLock()) - { - _offsetTicks = 0; - _chrono.Reset(); - } - } - - /// - public void Dispose() - { - if (_isDisposed) return; - _isDisposed = true; - _locker?.Dispose(); - _locker = null; - } - } + public RealTimeClock() => this.Reset(); + + /// + /// Gets or sets the clock position. + /// + public TimeSpan Position { + get { + using(this._locker?.AcquireReaderLock()) { + return TimeSpan.FromTicks(this._offsetTicks + Convert.ToInt64(this._chrono.Elapsed.Ticks * this.SpeedRatio)); + } + } + } + + /// + /// Gets a value indicating whether the clock is running. + /// + public Boolean IsRunning { + get { + using(this._locker?.AcquireReaderLock()) { + return this._chrono.IsRunning; + } + } + } + + /// + /// Gets or sets the speed ratio at which the clock runs. + /// + public Double SpeedRatio { + get { + using(this._locker?.AcquireReaderLock()) { + return this._speedRatio; + } + } + set { + using(this._locker?.AcquireWriterLock()) { + if(value < 0d) { + value = 0d; + } + + // Capture the initial position se we set it even after the Speed Ratio has changed + // this ensures a smooth position transition + TimeSpan initialPosition = this.Position; + this._speedRatio = value; + this.Update(initialPosition); + } + } + } + + /// + /// Sets a new position value atomically. + /// + /// The new value that the position property will hold. + public void Update(TimeSpan value) { + using(this._locker?.AcquireWriterLock()) { + Boolean resume = this._chrono.IsRunning; + this._chrono.Reset(); + this._offsetTicks = value.Ticks; + if(resume) { + this._chrono.Start(); + } + } + } + + /// + /// Starts or resumes the clock. + /// + public void Play() { + using(this._locker?.AcquireWriterLock()) { + if(this._chrono.IsRunning) { + return; + } + + this._chrono.Start(); + } + } + + /// + /// Pauses the clock. + /// + public void Pause() { + using(this._locker?.AcquireWriterLock()) { + this._chrono.Stop(); + } + } + + /// + /// Sets the clock position to 0 and stops it. + /// The speed ratio is not modified. + /// + public void Reset() { + using(this._locker?.AcquireWriterLock()) { + this._offsetTicks = 0; + this._chrono.Reset(); + } + } + + /// + public void Dispose() { + if(this._isDisposed) { + return; + } + + this._isDisposed = true; + this._locker?.Dispose(); + this._locker = null; + } + } } diff --git a/Swan/Extensions.MimeMessage.cs b/Swan/Extensions.MimeMessage.cs index 2f42f3d..a4a0ca9 100644 --- a/Swan/Extensions.MimeMessage.cs +++ b/Swan/Extensions.MimeMessage.cs @@ -1,56 +1,43 @@ -namespace Swan -{ - using System; - using System.IO; - using System.Net.Mail; - using System.Reflection; - +using System; +using System.IO; +using System.Net.Mail; +using System.Reflection; + +namespace Swan { + /// + /// Extension methods. + /// + public static class SmtpExtensions { + private static readonly BindingFlags PrivateInstanceFlags = BindingFlags.Instance | BindingFlags.NonPublic; + /// - /// Extension methods. + /// The raw contents of this MailMessage as a MemoryStream. /// - public static class SmtpExtensions - { - private static readonly BindingFlags PrivateInstanceFlags = BindingFlags.Instance | BindingFlags.NonPublic; - - /// - /// The raw contents of this MailMessage as a MemoryStream. - /// - /// The caller. - /// A MemoryStream with the raw contents of this MailMessage. - public static MemoryStream ToMimeMessage(this MailMessage @this) - { - if (@this == null) - throw new ArgumentNullException(nameof(@this)); - - var result = new MemoryStream(); - var mailWriter = MimeMessageConstants.MailWriterConstructor.Invoke(new object[] { result }); - MimeMessageConstants.SendMethod.Invoke( - @this, - PrivateInstanceFlags, - null, - MimeMessageConstants.IsRunningInDotNetFourPointFive ? new[] { mailWriter, true, true } : new[] { mailWriter, true }, - null); - - result = new MemoryStream(result.ToArray()); - MimeMessageConstants.CloseMethod.Invoke( - mailWriter, - PrivateInstanceFlags, - null, - Array.Empty(), - null); - result.Position = 0; - return result; - } - - internal static class MimeMessageConstants - { + /// The caller. + /// A MemoryStream with the raw contents of this MailMessage. + public static MemoryStream ToMimeMessage(this MailMessage @this) { + if(@this == null) { + throw new ArgumentNullException(nameof(@this)); + } + + MemoryStream result = new MemoryStream(); + Object mailWriter = MimeMessageConstants.MailWriterConstructor.Invoke(new Object[] { result }); + _ = MimeMessageConstants.SendMethod.Invoke(@this, PrivateInstanceFlags, null, MimeMessageConstants.IsRunningInDotNetFourPointFive ? new[] { mailWriter, true, true } : new[] { mailWriter, true }, null); + + result = new MemoryStream(result.ToArray()); + _ = MimeMessageConstants.CloseMethod.Invoke(mailWriter, PrivateInstanceFlags, null, Array.Empty(), null); + result.Position = 0; + return result; + } + + internal static class MimeMessageConstants { #pragma warning disable DE0005 // API is deprecated - public static readonly Type MailWriter = typeof(SmtpClient).Assembly.GetType("System.Net.Mail.MailWriter"); + public static readonly Type MailWriter = typeof(SmtpClient).Assembly.GetType("System.Net.Mail.MailWriter"); #pragma warning restore DE0005 // API is deprecated - public static readonly ConstructorInfo MailWriterConstructor = MailWriter.GetConstructor(PrivateInstanceFlags, null, new[] { typeof(Stream) }, null); - public static readonly MethodInfo CloseMethod = MailWriter.GetMethod("Close", PrivateInstanceFlags); - public static readonly MethodInfo SendMethod = typeof(MailMessage).GetMethod("Send", PrivateInstanceFlags); - public static readonly bool IsRunningInDotNetFourPointFive = SendMethod.GetParameters().Length == 3; - } - } + public static readonly ConstructorInfo MailWriterConstructor = MailWriter.GetConstructor(PrivateInstanceFlags, null, new[] { typeof(Stream) }, null); + public static readonly MethodInfo CloseMethod = MailWriter.GetMethod("Close", PrivateInstanceFlags); + public static readonly MethodInfo SendMethod = typeof(MailMessage).GetMethod("Send", PrivateInstanceFlags); + public static readonly Boolean IsRunningInDotNetFourPointFive = SendMethod.GetParameters().Length == 3; + } + } } \ No newline at end of file diff --git a/Swan/Extensions.Network.cs b/Swan/Extensions.Network.cs index 32ffef6..5c4955f 100644 --- a/Swan/Extensions.Network.cs +++ b/Swan/Extensions.Network.cs @@ -1,58 +1,58 @@ -namespace Swan -{ - using System; - using System.Linq; - using System.Net; - using System.Net.Sockets; - +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; + +namespace Swan { + /// + /// Provides various extension methods for networking-related tasks. + /// + public static class NetworkExtensions { /// - /// Provides various extension methods for networking-related tasks. + /// Determines whether the IP address is private. /// - public static class NetworkExtensions - { - /// - /// Determines whether the IP address is private. - /// - /// The IP address. - /// - /// True if the IP Address is private; otherwise, false. - /// - /// address. - public static bool IsPrivateAddress(this IPAddress @this) - { - if (@this == null) - throw new ArgumentNullException(nameof(@this)); - - var octets = @this.ToString().Split(new[] { "." }, StringSplitOptions.RemoveEmptyEntries).Select(byte.Parse).ToArray(); - var is24Bit = octets[0] == 10; - var is20Bit = octets[0] == 172 && (octets[1] >= 16 && octets[1] <= 31); - var is16Bit = octets[0] == 192 && octets[1] == 168; - - return is24Bit || is20Bit || is16Bit; - } - - /// - /// Converts an IPv4 Address to its Unsigned, 32-bit integer representation. - /// - /// The address. - /// - /// A 32-bit unsigned integer converted from four bytes at a specified position in a byte array. - /// - /// address. - /// InterNetwork - address. - public static uint ToUInt32(this IPAddress @this) - { - if (@this == null) - throw new ArgumentNullException(nameof(@this)); - - if (@this.AddressFamily != AddressFamily.InterNetwork) - throw new ArgumentException($"Address has to be of family '{nameof(AddressFamily.InterNetwork)}'", nameof(@this)); - - var addressBytes = @this.GetAddressBytes(); - if (BitConverter.IsLittleEndian) - Array.Reverse(addressBytes); - - return BitConverter.ToUInt32(addressBytes, 0); - } - } + /// The IP address. + /// + /// True if the IP Address is private; otherwise, false. + /// + /// address. + public static Boolean IsPrivateAddress(this IPAddress @this) { + if(@this == null) { + throw new ArgumentNullException(nameof(@this)); + } + + Byte[] octets = @this.ToString().Split(new[] { "." }, StringSplitOptions.RemoveEmptyEntries).Select(Byte.Parse).ToArray(); + Boolean is24Bit = octets[0] == 10; + Boolean is20Bit = octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31; + Boolean is16Bit = octets[0] == 192 && octets[1] == 168; + + return is24Bit || is20Bit || is16Bit; + } + + /// + /// Converts an IPv4 Address to its Unsigned, 32-bit integer representation. + /// + /// The address. + /// + /// A 32-bit unsigned integer converted from four bytes at a specified position in a byte array. + /// + /// address. + /// InterNetwork - address. + public static UInt32 ToUInt32(this IPAddress @this) { + if(@this == null) { + throw new ArgumentNullException(nameof(@this)); + } + + if(@this.AddressFamily != AddressFamily.InterNetwork) { + throw new ArgumentException($"Address has to be of family '{nameof(AddressFamily.InterNetwork)}'", nameof(@this)); + } + + Byte[] addressBytes = @this.GetAddressBytes(); + if(BitConverter.IsLittleEndian) { + Array.Reverse(addressBytes); + } + + return BitConverter.ToUInt32(addressBytes, 0); + } + } } diff --git a/Swan/Extensions.WindowsServices.cs b/Swan/Extensions.WindowsServices.cs index 420f15d..b7e45a2 100644 --- a/Swan/Extensions.WindowsServices.cs +++ b/Swan/Extensions.WindowsServices.cs @@ -1,89 +1,81 @@ -namespace Swan -{ - using Logging; - using System; - using System.Collections.Generic; - using System.Reflection; - using System.Threading; -#if NET461 - using System.ServiceProcess; -#else - using Services; -#endif - +using Swan.Logging; +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Threading; + +using Swan.Services; + +namespace Swan { + /// + /// Extension methods. + /// + public static class WindowsServicesExtensions { /// - /// Extension methods. + /// Runs a service in console mode. /// - public static class WindowsServicesExtensions - { - /// - /// Runs a service in console mode. - /// - /// The service to run. - /// The logger source. - /// this. - [Obsolete("This extension method will be removed in version 3.0")] - public static void RunInConsoleMode(this ServiceBase @this, string loggerSource = null) - { - if (@this == null) - throw new ArgumentNullException(nameof(@this)); - - RunInConsoleMode(new[] { @this }, loggerSource); - } - - /// - /// Runs a set of services in console mode. - /// - /// The services to run. - /// The logger source. - /// this. - /// The ServiceBase class isn't available. - [Obsolete("This extension method will be removed in version 3.0")] - public static void RunInConsoleMode(this ServiceBase[] @this, string loggerSource = null) - { - if (@this == null) - throw new ArgumentNullException(nameof(@this)); - - const string onStartMethodName = "OnStart"; - const string onStopMethodName = "OnStop"; - - var onStartMethod = typeof(ServiceBase).GetMethod(onStartMethodName, - BindingFlags.Instance | BindingFlags.NonPublic); - var onStopMethod = typeof(ServiceBase).GetMethod(onStopMethodName, - BindingFlags.Instance | BindingFlags.NonPublic); - - if (onStartMethod == null || onStopMethod == null) - throw new InvalidOperationException("The ServiceBase class isn't available."); - - var serviceThreads = new List(); - "Starting services . . .".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); - - foreach (var service in @this) - { - var thread = new Thread(() => - { - onStartMethod.Invoke(service, new object[] { Array.Empty() }); - $"Started service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); - }); - - serviceThreads.Add(thread); - thread.Start(); - } - - "Press any key to stop all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); - Terminal.ReadKey(true, true); - "Stopping services . . .".Info(SwanRuntime.EntryAssemblyName.Name); - - foreach (var service in @this) - { - onStopMethod.Invoke(service, null); - $"Stopped service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); - } - - foreach (var thread in serviceThreads) - thread.Join(); - - "Stopped all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); - } - } + /// The service to run. + /// The logger source. + /// this. + [Obsolete("This extension method will be removed in version 3.0")] + public static void RunInConsoleMode(this ServiceBase @this, String loggerSource = null) { + if(@this == null) { + throw new ArgumentNullException(nameof(@this)); + } + + RunInConsoleMode(new[] { @this }, loggerSource); + } + + /// + /// Runs a set of services in console mode. + /// + /// The services to run. + /// The logger source. + /// this. + /// The ServiceBase class isn't available. + [Obsolete("This extension method will be removed in version 3.0")] + public static void RunInConsoleMode(this ServiceBase[] @this, String loggerSource = null) { + if(@this == null) { + throw new ArgumentNullException(nameof(@this)); + } + + const String onStartMethodName = "OnStart"; + const String onStopMethodName = "OnStop"; + + MethodInfo onStartMethod = typeof(ServiceBase).GetMethod(onStartMethodName, BindingFlags.Instance | BindingFlags.NonPublic); + MethodInfo onStopMethod = typeof(ServiceBase).GetMethod(onStopMethodName, BindingFlags.Instance | BindingFlags.NonPublic); + + if(onStartMethod == null || onStopMethod == null) { + throw new InvalidOperationException("The ServiceBase class isn't available."); + } + + List serviceThreads = new List(); + "Starting services . . .".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); + + foreach(ServiceBase service in @this) { + Thread thread = new Thread(() => { + _ = onStartMethod.Invoke(service, new Object[] { Array.Empty() }); + $"Started service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); + }); + + serviceThreads.Add(thread); + thread.Start(); + } + + "Press any key to stop all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); + _ = Terminal.ReadKey(true, true); + "Stopping services . . .".Info(SwanRuntime.EntryAssemblyName.Name); + + foreach(ServiceBase service in @this) { + _ = onStopMethod.Invoke(service, null); + $"Stopped service '{service.GetType().Name}'".Info(loggerSource ?? service.GetType().Name); + } + + foreach(Thread thread in serviceThreads) { + thread.Join(); + } + + "Stopped all services.".Info(loggerSource ?? SwanRuntime.EntryAssemblyName.Name); + } + } } diff --git a/Swan/Messaging/IMessageHubMessage.cs b/Swan/Messaging/IMessageHubMessage.cs index 6cc2869..c9a7908 100644 --- a/Swan/Messaging/IMessageHubMessage.cs +++ b/Swan/Messaging/IMessageHubMessage.cs @@ -1,13 +1,15 @@ -namespace Swan.Messaging -{ +using System; + +namespace Swan.Messaging { + /// + /// A Message to be published/delivered by Messenger. + /// + public interface IMessageHubMessage { /// - /// A Message to be published/delivered by Messenger. + /// The sender of the message, or null if not supported by the message implementation. /// - public interface IMessageHubMessage - { - /// - /// The sender of the message, or null if not supported by the message implementation. - /// - object Sender { get; } - } + Object Sender { + get; + } + } } diff --git a/Swan/Messaging/IMessageHubSubscription.cs b/Swan/Messaging/IMessageHubSubscription.cs index 0248b85..33bb877 100644 --- a/Swan/Messaging/IMessageHubSubscription.cs +++ b/Swan/Messaging/IMessageHubSubscription.cs @@ -1,26 +1,28 @@ -namespace Swan.Messaging -{ +using System; + +namespace Swan.Messaging { + /// + /// Represents a message subscription. + /// + public interface IMessageHubSubscription { /// - /// Represents a message subscription. + /// Token returned to the subscribed to reference this subscription. /// - public interface IMessageHubSubscription - { - /// - /// Token returned to the subscribed to reference this subscription. - /// - MessageHubSubscriptionToken SubscriptionToken { get; } - - /// - /// Whether delivery should be attempted. - /// - /// Message that may potentially be delivered. - /// true - ok to send, false - should not attempt to send. - bool ShouldAttemptDelivery(IMessageHubMessage message); - - /// - /// Deliver the message. - /// - /// Message to deliver. - void Deliver(IMessageHubMessage message); - } + MessageHubSubscriptionToken SubscriptionToken { + get; + } + + /// + /// Whether delivery should be attempted. + /// + /// Message that may potentially be delivered. + /// true - ok to send, false - should not attempt to send. + Boolean ShouldAttemptDelivery(IMessageHubMessage message); + + /// + /// Deliver the message. + /// + /// Message to deliver. + void Deliver(IMessageHubMessage message); + } } \ No newline at end of file diff --git a/Swan/Messaging/MessageHub.cs b/Swan/Messaging/MessageHub.cs index 7a1de9b..d914a0d 100644 --- a/Swan/Messaging/MessageHub.cs +++ b/Swan/Messaging/MessageHub.cs @@ -1,442 +1,371 @@ -// =============================================================================== -// TinyIoC - TinyMessenger -// -// A simple messenger/event aggregator. -// -// https://github.com/grumpydev/TinyIoC/blob/master/src/TinyIoC/TinyMessenger.cs -// =============================================================================== -// Copyright © Steven Robbins. All rights reserved. -// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY -// OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT -// LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -// FITNESS FOR A PARTICULAR PURPOSE. -// =============================================================================== - -namespace Swan.Messaging -{ - using System.Threading.Tasks; - using System; - using System.Collections.Generic; - using System.Linq; - - #region Message Types / Interfaces - +// =============================================================================== +// TinyIoC - TinyMessenger +// +// A simple messenger/event aggregator. +// +// https://github.com/grumpydev/TinyIoC/blob/master/src/TinyIoC/TinyMessenger.cs +// =============================================================================== +// Copyright © Steven Robbins. All rights reserved. +// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY +// OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT +// LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +// FITNESS FOR A PARTICULAR PURPOSE. +// =============================================================================== +#nullable enable +using System.Threading.Tasks; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Swan.Messaging { + #region Message Types / Interfaces + + /// + /// Message proxy definition. + /// + /// A message proxy can be used to intercept/alter messages and/or + /// marshal delivery actions onto a particular thread. + /// + public interface IMessageHubProxy { /// - /// Message proxy definition. + /// Delivers the specified message. + /// + /// The message. + /// The subscription. + void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription); + } + + /// + /// Default "pass through" proxy. + /// + /// Does nothing other than deliver the message. + /// + public sealed class MessageHubDefaultProxy : IMessageHubProxy { + private MessageHubDefaultProxy() { + // placeholder + } + + /// + /// Singleton instance of the proxy. + /// + public static MessageHubDefaultProxy Instance { get; } = new MessageHubDefaultProxy(); + + /// + /// Delivers the specified message. + /// + /// The message. + /// The subscription. + public void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription) => subscription.Deliver(message); + } + + #endregion + + #region Hub Interface + + /// + /// Messenger hub responsible for taking subscriptions/publications and delivering of messages. + /// + public interface IMessageHub { + /// + /// Subscribe to a message type with the given destination and delivery action. + /// Messages will be delivered via the specified proxy. /// - /// A message proxy can be used to intercept/alter messages and/or - /// marshal delivery actions onto a particular thread. + /// All messages of this type will be delivered. /// - public interface IMessageHubProxy - { - /// - /// Delivers the specified message. - /// - /// The message. - /// The subscription. - void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription); - } - + /// Type of message. + /// Action to invoke when message is delivered. + /// Use strong references to destination and deliveryAction. + /// Proxy to use when delivering the messages. + /// MessageSubscription used to unsubscribing. + MessageHubSubscriptionToken Subscribe(Action deliveryAction, Boolean useStrongReferences, IMessageHubProxy proxy) where TMessage : class, IMessageHubMessage; + /// - /// Default "pass through" proxy. + /// Subscribe to a message type with the given destination and delivery action with the given filter. + /// Messages will be delivered via the specified proxy. + /// All references are held with WeakReferences + /// Only messages that "pass" the filter will be delivered. + /// + /// Type of message. + /// Action to invoke when message is delivered. + /// The message filter. + /// Use strong references to destination and deliveryAction. + /// Proxy to use when delivering the messages. + /// + /// MessageSubscription used to unsubscribing. + /// + MessageHubSubscriptionToken Subscribe(Action deliveryAction, Func messageFilter, Boolean useStrongReferences, IMessageHubProxy proxy) where TMessage : class, IMessageHubMessage; + + /// + /// Unsubscribe from a particular message type. /// - /// Does nothing other than deliver the message. + /// Does not throw an exception if the subscription is not found. /// - public sealed class MessageHubDefaultProxy : IMessageHubProxy - { - private MessageHubDefaultProxy() - { - // placeholder - } - - /// - /// Singleton instance of the proxy. - /// - public static MessageHubDefaultProxy Instance { get; } = new MessageHubDefaultProxy(); - - /// - /// Delivers the specified message. - /// - /// The message. - /// The subscription. - public void Deliver(IMessageHubMessage message, IMessageHubSubscription subscription) - => subscription.Deliver(message); - } - - #endregion - - #region Hub Interface - + /// Type of message. + /// Subscription token received from Subscribe. + void Unsubscribe(MessageHubSubscriptionToken subscriptionToken) where TMessage : class, IMessageHubMessage; + /// - /// Messenger hub responsible for taking subscriptions/publications and delivering of messages. + /// Publish a message to any subscribers. /// - public interface IMessageHub - { - /// - /// Subscribe to a message type with the given destination and delivery action. - /// Messages will be delivered via the specified proxy. - /// - /// All messages of this type will be delivered. - /// - /// Type of message. - /// Action to invoke when message is delivered. - /// Use strong references to destination and deliveryAction. - /// Proxy to use when delivering the messages. - /// MessageSubscription used to unsubscribing. - MessageHubSubscriptionToken Subscribe( - Action deliveryAction, - bool useStrongReferences, - IMessageHubProxy proxy) - where TMessage : class, IMessageHubMessage; - - /// - /// Subscribe to a message type with the given destination and delivery action with the given filter. - /// Messages will be delivered via the specified proxy. - /// All references are held with WeakReferences - /// Only messages that "pass" the filter will be delivered. - /// - /// Type of message. - /// Action to invoke when message is delivered. - /// The message filter. - /// Use strong references to destination and deliveryAction. - /// Proxy to use when delivering the messages. - /// - /// MessageSubscription used to unsubscribing. - /// - MessageHubSubscriptionToken Subscribe( - Action deliveryAction, - Func messageFilter, - bool useStrongReferences, - IMessageHubProxy proxy) - where TMessage : class, IMessageHubMessage; - - /// - /// Unsubscribe from a particular message type. - /// - /// Does not throw an exception if the subscription is not found. - /// - /// Type of message. - /// Subscription token received from Subscribe. - void Unsubscribe(MessageHubSubscriptionToken subscriptionToken) - where TMessage : class, IMessageHubMessage; - - /// - /// Publish a message to any subscribers. - /// - /// Type of message. - /// Message to deliver. - void Publish(TMessage message) - where TMessage : class, IMessageHubMessage; - - /// - /// Publish a message to any subscribers asynchronously. - /// - /// Type of message. - /// Message to deliver. - /// A task from Publish action. - Task PublishAsync(TMessage message) - where TMessage : class, IMessageHubMessage; - } - + /// Type of message. + /// Message to deliver. + void Publish(TMessage message) where TMessage : class, IMessageHubMessage; + + /// + /// Publish a message to any subscribers asynchronously. + /// + /// Type of message. + /// Message to deliver. + /// A task from Publish action. + Task PublishAsync(TMessage message) where TMessage : class, IMessageHubMessage; + } + + #endregion + + #region Hub Implementation + + /// + /// + /// The following code describes how to use a MessageHub. Both the + /// subscription and the message sending are done in the same place but this is only for explanatory purposes. + /// + /// class Example + /// { + /// using Swan; + /// using Swan.Components; + /// + /// static void Main() + /// { + /// // using DependencyContainer to create an instance of MessageHub + /// var messageHub = DependencyContainer + /// .Current + /// .Resolve<IMessageHub>() as MessageHub; + /// + /// // create an instance of the publisher class + /// // which has a string as its content + /// var message = new MessageHubGenericMessage<string>(new object(), "SWAN"); + /// + /// // subscribe to the publisher's event + /// // and just print out the content which is a string + /// // a token is returned which can be used to unsubscribe later on + /// var token = messageHub + /// .Subscribe<MessageHubGenericMessage<string>>(m => m.Content.Info()); + /// + /// // publish the message described above which is + /// // the string 'SWAN' + /// messageHub.Publish(message); + /// + /// // unsuscribe, we will no longer receive any messages + /// messageHub.Unsubscribe<MessageHubGenericMessage<string>>(token); + /// + /// Terminal.Flush(); + /// } + /// + /// } + /// + /// + public sealed class MessageHub : IMessageHub { + #region Private Types and Interfaces + + private readonly Object _subscriptionsPadlock = new Object(); + + private readonly Dictionary> _subscriptions = new Dictionary>(); + + private class WeakMessageSubscription : IMessageHubSubscription where TMessage : class, IMessageHubMessage { + private readonly WeakReference _deliveryAction; + private readonly WeakReference _messageFilter; + + /// + /// Initializes a new instance of the class. + /// + /// The subscription token. + /// The delivery action. + /// The message filter. + /// subscriptionToken + /// or + /// deliveryAction + /// or + /// messageFilter. + public WeakMessageSubscription(MessageHubSubscriptionToken subscriptionToken, Action deliveryAction, Func messageFilter) { + this.SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken)); + this._deliveryAction = new WeakReference(deliveryAction); + this._messageFilter = new WeakReference(messageFilter); + } + + public MessageHubSubscriptionToken SubscriptionToken { + get; + } + + public Boolean ShouldAttemptDelivery(IMessageHubMessage message) => this._deliveryAction.IsAlive && this._messageFilter.IsAlive && ((Func)this._messageFilter.Target!).Invoke((TMessage)message); + + public void Deliver(IMessageHubMessage message) { + if(this._deliveryAction.IsAlive) { + ((Action)this._deliveryAction.Target!).Invoke((TMessage)message); + } + } + } + + private class StrongMessageSubscription : IMessageHubSubscription where TMessage : class, IMessageHubMessage { + private readonly Action _deliveryAction; + private readonly Func _messageFilter; + + /// + /// Initializes a new instance of the class. + /// + /// The subscription token. + /// The delivery action. + /// The message filter. + /// subscriptionToken + /// or + /// deliveryAction + /// or + /// messageFilter. + public StrongMessageSubscription(MessageHubSubscriptionToken subscriptionToken, Action deliveryAction, Func messageFilter) { + this.SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken)); + this._deliveryAction = deliveryAction; + this._messageFilter = messageFilter; + } + + public MessageHubSubscriptionToken SubscriptionToken { + get; + } + + public Boolean ShouldAttemptDelivery(IMessageHubMessage message) => this._messageFilter.Invoke((TMessage)message); + + public void Deliver(IMessageHubMessage message) => this._deliveryAction.Invoke((TMessage)message); + } + #endregion - - #region Hub Implementation - + + #region Subscription dictionary + + private class SubscriptionItem { + public SubscriptionItem(IMessageHubProxy proxy, IMessageHubSubscription subscription) { + this.Proxy = proxy; + this.Subscription = subscription; + } + + public IMessageHubProxy Proxy { + get; + } + public IMessageHubSubscription Subscription { + get; + } + } + + #endregion + + #region Public API + + /// + /// Subscribe to a message type with the given destination and delivery action. + /// Messages will be delivered via the specified proxy. + /// + /// All messages of this type will be delivered. + /// + /// Type of message. + /// Action to invoke when message is delivered. + /// Use strong references to destination and deliveryAction. + /// Proxy to use when delivering the messages. + /// MessageSubscription used to unsubscribing. + public MessageHubSubscriptionToken Subscribe(Action deliveryAction, Boolean useStrongReferences = true, IMessageHubProxy? proxy = null) where TMessage : class, IMessageHubMessage => this.Subscribe(deliveryAction, m => true, useStrongReferences, proxy); + + + /// + /// Subscribe to a message type with the given destination and delivery action with the given filter. + /// Messages will be delivered via the specified proxy. + /// All references are held with WeakReferences + /// Only messages that "pass" the filter will be delivered. + /// + /// Type of message. + /// Action to invoke when message is delivered. + /// The message filter. + /// Use strong references to destination and deliveryAction. + /// Proxy to use when delivering the messages. + /// + /// MessageSubscription used to unsubscribing. + /// + [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0068:Empfohlenes Dispose-Muster verwenden", Justification = "")] + public MessageHubSubscriptionToken Subscribe(Action deliveryAction, Func messageFilter, Boolean useStrongReferences = true, IMessageHubProxy? proxy = null) where TMessage : class, IMessageHubMessage { + if(deliveryAction == null) { + throw new ArgumentNullException(nameof(deliveryAction)); + } + + if(messageFilter == null) { + throw new ArgumentNullException(nameof(messageFilter)); + } + + lock(this._subscriptionsPadlock) { + if(!this._subscriptions.TryGetValue(typeof(TMessage), out List? currentSubscriptions)) { + currentSubscriptions = new List(); + this._subscriptions[typeof(TMessage)] = currentSubscriptions; + } + + MessageHubSubscriptionToken subscriptionToken = new MessageHubSubscriptionToken(this, typeof(TMessage)); + + IMessageHubSubscription subscription = useStrongReferences ? new StrongMessageSubscription(subscriptionToken, deliveryAction, messageFilter) : (IMessageHubSubscription)new WeakMessageSubscription(subscriptionToken, deliveryAction, messageFilter); + + currentSubscriptions.Add(new SubscriptionItem(proxy ?? MessageHubDefaultProxy.Instance, subscription)); + + return subscriptionToken; + } + } + /// - /// - /// The following code describes how to use a MessageHub. Both the - /// subscription and the message sending are done in the same place but this is only for explanatory purposes. - /// - /// class Example - /// { - /// using Swan; - /// using Swan.Components; - /// - /// static void Main() - /// { - /// // using DependencyContainer to create an instance of MessageHub - /// var messageHub = DependencyContainer - /// .Current - /// .Resolve<IMessageHub>() as MessageHub; - /// - /// // create an instance of the publisher class - /// // which has a string as its content - /// var message = new MessageHubGenericMessage<string>(new object(), "SWAN"); - /// - /// // subscribe to the publisher's event - /// // and just print out the content which is a string - /// // a token is returned which can be used to unsubscribe later on - /// var token = messageHub - /// .Subscribe<MessageHubGenericMessage<string>>(m => m.Content.Info()); - /// - /// // publish the message described above which is - /// // the string 'SWAN' - /// messageHub.Publish(message); - /// - /// // unsuscribe, we will no longer receive any messages - /// messageHub.Unsubscribe<MessageHubGenericMessage<string>>(token); - /// - /// Terminal.Flush(); - /// } - /// - /// } - /// - /// - public sealed class MessageHub : IMessageHub - { - #region Private Types and Interfaces - - private readonly object _subscriptionsPadlock = new object(); - - private readonly Dictionary> _subscriptions = - new Dictionary>(); - - private class WeakMessageSubscription : IMessageHubSubscription - where TMessage : class, IMessageHubMessage - { - private readonly WeakReference _deliveryAction; - private readonly WeakReference _messageFilter; - - /// - /// Initializes a new instance of the class. - /// - /// The subscription token. - /// The delivery action. - /// The message filter. - /// subscriptionToken - /// or - /// deliveryAction - /// or - /// messageFilter. - public WeakMessageSubscription( - MessageHubSubscriptionToken subscriptionToken, - Action deliveryAction, - Func messageFilter) - { - SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken)); - _deliveryAction = new WeakReference(deliveryAction); - _messageFilter = new WeakReference(messageFilter); - } - - public MessageHubSubscriptionToken SubscriptionToken { get; } - - public bool ShouldAttemptDelivery(IMessageHubMessage message) - { - return _deliveryAction.IsAlive && _messageFilter.IsAlive && - ((Func) _messageFilter.Target).Invoke((TMessage) message); - } - - public void Deliver(IMessageHubMessage message) - { - if (_deliveryAction.IsAlive) - { - ((Action) _deliveryAction.Target).Invoke((TMessage) message); - } - } - } - - private class StrongMessageSubscription : IMessageHubSubscription - where TMessage : class, IMessageHubMessage - { - private readonly Action _deliveryAction; - private readonly Func _messageFilter; - - /// - /// Initializes a new instance of the class. - /// - /// The subscription token. - /// The delivery action. - /// The message filter. - /// subscriptionToken - /// or - /// deliveryAction - /// or - /// messageFilter. - public StrongMessageSubscription( - MessageHubSubscriptionToken subscriptionToken, - Action deliveryAction, - Func messageFilter) - { - SubscriptionToken = subscriptionToken ?? throw new ArgumentNullException(nameof(subscriptionToken)); - _deliveryAction = deliveryAction; - _messageFilter = messageFilter; - } - - public MessageHubSubscriptionToken SubscriptionToken { get; } - - public bool ShouldAttemptDelivery(IMessageHubMessage message) => _messageFilter.Invoke((TMessage) message); - - public void Deliver(IMessageHubMessage message) => _deliveryAction.Invoke((TMessage) message); - } - - #endregion - - #region Subscription dictionary - - private class SubscriptionItem - { - public SubscriptionItem(IMessageHubProxy proxy, IMessageHubSubscription subscription) - { - Proxy = proxy; - Subscription = subscription; - } - - public IMessageHubProxy Proxy { get; } - public IMessageHubSubscription Subscription { get; } - } - - #endregion - - #region Public API - - /// - /// Subscribe to a message type with the given destination and delivery action. - /// Messages will be delivered via the specified proxy. - /// - /// All messages of this type will be delivered. - /// - /// Type of message. - /// Action to invoke when message is delivered. - /// Use strong references to destination and deliveryAction. - /// Proxy to use when delivering the messages. - /// MessageSubscription used to unsubscribing. - public MessageHubSubscriptionToken Subscribe( - Action deliveryAction, - bool useStrongReferences = true, - IMessageHubProxy? proxy = null) - where TMessage : class, IMessageHubMessage - { - return Subscribe(deliveryAction, m => true, useStrongReferences, proxy); - } - - /// - /// Subscribe to a message type with the given destination and delivery action with the given filter. - /// Messages will be delivered via the specified proxy. - /// All references are held with WeakReferences - /// Only messages that "pass" the filter will be delivered. - /// - /// Type of message. - /// Action to invoke when message is delivered. - /// The message filter. - /// Use strong references to destination and deliveryAction. - /// Proxy to use when delivering the messages. - /// - /// MessageSubscription used to unsubscribing. - /// - public MessageHubSubscriptionToken Subscribe( - Action deliveryAction, - Func messageFilter, - bool useStrongReferences = true, - IMessageHubProxy? proxy = null) - where TMessage : class, IMessageHubMessage - { - if (deliveryAction == null) - throw new ArgumentNullException(nameof(deliveryAction)); - - if (messageFilter == null) - throw new ArgumentNullException(nameof(messageFilter)); - - lock (_subscriptionsPadlock) - { - if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions)) - { - currentSubscriptions = new List(); - _subscriptions[typeof(TMessage)] = currentSubscriptions; - } - - var subscriptionToken = new MessageHubSubscriptionToken(this, typeof(TMessage)); - - IMessageHubSubscription subscription; - if (useStrongReferences) - { - subscription = new StrongMessageSubscription( - subscriptionToken, - deliveryAction, - messageFilter); - } - else - { - subscription = new WeakMessageSubscription( - subscriptionToken, - deliveryAction, - messageFilter); - } - - currentSubscriptions.Add(new SubscriptionItem(proxy ?? MessageHubDefaultProxy.Instance, subscription)); - - return subscriptionToken; - } - } - - /// - public void Unsubscribe(MessageHubSubscriptionToken subscriptionToken) - where TMessage : class, IMessageHubMessage - { - if (subscriptionToken == null) - throw new ArgumentNullException(nameof(subscriptionToken)); - - lock (_subscriptionsPadlock) - { - if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions)) - return; - - var currentlySubscribed = currentSubscriptions - .Where(sub => ReferenceEquals(sub.Subscription.SubscriptionToken, subscriptionToken)) - .ToList(); - - currentlySubscribed.ForEach(sub => currentSubscriptions.Remove(sub)); - } - } - - /// - /// Publish a message to any subscribers. - /// - /// Type of message. - /// Message to deliver. - public void Publish(TMessage message) - where TMessage : class, IMessageHubMessage - { - if (message == null) - throw new ArgumentNullException(nameof(message)); - - List currentlySubscribed; - lock (_subscriptionsPadlock) - { - if (!_subscriptions.TryGetValue(typeof(TMessage), out var currentSubscriptions)) - return; - - currentlySubscribed = currentSubscriptions - .Where(sub => sub.Subscription.ShouldAttemptDelivery(message)) - .ToList(); - } - - currentlySubscribed.ForEach(sub => - { - try - { - sub.Proxy.Deliver(message, sub.Subscription); - } - catch - { - // Ignore any errors and carry on - } - }); - } - - /// - /// Publish a message to any subscribers asynchronously. - /// - /// Type of message. - /// Message to deliver. - /// A task with the publish. - public Task PublishAsync(TMessage message) - where TMessage : class, IMessageHubMessage - { - return Task.Run(() => Publish(message)); - } - - #endregion - } - + public void Unsubscribe(MessageHubSubscriptionToken subscriptionToken) where TMessage : class, IMessageHubMessage { + if(subscriptionToken == null) { + throw new ArgumentNullException(nameof(subscriptionToken)); + } + + lock(this._subscriptionsPadlock) { + if(!this._subscriptions.TryGetValue(typeof(TMessage), out List? currentSubscriptions)) { + return; + } + + List currentlySubscribed = currentSubscriptions.Where(sub => ReferenceEquals(sub.Subscription.SubscriptionToken, subscriptionToken)).ToList(); + + currentlySubscribed.ForEach(sub => currentSubscriptions.Remove(sub)); + } + } + + /// + /// Publish a message to any subscribers. + /// + /// Type of message. + /// Message to deliver. + public void Publish(TMessage message) where TMessage : class, IMessageHubMessage { + if(message == null) { + throw new ArgumentNullException(nameof(message)); + } + + List currentlySubscribed; + lock(this._subscriptionsPadlock) { + if(!this._subscriptions.TryGetValue(typeof(TMessage), out List? currentSubscriptions)) { + return; + } + + currentlySubscribed = currentSubscriptions.Where(sub => sub.Subscription.ShouldAttemptDelivery(message)).ToList(); + } + + currentlySubscribed.ForEach(sub => { + try { + sub.Proxy.Deliver(message, sub.Subscription); + } catch { + // Ignore any errors and carry on + } + }); + } + + /// + /// Publish a message to any subscribers asynchronously. + /// + /// Type of message. + /// Message to deliver. + /// A task with the publish. + public Task PublishAsync(TMessage message) where TMessage : class, IMessageHubMessage => Task.Run(() => this.Publish(message)); + #endregion + } + + #endregion } diff --git a/Swan/Messaging/MessageHubMessageBase.cs b/Swan/Messaging/MessageHubMessageBase.cs index bdcdf6f..ff606b0 100644 --- a/Swan/Messaging/MessageHubMessageBase.cs +++ b/Swan/Messaging/MessageHubMessageBase.cs @@ -1,57 +1,50 @@ -namespace Swan.Messaging -{ - using System; - +using System; + +namespace Swan.Messaging { + /// + /// Base class for messages that provides weak reference storage of the sender. + /// + public abstract class MessageHubMessageBase : IMessageHubMessage { /// - /// Base class for messages that provides weak reference storage of the sender. + /// Store a WeakReference to the sender just in case anyone is daft enough to + /// keep the message around and prevent the sender from being collected. /// - public abstract class MessageHubMessageBase - : IMessageHubMessage - { - /// - /// Store a WeakReference to the sender just in case anyone is daft enough to - /// keep the message around and prevent the sender from being collected. - /// - private readonly WeakReference _sender; - - /// - /// Initializes a new instance of the class. - /// - /// The sender. - /// sender. - protected MessageHubMessageBase(object sender) - { - if (sender == null) - throw new ArgumentNullException(nameof(sender)); - - _sender = new WeakReference(sender); - } - - /// - public object Sender => _sender.Target; - } - + private readonly WeakReference _sender; + /// - /// Generic message with user specified content. + /// Initializes a new instance of the class. /// - /// Content type to store. - public class MessageHubGenericMessage - : MessageHubMessageBase - { - /// - /// Initializes a new instance of the class. - /// - /// The sender. - /// The content. - public MessageHubGenericMessage(object sender, TContent content) - : base(sender) - { - Content = content; - } - - /// - /// Contents of the message. - /// - public TContent Content { get; protected set; } - } + /// The sender. + /// sender. + protected MessageHubMessageBase(Object sender) { + if(sender == null) { + throw new ArgumentNullException(nameof(sender)); + } + + this._sender = new WeakReference(sender); + } + + /// + public Object Sender => this._sender.Target; + } + + /// + /// Generic message with user specified content. + /// + /// Content type to store. + public class MessageHubGenericMessage : MessageHubMessageBase { + /// + /// Initializes a new instance of the class. + /// + /// The sender. + /// The content. + public MessageHubGenericMessage(Object sender, TContent content) : base(sender) => this.Content = content; + + /// + /// Contents of the message. + /// + public TContent Content { + get; protected set; + } + } } diff --git a/Swan/Messaging/MessageHubSubscriptionToken.cs b/Swan/Messaging/MessageHubSubscriptionToken.cs index 4cc6a71..d6bd50a 100644 --- a/Swan/Messaging/MessageHubSubscriptionToken.cs +++ b/Swan/Messaging/MessageHubSubscriptionToken.cs @@ -1,51 +1,43 @@ -namespace Swan.Messaging -{ - using System; - +using System; +using System.Reflection; + +namespace Swan.Messaging { + /// + /// Represents an active subscription to a message. + /// + public sealed class MessageHubSubscriptionToken : IDisposable { + private readonly WeakReference _hub; + private readonly Type _messageType; + /// - /// Represents an active subscription to a message. + /// Initializes a new instance of the class. /// - public sealed class MessageHubSubscriptionToken - : IDisposable - { - private readonly WeakReference _hub; - private readonly Type _messageType; - - /// - /// Initializes a new instance of the class. - /// - /// The hub. - /// Type of the message. - /// hub. - /// messageType. - public MessageHubSubscriptionToken(IMessageHub hub, Type messageType) - { - if (hub == null) - { - throw new ArgumentNullException(nameof(hub)); - } - - if (!typeof(IMessageHubMessage).IsAssignableFrom(messageType)) - { - throw new ArgumentOutOfRangeException(nameof(messageType)); - } - - _hub = new WeakReference(hub); - _messageType = messageType; - } - - /// - public void Dispose() - { - if (_hub.IsAlive && _hub.Target is IMessageHub hub) - { - var unsubscribeMethod = typeof(IMessageHub).GetMethod(nameof(IMessageHub.Unsubscribe), - new[] {typeof(MessageHubSubscriptionToken)}); - unsubscribeMethod = unsubscribeMethod.MakeGenericMethod(_messageType); - unsubscribeMethod.Invoke(hub, new object[] {this}); - } - - GC.SuppressFinalize(this); - } - } + /// The hub. + /// Type of the message. + /// hub. + /// messageType. + public MessageHubSubscriptionToken(IMessageHub hub, Type messageType) { + if(hub == null) { + throw new ArgumentNullException(nameof(hub)); + } + + if(!typeof(IMessageHubMessage).IsAssignableFrom(messageType)) { + throw new ArgumentOutOfRangeException(nameof(messageType)); + } + + this._hub = new WeakReference(hub); + this._messageType = messageType; + } + + /// + public void Dispose() { + if(this._hub.IsAlive && this._hub.Target is IMessageHub hub) { + MethodInfo unsubscribeMethod = typeof(IMessageHub).GetMethod(nameof(IMessageHub.Unsubscribe), new[] { typeof(MessageHubSubscriptionToken) }); + unsubscribeMethod = unsubscribeMethod.MakeGenericMethod(this._messageType); + _ = unsubscribeMethod.Invoke(hub, new Object[] { this }); + } + + GC.SuppressFinalize(this); + } + } } \ No newline at end of file diff --git a/Swan/Net/Connection.cs b/Swan/Net/Connection.cs index ff25bee..fa504af 100644 --- a/Swan/Net/Connection.cs +++ b/Swan/Net/Connection.cs @@ -1,886 +1,836 @@ -namespace Swan.Net -{ - using Logging; - using System; - using System.Collections.Generic; - using System.IO; - using System.Linq; - using System.Net; - using System.Net.Security; - using System.Net.Sockets; - using System.Security.Cryptography.X509Certificates; - using System.Text; - using System.Threading; - using System.Threading.Tasks; - +#nullable enable +using Swan.Logging; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Net { + /// + /// Represents a network connection either on the server or on the client. It wraps a TcpClient + /// and its corresponding network streams. It is capable of working in 2 modes. Typically on the server side + /// you will need to enable continuous reading and events. On the client side you may want to disable continuous reading + /// and use the Read methods available. In continuous reading mode Read methods are not available and will throw + /// an invalid operation exceptions if they are used. + /// Continuous Reading Mode: Subscribe to data reception events, it runs a background thread, don't use Read methods + /// Manual Reading Mode: Data reception events are NEVER fired. No background threads are used. Use Read methods to receive data. + /// + /// + /// + /// The following code explains how to create a TCP server. + /// + /// using System.Text; + /// using Swan.Net; + /// + /// class Example + /// { + /// static void Main() + /// { + /// // create a new connection listener on a specific port + /// var connectionListener = new ConnectionListener(1337); + /// + /// // handle the OnConnectionAccepting event + /// connectionListener.OnConnectionAccepted += async (s, e) => + /// { + /// // create a new connection + /// using (var con = new Connection(e.Client)) + /// { + /// await con.WriteLineAsync("Hello world!"); + /// } + /// }; + /// + /// connectionListener.Start(); + /// Console.ReadLine)=ñ + /// } + /// } + /// + /// The following code describes how to create a TCP client. + /// + /// using System.Net.Sockets; + /// using System.Text; + /// using System.Threading.Tasks; + /// using Swan.Net; + /// + /// class Example + /// { + /// static async Task Main() + /// { + /// // create a new TcpClient object + /// var client = new TcpClient(); + /// + /// // connect to a specific address and port + /// client.Connect("localhost", 1337); + /// + /// //create a new connection with specific encoding, + /// //new line sequence and continuous reading disabled + /// using (var cn = new Connection(client, Encoding.UTF8, "\r\n", true, 0)) + /// { + /// var response = await cn.ReadTextAsync(); + /// } + /// } + /// } + /// + /// + public sealed class Connection : IDisposable { + // New Line definitions for reading. This applies to both, events and read methods + private readonly String _newLineSequence; + + private readonly Byte[] _newLineSequenceBytes; + private readonly Char[] _newLineSequenceChars; + private readonly String[] _newLineSequenceLineSplitter; + private readonly Byte[] _receiveBuffer; + private readonly TimeSpan _continuousReadingInterval = TimeSpan.FromMilliseconds(5); + private readonly Queue _readLineBuffer = new Queue(); + private readonly ManualResetEvent _writeDone = new ManualResetEvent(true); + + // Disconnect and Dispose + private Boolean _hasDisposed; + + private Int32 _disconnectCalls; + + // Continuous Reading + private Thread? _continuousReadingThread; + + private Int32 _receiveBufferPointer; + + // Reading and writing + private Task? _readTask; + /// - /// Represents a network connection either on the server or on the client. It wraps a TcpClient - /// and its corresponding network streams. It is capable of working in 2 modes. Typically on the server side - /// you will need to enable continuous reading and events. On the client side you may want to disable continuous reading - /// and use the Read methods available. In continuous reading mode Read methods are not available and will throw - /// an invalid operation exceptions if they are used. - /// Continuous Reading Mode: Subscribe to data reception events, it runs a background thread, don't use Read methods - /// Manual Reading Mode: Data reception events are NEVER fired. No background threads are used. Use Read methods to receive data. + /// Initializes a new instance of the class. /// - /// - /// - /// The following code explains how to create a TCP server. - /// - /// using System.Text; - /// using Swan.Net; - /// - /// class Example - /// { - /// static void Main() - /// { - /// // create a new connection listener on a specific port - /// var connectionListener = new ConnectionListener(1337); - /// - /// // handle the OnConnectionAccepting event - /// connectionListener.OnConnectionAccepted += async (s, e) => - /// { - /// // create a new connection - /// using (var con = new Connection(e.Client)) - /// { - /// await con.WriteLineAsync("Hello world!"); - /// } - /// }; - /// - /// connectionListener.Start(); - /// Console.ReadLine)=ñ - /// } - /// } - /// - /// The following code describes how to create a TCP client. - /// - /// using System.Net.Sockets; - /// using System.Text; - /// using System.Threading.Tasks; - /// using Swan.Net; - /// - /// class Example - /// { - /// static async Task Main() - /// { - /// // create a new TcpClient object - /// var client = new TcpClient(); - /// - /// // connect to a specific address and port - /// client.Connect("localhost", 1337); - /// - /// //create a new connection with specific encoding, - /// //new line sequence and continuous reading disabled - /// using (var cn = new Connection(client, Encoding.UTF8, "\r\n", true, 0)) - /// { - /// var response = await cn.ReadTextAsync(); - /// } - /// } - /// } - /// - /// - public sealed class Connection : IDisposable - { - // New Line definitions for reading. This applies to both, events and read methods - private readonly string _newLineSequence; - - private readonly byte[] _newLineSequenceBytes; - private readonly char[] _newLineSequenceChars; - private readonly string[] _newLineSequenceLineSplitter; - private readonly byte[] _receiveBuffer; - private readonly TimeSpan _continuousReadingInterval = TimeSpan.FromMilliseconds(5); - private readonly Queue _readLineBuffer = new Queue(); - private readonly ManualResetEvent _writeDone = new ManualResetEvent(true); - - // Disconnect and Dispose - private bool _hasDisposed; - - private int _disconnectCalls; - - // Continuous Reading - private Thread _continuousReadingThread; - - private int _receiveBufferPointer; - - // Reading and writing - private Task _readTask; - - /// - /// Initializes a new instance of the class. - /// - /// The client. - /// The text encoding. - /// The new line sequence used for read and write operations. - /// if set to true [disable continuous reading]. - /// Size of the block. -- set to 0 or less to disable. - public Connection( - TcpClient client, - Encoding textEncoding, - string newLineSequence, - bool disableContinuousReading, - int blockSize) - { - // Setup basic properties - Id = Guid.NewGuid(); - TextEncoding = textEncoding; - - // Setup new line sequence - if (string.IsNullOrEmpty(newLineSequence)) - throw new ArgumentException("Argument cannot be null", nameof(newLineSequence)); - - _newLineSequence = newLineSequence; - _newLineSequenceBytes = TextEncoding.GetBytes(_newLineSequence); - _newLineSequenceChars = _newLineSequence.ToCharArray(); - _newLineSequenceLineSplitter = new[] { _newLineSequence }; - - // Setup Connection timers - ConnectionStartTimeUtc = DateTime.UtcNow; - DataReceivedLastTimeUtc = ConnectionStartTimeUtc; - DataSentLastTimeUtc = ConnectionStartTimeUtc; - - // Setup connection properties - RemoteClient = client; - LocalEndPoint = client.Client.LocalEndPoint as IPEndPoint; - NetworkStream = RemoteClient.GetStream(); - RemoteEndPoint = RemoteClient.Client.RemoteEndPoint as IPEndPoint; - - // Setup buffers - _receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; - ProtocolBlockSize = blockSize; - _receiveBufferPointer = 0; - - // Setup continuous reading mode if enabled - if (disableContinuousReading) return; - - ThreadPool.GetAvailableThreads(out var availableWorkerThreads, out _); - ThreadPool.GetMaxThreads(out var maxWorkerThreads, out _); - - var activeThreadPoolTreads = maxWorkerThreads - availableWorkerThreads; - - if (activeThreadPoolTreads < Environment.ProcessorCount / 4) - { - ThreadPool.QueueUserWorkItem(PerformContinuousReading, this); - } - else - { - new Thread(PerformContinuousReading) { IsBackground = true }.Start(); - } - } - - /// - /// Initializes a new instance of the class in continuous reading mode. - /// It uses UTF8 encoding, CRLF as a new line sequence and disables a protocol block size. - /// - /// The client. - public Connection(TcpClient client) - : this(client, Encoding.UTF8, "\r\n", false, 0) - { - // placeholder - } - - /// - /// Initializes a new instance of the class in continuous reading mode. - /// It uses UTF8 encoding, disables line sequences, and uses a protocol block size instead. - /// - /// The client. - /// Size of the block. - public Connection(TcpClient client, int blockSize) - : this(client, Encoding.UTF8, new string('\n', blockSize + 1), false, blockSize) - { - // placeholder - } - - #region Events - - /// - /// Occurs when the receive buffer has encounters a new line sequence, the buffer is flushed or the buffer is full. - /// - public event EventHandler DataReceived = (s, e) => { }; - - /// - /// Occurs when an error occurs while upgrading, sending, or receiving data in this client - /// - public event EventHandler ConnectionFailure = (s, e) => { }; - - /// - /// Occurs when a client is disconnected - /// - public event EventHandler ClientDisconnected = (s, e) => { }; - - #endregion - - #region Properties - - /// - /// Gets the unique identifier of this connection. - /// This field is filled out upon instantiation of this class. - /// - /// - /// The identifier. - /// - public Guid Id { get; } - - /// - /// Gets the active stream. Returns an SSL stream if the connection is secure, otherwise returns - /// the underlying NetworkStream. - /// - /// - /// The active stream. - /// - public Stream ActiveStream => SecureStream ?? NetworkStream as Stream; - - /// - /// Gets a value indicating whether the current connection stream is an SSL stream. - /// - /// - /// true if this instance is active stream secure; otherwise, false. - /// - public bool IsActiveStreamSecure => SecureStream != null; - - /// - /// Gets the text encoding for send and receive operations. - /// - /// - /// The text encoding. - /// - public Encoding TextEncoding { get; } - - /// - /// Gets the remote end point of this TCP connection. - /// - /// - /// The remote end point. - /// - public IPEndPoint RemoteEndPoint { get; } - - /// - /// Gets the local end point of this TCP connection. - /// - /// - /// The local end point. - /// - public IPEndPoint LocalEndPoint { get; } - - /// - /// Gets the remote client of this TCP connection. - /// - /// - /// The remote client. - /// - public TcpClient RemoteClient { get; private set; } - - /// - /// When in continuous reading mode, and if set to greater than 0, - /// a Data reception event will be fired whenever the amount of bytes - /// determined by this property has been received. Useful for fixed-length message protocols. - /// - /// - /// The size of the protocol block. - /// - public int ProtocolBlockSize { get; } - - /// - /// Gets a value indicating whether this connection is in continuous reading mode. - /// Remark: Whenever a disconnect event occurs, the background thread is terminated - /// and this property will return false whenever the reading thread is not active. - /// Therefore, even if continuous reading was not disabled in the constructor, this property - /// might return false. - /// - /// - /// true if this instance is continuous reading enabled; otherwise, false. - /// - public bool IsContinuousReadingEnabled => _continuousReadingThread != null; - - /// - /// Gets the start time at which the connection was started in UTC. - /// - /// - /// The connection start time UTC. - /// - public DateTime ConnectionStartTimeUtc { get; } - - /// - /// Gets the start time at which the connection was started in local time. - /// - /// - /// The connection start time. - /// - public DateTime ConnectionStartTime => ConnectionStartTimeUtc.ToLocalTime(); - - /// - /// Gets the duration of the connection. - /// - /// - /// The duration of the connection. - /// - public TimeSpan ConnectionDuration => DateTime.UtcNow.Subtract(ConnectionStartTimeUtc); - - /// - /// Gets the last time data was received at in UTC. - /// - /// - /// The data received last time UTC. - /// - public DateTime DataReceivedLastTimeUtc { get; private set; } - - /// - /// Gets how long has elapsed since data was last received. - /// - public TimeSpan DataReceivedIdleDuration => DateTime.UtcNow.Subtract(DataReceivedLastTimeUtc); - - /// - /// Gets the last time at which data was sent in UTC. - /// - /// - /// The data sent last time UTC. - /// - public DateTime DataSentLastTimeUtc { get; private set; } - - /// - /// Gets how long has elapsed since data was last sent. - /// - /// - /// The duration of the data sent idle. - /// - public TimeSpan DataSentIdleDuration => DateTime.UtcNow.Subtract(DataSentLastTimeUtc); - - /// - /// Gets a value indicating whether this connection is connected. - /// Remarks: This property polls the socket internally and checks if it is available to read data from it. - /// If disconnect has been called, then this property will return false. - /// - /// - /// true if this instance is connected; otherwise, false. - /// - public bool IsConnected - { - get - { - if (_disconnectCalls > 0) - return false; - - try - { - var socket = RemoteClient.Client; - var pollResult = !((socket.Poll(1000, SelectMode.SelectRead) - && (NetworkStream.DataAvailable == false)) || !socket.Connected); - - if (pollResult == false) - Disconnect(); - - return pollResult; - } - catch - { - Disconnect(); - return false; - } - } - } - - private NetworkStream NetworkStream { get; set; } - - private SslStream SecureStream { get; set; } - - #endregion - - #region Read Methods - - /// - /// Reads data from the remote client asynchronously and with the given timeout. - /// - /// The timeout. - /// The cancellation token. - /// A byte array containing the results of encoding the specified set of characters. - /// Read methods have been disabled because continuous reading is enabled. - /// Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} m. - public async Task ReadDataAsync(TimeSpan timeout, CancellationToken cancellationToken = default) - { - if (IsContinuousReadingEnabled) - { - throw new InvalidOperationException( - "Read methods have been disabled because continuous reading is enabled."); - } - - if (RemoteClient == null) - { - throw new InvalidOperationException("An open connection is required"); - } - - var receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; - var receiveBuilder = new List(receiveBuffer.Length); - - try - { - var startTime = DateTime.UtcNow; - - while (receiveBuilder.Count <= 0) - { - if (DateTime.UtcNow.Subtract(startTime) >= timeout) - { - throw new TimeoutException( - $"Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} ms"); - } - - if (_readTask == null) - _readTask = ActiveStream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length, cancellationToken); - - if (_readTask.Wait(_continuousReadingInterval)) - { - var bytesReceivedCount = _readTask.Result; - if (bytesReceivedCount > 0) - { - DataReceivedLastTimeUtc = DateTime.UtcNow; - var buffer = new byte[bytesReceivedCount]; - Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); - receiveBuilder.AddRange(buffer); - } - - _readTask = null; - } - else - { - await Task.Delay(_continuousReadingInterval, cancellationToken).ConfigureAwait(false); - } - } - } - catch (Exception ex) - { - ex.Error(typeof(Connection).FullName, "Error while reading network stream data asynchronously."); - throw; - } - - return receiveBuilder.ToArray(); - } - - /// - /// Reads data asynchronously from the remote stream with a 5000 millisecond timeout. - /// - /// The cancellation token. - /// - /// A byte array containing the results the specified sequence of bytes. - /// - public Task ReadDataAsync(CancellationToken cancellationToken = default) - => ReadDataAsync(TimeSpan.FromSeconds(5), cancellationToken); - - /// - /// Asynchronously reads data as text with the given timeout. - /// - /// The timeout. - /// The cancellation token. - /// - /// A that contains the results of decoding the specified sequence of bytes. - /// - public async Task ReadTextAsync(TimeSpan timeout, CancellationToken cancellationToken = default) - { - var buffer = await ReadDataAsync(timeout, cancellationToken).ConfigureAwait(false); - return buffer == null ? null : TextEncoding.GetString(buffer); - } - - /// - /// Asynchronously reads data as text with a 5000 millisecond timeout. - /// - /// The cancellation token. - /// - /// When this method completes successfully, it returns the contents of the file as a text string. - /// - public Task ReadTextAsync(CancellationToken cancellationToken = default) - => ReadTextAsync(TimeSpan.FromSeconds(5), cancellationToken); - - /// - /// Performs the same task as this method's overload but it defaults to a read timeout of 30 seconds. - /// - /// The cancellation token. - /// - /// A task that represents the asynchronous read operation. The value of the TResult parameter - /// contains the next line from the stream, or is null if all the characters have been read. - /// - public Task ReadLineAsync(CancellationToken cancellationToken = default) - => ReadLineAsync(TimeSpan.FromSeconds(30), cancellationToken); - - /// - /// Reads the next available line of text in queue. Return null when no text is read. - /// This method differs from the rest of the read methods because it keeps an internal - /// queue of lines that are read from the stream and only returns the one line next in the queue. - /// It is only recommended to use this method when you are working with text-based protocols - /// and the rest of the read methods are not called. - /// - /// The timeout. - /// The cancellation token. - /// A task with a string line from the queue. - /// Read methods have been disabled because continuous reading is enabled. - public async Task ReadLineAsync(TimeSpan timeout, CancellationToken cancellationToken = default) - { - if (IsContinuousReadingEnabled) - { - throw new InvalidOperationException( - "Read methods have been disabled because continuous reading is enabled."); - } - - if (_readLineBuffer.Count > 0) - return _readLineBuffer.Dequeue(); - - var builder = new StringBuilder(); - - while (true) - { - var text = await ReadTextAsync(timeout, cancellationToken).ConfigureAwait(false); - - if (string.IsNullOrEmpty(text)) - break; - - builder.Append(text); - - if (!text.EndsWith(_newLineSequence)) continue; - - var lines = builder.ToString().TrimEnd(_newLineSequenceChars) - .Split(_newLineSequenceLineSplitter, StringSplitOptions.None); - foreach (var item in lines) - _readLineBuffer.Enqueue(item); - - break; - } - - return _readLineBuffer.Count > 0 ? _readLineBuffer.Dequeue() : null; - } - - #endregion - - #region Write Methods - - /// - /// Writes data asynchronously. - /// - /// The buffer. - /// if set to true [force flush]. - /// The cancellation token. - /// A task that represents the asynchronous write operation. - public async Task WriteDataAsync(byte[] buffer, bool forceFlush, CancellationToken cancellationToken = default) - { - try - { - _writeDone.WaitOne(); - _writeDone.Reset(); - await ActiveStream.WriteAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); - if (forceFlush) - await ActiveStream.FlushAsync(cancellationToken).ConfigureAwait(false); - - DataSentLastTimeUtc = DateTime.UtcNow; - } - finally - { - _writeDone.Set(); - } - } - - /// - /// Writes text asynchronously. - /// - /// The text. - /// The cancellation token. - /// A task that represents the asynchronous write operation. - public Task WriteTextAsync(string text, CancellationToken cancellationToken = default) - => WriteTextAsync(text, TextEncoding, cancellationToken); - - /// - /// Writes text asynchronously. - /// - /// The text. - /// The encoding. - /// The cancellation token. - /// A task that represents the asynchronous write operation. - public Task WriteTextAsync(string text, Encoding encoding, CancellationToken cancellationToken = default) - => WriteDataAsync(encoding.GetBytes(text), true, cancellationToken); - - /// - /// Writes a line of text asynchronously. - /// The new line sequence is added automatically at the end of the line. - /// - /// The line. - /// The encoding. - /// The cancellation token. - /// A task that represents the asynchronous write operation. - public Task WriteLineAsync(string line, Encoding encoding, CancellationToken cancellationToken = default) - => WriteDataAsync(encoding.GetBytes($"{line}{_newLineSequence}"), true, cancellationToken); - - /// - /// Writes a line of text asynchronously. - /// The new line sequence is added automatically at the end of the line. - /// - /// The line. - /// The cancellation token. - /// A task that represents the asynchronous write operation. - public Task WriteLineAsync(string line, CancellationToken cancellationToken = default) - => WriteLineAsync(line, TextEncoding, cancellationToken); - - #endregion - - #region Socket Methods - - /// - /// Upgrades the active stream to an SSL stream if this connection object is hosted in the server. - /// - /// The server certificate. - /// true if the object is hosted in the server; otherwise, false. - public async Task UpgradeToSecureAsServerAsync(X509Certificate2 serverCertificate) - { - if (IsActiveStreamSecure) - return true; - - _writeDone.WaitOne(); - - SslStream? secureStream = null; - - try - { - secureStream = new SslStream(NetworkStream, true); - await secureStream.AuthenticateAsServerAsync(serverCertificate).ConfigureAwait(false); - SecureStream = secureStream; - return true; - } - catch (Exception ex) - { - ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); - secureStream?.Dispose(); - - return false; - } - } - - /// - /// Upgrades the active stream to an SSL stream if this connection object is hosted in the client. - /// - /// The hostname. - /// The callback. - /// A tasks with true if the upgrade to SSL was successful; otherwise, false. - public async Task UpgradeToSecureAsClientAsync( - string? hostname = null, - RemoteCertificateValidationCallback? callback = null) - { - if (IsActiveStreamSecure) - return true; - - var secureStream = callback == null - ? new SslStream(NetworkStream, true) - : new SslStream(NetworkStream, true, callback); - - try - { - await secureStream.AuthenticateAsClientAsync(hostname ?? Network.HostName.ToLowerInvariant()).ConfigureAwait(false); - SecureStream = secureStream; - } - catch (Exception ex) - { - secureStream.Dispose(); - ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); - return false; - } - - return true; - } - - /// - /// Disconnects this connection. - /// - public void Disconnect() - { - if (_disconnectCalls > 0) - return; - - _disconnectCalls++; - _writeDone.WaitOne(); - - try - { - ClientDisconnected(this, EventArgs.Empty); - } - catch - { - // ignore - } - - try - { -#if !NET461 - RemoteClient.Dispose(); - SecureStream?.Dispose(); - NetworkStream?.Dispose(); -#else - RemoteClient.Close(); - SecureStream?.Close(); - NetworkStream?.Close(); -#endif - } - finally - { - NetworkStream = null; - SecureStream = null; - RemoteClient = null; - _continuousReadingThread = null; - } - } - - #endregion - - #region Dispose - - /// - public void Dispose() - { - if (_hasDisposed) - return; - - // Release managed resources - Disconnect(); - _continuousReadingThread = null; - _writeDone.Dispose(); - - _hasDisposed = true; - } - - #endregion - - #region Continuous Read Methods - - private void RaiseReceiveBufferEvents(IEnumerable receivedData) - { - var moreAvailable = RemoteClient.Available > 0; - - foreach (var data in receivedData) - { - ProcessReceivedBlock(data, moreAvailable); - } - - // Check if we are left with some more stuff to handle - if (_receiveBufferPointer <= 0) - return; - - // Extract the segments split by newline terminated bytes - var sequences = _receiveBuffer.Skip(0).Take(_receiveBufferPointer).ToArray() - .Split(0, _newLineSequenceBytes); - - // Something really wrong happened - if (sequences.Count == 0) - throw new InvalidOperationException("Split function failed! This is terribly wrong!"); - - // We only have one sequence and it is not newline-terminated - // we don't have to do anything. - if (sequences.Count == 1 && sequences[0].EndsWith(_newLineSequenceBytes) == false) - return; - - // Process the events for each sequence - for (var i = 0; i < sequences.Count; i++) - { - var sequenceBytes = sequences[i]; - var isNewLineTerminated = sequences[i].EndsWith(_newLineSequenceBytes); - var isLast = i == sequences.Count - 1; - - if (isNewLineTerminated) - { - var eventArgs = new ConnectionDataReceivedEventArgs( - sequenceBytes, - ConnectionDataReceivedTrigger.NewLineSequenceEncountered, - isLast == false); - DataReceived(this, eventArgs); - } - - // Depending on the last segment determine what to do with the receive buffer - if (!isLast) continue; - - if (isNewLineTerminated) - { - // Simply reset the buffer pointer if the last segment was also terminated - _receiveBufferPointer = 0; - } - else - { - // If we have not received the termination sequence, then just shift the receive buffer to the left - // and adjust the pointer - Array.Copy(sequenceBytes, _receiveBuffer, sequenceBytes.Length); - _receiveBufferPointer = sequenceBytes.Length; - } - } - } - - private void ProcessReceivedBlock(byte data, bool moreAvailable) - { - _receiveBuffer[_receiveBufferPointer] = data; - _receiveBufferPointer++; - - // Block size reached - if (ProtocolBlockSize > 0 && _receiveBufferPointer >= ProtocolBlockSize) - { - SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BlockSizeReached); - return; - } - - // The receive buffer is full. Time to flush - if (_receiveBufferPointer >= _receiveBuffer.Length) - { - SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BufferFull); - } - } - - private void SendBuffer(bool moreAvailable, ConnectionDataReceivedTrigger trigger) - { - var eventBuffer = new byte[_receiveBuffer.Length]; - Array.Copy(_receiveBuffer, eventBuffer, eventBuffer.Length); - - DataReceived(this, - new ConnectionDataReceivedEventArgs( - eventBuffer, - trigger, - moreAvailable)); - _receiveBufferPointer = 0; - } - - private void PerformContinuousReading(object threadContext) - { - _continuousReadingThread = Thread.CurrentThread; - - // Check if the RemoteClient is still there - if (RemoteClient == null) return; - - var receiveBuffer = new byte[RemoteClient.ReceiveBufferSize * 2]; - - while (IsConnected && _disconnectCalls <= 0) - { - var doThreadSleep = false; - - try - { - if (_readTask == null) - _readTask = ActiveStream.ReadAsync(receiveBuffer, 0, receiveBuffer.Length); - - if (_readTask.Wait(_continuousReadingInterval)) - { - var bytesReceivedCount = _readTask.Result; - if (bytesReceivedCount > 0) - { - DataReceivedLastTimeUtc = DateTime.UtcNow; - var buffer = new byte[bytesReceivedCount]; - Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); - RaiseReceiveBufferEvents(buffer); - } - - _readTask = null; - } - else - { - doThreadSleep = _disconnectCalls <= 0; - } - } - catch (Exception ex) - { - ex.Log(nameof(PerformContinuousReading), "Continuous Read operation errored"); - } - finally - { - if (doThreadSleep) - Thread.Sleep(_continuousReadingInterval); - } - } - } - - #endregion - } + /// The client. + /// The text encoding. + /// The new line sequence used for read and write operations. + /// if set to true [disable continuous reading]. + /// Size of the block. -- set to 0 or less to disable. + public Connection(TcpClient client, Encoding textEncoding, String newLineSequence, Boolean disableContinuousReading, Int32 blockSize) { + // Setup basic properties + this.Id = Guid.NewGuid(); + this.TextEncoding = textEncoding; + + // Setup new line sequence + if(String.IsNullOrEmpty(newLineSequence)) { + throw new ArgumentException("Argument cannot be null", nameof(newLineSequence)); + } + + this._newLineSequence = newLineSequence; + this._newLineSequenceBytes = this.TextEncoding.GetBytes(this._newLineSequence); + this._newLineSequenceChars = this._newLineSequence.ToCharArray(); + this._newLineSequenceLineSplitter = new[] { this._newLineSequence }; + + // Setup Connection timers + this.ConnectionStartTimeUtc = DateTime.UtcNow; + this.DataReceivedLastTimeUtc = this.ConnectionStartTimeUtc; + this.DataSentLastTimeUtc = this.ConnectionStartTimeUtc; + + // Setup connection properties + this.RemoteClient = client; + this.LocalEndPoint = client.Client.LocalEndPoint as IPEndPoint; + this.NetworkStream = this.RemoteClient.GetStream(); + this.RemoteEndPoint = this.RemoteClient.Client.RemoteEndPoint as IPEndPoint; + + // Setup buffers + this._receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2]; + this.ProtocolBlockSize = blockSize; + this._receiveBufferPointer = 0; + + // Setup continuous reading mode if enabled + if(disableContinuousReading) { + return; + } + + ThreadPool.GetAvailableThreads(out Int32 availableWorkerThreads, out _); + ThreadPool.GetMaxThreads(out Int32 maxWorkerThreads, out _); + + Int32 activeThreadPoolTreads = maxWorkerThreads - availableWorkerThreads; + + if(activeThreadPoolTreads < Environment.ProcessorCount / 4) { + _ = ThreadPool.QueueUserWorkItem(this.PerformContinuousReading!, this); + } else { + new Thread(this.PerformContinuousReading!) { IsBackground = true }.Start(); + } + } + + /// + /// Initializes a new instance of the class in continuous reading mode. + /// It uses UTF8 encoding, CRLF as a new line sequence and disables a protocol block size. + /// + /// The client. + public Connection(TcpClient client) : this(client, Encoding.UTF8, "\r\n", false, 0) { + // placeholder + } + + /// + /// Initializes a new instance of the class in continuous reading mode. + /// It uses UTF8 encoding, disables line sequences, and uses a protocol block size instead. + /// + /// The client. + /// Size of the block. + public Connection(TcpClient client, Int32 blockSize) : this(client, Encoding.UTF8, new String('\n', blockSize + 1), false, blockSize) { + // placeholder + } + + #region Events + + /// + /// Occurs when the receive buffer has encounters a new line sequence, the buffer is flushed or the buffer is full. + /// + public event EventHandler DataReceived = (s, e) => { }; + + /// + /// Occurs when an error occurs while upgrading, sending, or receiving data in this client + /// + public event EventHandler ConnectionFailure = (s, e) => { }; + + /// + /// Occurs when a client is disconnected + /// + public event EventHandler ClientDisconnected = (s, e) => { }; + + #endregion + + #region Properties + + /// + /// Gets the unique identifier of this connection. + /// This field is filled out upon instantiation of this class. + /// + /// + /// The identifier. + /// + public Guid Id { + get; + } + + /// + /// Gets the active stream. Returns an SSL stream if the connection is secure, otherwise returns + /// the underlying NetworkStream. + /// + /// + /// The active stream. + /// + public Stream? ActiveStream => this.SecureStream ?? this.NetworkStream as Stream; + + /// + /// Gets a value indicating whether the current connection stream is an SSL stream. + /// + /// + /// true if this instance is active stream secure; otherwise, false. + /// + public Boolean IsActiveStreamSecure => this.SecureStream != null; + + /// + /// Gets the text encoding for send and receive operations. + /// + /// + /// The text encoding. + /// + public Encoding TextEncoding { + get; + } + + /// + /// Gets the remote end point of this TCP connection. + /// + /// + /// The remote end point. + /// + public IPEndPoint? RemoteEndPoint { + get; + } + + /// + /// Gets the local end point of this TCP connection. + /// + /// + /// The local end point. + /// + public IPEndPoint? LocalEndPoint { + get; + } + + /// + /// Gets the remote client of this TCP connection. + /// + /// + /// The remote client. + /// + public TcpClient? RemoteClient { + get; private set; + } + + /// + /// When in continuous reading mode, and if set to greater than 0, + /// a Data reception event will be fired whenever the amount of bytes + /// determined by this property has been received. Useful for fixed-length message protocols. + /// + /// + /// The size of the protocol block. + /// + public Int32 ProtocolBlockSize { + get; + } + + /// + /// Gets a value indicating whether this connection is in continuous reading mode. + /// Remark: Whenever a disconnect event occurs, the background thread is terminated + /// and this property will return false whenever the reading thread is not active. + /// Therefore, even if continuous reading was not disabled in the constructor, this property + /// might return false. + /// + /// + /// true if this instance is continuous reading enabled; otherwise, false. + /// + public Boolean IsContinuousReadingEnabled => this._continuousReadingThread != null; + + /// + /// Gets the start time at which the connection was started in UTC. + /// + /// + /// The connection start time UTC. + /// + public DateTime ConnectionStartTimeUtc { + get; + } + + /// + /// Gets the start time at which the connection was started in local time. + /// + /// + /// The connection start time. + /// + public DateTime ConnectionStartTime => this.ConnectionStartTimeUtc.ToLocalTime(); + + /// + /// Gets the duration of the connection. + /// + /// + /// The duration of the connection. + /// + public TimeSpan ConnectionDuration => DateTime.UtcNow.Subtract(this.ConnectionStartTimeUtc); + + /// + /// Gets the last time data was received at in UTC. + /// + /// + /// The data received last time UTC. + /// + public DateTime DataReceivedLastTimeUtc { + get; private set; + } + + /// + /// Gets how long has elapsed since data was last received. + /// + public TimeSpan DataReceivedIdleDuration => DateTime.UtcNow.Subtract(this.DataReceivedLastTimeUtc); + + /// + /// Gets the last time at which data was sent in UTC. + /// + /// + /// The data sent last time UTC. + /// + public DateTime DataSentLastTimeUtc { + get; private set; + } + + /// + /// Gets how long has elapsed since data was last sent. + /// + /// + /// The duration of the data sent idle. + /// + public TimeSpan DataSentIdleDuration => DateTime.UtcNow.Subtract(this.DataSentLastTimeUtc); + + /// + /// Gets a value indicating whether this connection is connected. + /// Remarks: This property polls the socket internally and checks if it is available to read data from it. + /// If disconnect has been called, then this property will return false. + /// + /// + /// true if this instance is connected; otherwise, false. + /// + public Boolean IsConnected { + get { + if(this._disconnectCalls > 0) { + return false; + } + + try { + Socket? socket = this.RemoteClient?.Client; + if(socket == null || this.NetworkStream == null) { + return false; + } + Boolean pollResult = !(socket.Poll(1000, SelectMode.SelectRead) && this.NetworkStream.DataAvailable == false || !socket.Connected); + + if(pollResult == false) { + this.Disconnect(); + } + + return pollResult; + } catch { + this.Disconnect(); + return false; + } + } + } + + private NetworkStream? NetworkStream { + get; set; + } + + private SslStream? SecureStream { + get; set; + } + + #endregion + + #region Read Methods + + /// + /// Reads data from the remote client asynchronously and with the given timeout. + /// + /// The timeout. + /// The cancellation token. + /// A byte array containing the results of encoding the specified set of characters. + /// Read methods have been disabled because continuous reading is enabled. + /// Reading data from {ActiveStream} timed out in {timeout.TotalMilliseconds} m. + public async Task ReadDataAsync(TimeSpan timeout, CancellationToken cancellationToken = default) { + if(this.IsContinuousReadingEnabled) { + throw new InvalidOperationException("Read methods have been disabled because continuous reading is enabled."); + } + + if(this.RemoteClient == null) { + throw new InvalidOperationException("An open connection is required"); + } + + Byte[] receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2]; + List receiveBuilder = new List(receiveBuffer.Length); + + try { + DateTime startTime = DateTime.UtcNow; + + while(receiveBuilder.Count <= 0) { + if(DateTime.UtcNow.Subtract(startTime) >= timeout) { + throw new TimeoutException($"Reading data from {this.ActiveStream} timed out in {timeout.TotalMilliseconds} ms"); + } + + if(this._readTask == null) { + this._readTask = this.ActiveStream?.ReadAsync(receiveBuffer, 0, receiveBuffer.Length, cancellationToken); + } + + if(this._readTask != null && this._readTask.Wait(this._continuousReadingInterval)) { + Int32 bytesReceivedCount = this._readTask.Result; + if(bytesReceivedCount > 0) { + this.DataReceivedLastTimeUtc = DateTime.UtcNow; + Byte[] buffer = new Byte[bytesReceivedCount]; + Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); + receiveBuilder.AddRange(buffer); + } + + this._readTask = null; + } else { + await Task.Delay(this._continuousReadingInterval, cancellationToken).ConfigureAwait(false); + } + } + } catch(Exception ex) { + ex.Error(typeof(Connection).FullName, "Error while reading network stream data asynchronously."); + throw; + } + + return receiveBuilder.ToArray(); + } + + /// + /// Reads data asynchronously from the remote stream with a 5000 millisecond timeout. + /// + /// The cancellation token. + /// + /// A byte array containing the results the specified sequence of bytes. + /// + public Task ReadDataAsync(CancellationToken cancellationToken = default) => this.ReadDataAsync(TimeSpan.FromSeconds(5), cancellationToken); + + /// + /// Asynchronously reads data as text with the given timeout. + /// + /// The timeout. + /// The cancellation token. + /// + /// A that contains the results of decoding the specified sequence of bytes. + /// + public async Task ReadTextAsync(TimeSpan timeout, CancellationToken cancellationToken = default) { + Byte[] buffer = await this.ReadDataAsync(timeout, cancellationToken).ConfigureAwait(false); + return buffer == null ? null : this.TextEncoding.GetString(buffer); + } + + /// + /// Asynchronously reads data as text with a 5000 millisecond timeout. + /// + /// The cancellation token. + /// + /// When this method completes successfully, it returns the contents of the file as a text string. + /// + public Task ReadTextAsync(CancellationToken cancellationToken = default) => this.ReadTextAsync(TimeSpan.FromSeconds(5), cancellationToken); + + /// + /// Performs the same task as this method's overload but it defaults to a read timeout of 30 seconds. + /// + /// The cancellation token. + /// + /// A task that represents the asynchronous read operation. The value of the TResult parameter + /// contains the next line from the stream, or is null if all the characters have been read. + /// + public Task ReadLineAsync(CancellationToken cancellationToken = default) => this.ReadLineAsync(TimeSpan.FromSeconds(30), cancellationToken); + + /// + /// Reads the next available line of text in queue. Return null when no text is read. + /// This method differs from the rest of the read methods because it keeps an internal + /// queue of lines that are read from the stream and only returns the one line next in the queue. + /// It is only recommended to use this method when you are working with text-based protocols + /// and the rest of the read methods are not called. + /// + /// The timeout. + /// The cancellation token. + /// A task with a string line from the queue. + /// Read methods have been disabled because continuous reading is enabled. + public async Task ReadLineAsync(TimeSpan timeout, CancellationToken cancellationToken = default) { + if(this.IsContinuousReadingEnabled) { + throw new InvalidOperationException("Read methods have been disabled because continuous reading is enabled."); + } + + if(this._readLineBuffer.Count > 0) { + return this._readLineBuffer.Dequeue(); + } + + StringBuilder builder = new StringBuilder(); + + while(true) { + String? text = await this.ReadTextAsync(timeout, cancellationToken).ConfigureAwait(false); + + if(String.IsNullOrEmpty(text)) { + break; + } + + _ = builder.Append(text); + + if(!text.EndsWith(this._newLineSequence)) { + continue; + } + + String[] lines = builder.ToString().TrimEnd(this._newLineSequenceChars).Split(this._newLineSequenceLineSplitter, StringSplitOptions.None); + foreach(String item in lines) { + this._readLineBuffer.Enqueue(item); + } + + break; + } + + return this._readLineBuffer.Count > 0 ? this._readLineBuffer.Dequeue() : null; + } + + #endregion + + #region Write Methods + + /// + /// Writes data asynchronously. + /// + /// The buffer. + /// if set to true [force flush]. + /// The cancellation token. + /// A task that represents the asynchronous write operation. + public async Task WriteDataAsync(Byte[] buffer, Boolean forceFlush, CancellationToken cancellationToken = default) { + try { + _ = this._writeDone.WaitOne(); + _ = this._writeDone.Reset(); + if(this.ActiveStream == null) { + return; + } + await this.ActiveStream.WriteAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + if(forceFlush) { + await this.ActiveStream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + + this.DataSentLastTimeUtc = DateTime.UtcNow; + } finally { + _ = this._writeDone.Set(); + } + } + + /// + /// Writes text asynchronously. + /// + /// The text. + /// The cancellation token. + /// A task that represents the asynchronous write operation. + public Task WriteTextAsync(String text, CancellationToken cancellationToken = default) => this.WriteTextAsync(text, this.TextEncoding, cancellationToken); + + /// + /// Writes text asynchronously. + /// + /// The text. + /// The encoding. + /// The cancellation token. + /// A task that represents the asynchronous write operation. + public Task WriteTextAsync(String text, Encoding encoding, CancellationToken cancellationToken = default) => this.WriteDataAsync(encoding.GetBytes(text), true, cancellationToken); + + /// + /// Writes a line of text asynchronously. + /// The new line sequence is added automatically at the end of the line. + /// + /// The line. + /// The encoding. + /// The cancellation token. + /// A task that represents the asynchronous write operation. + public Task WriteLineAsync(String line, Encoding encoding, CancellationToken cancellationToken = default) => this.WriteDataAsync(encoding.GetBytes($"{line}{this._newLineSequence}"), true, cancellationToken); + + /// + /// Writes a line of text asynchronously. + /// The new line sequence is added automatically at the end of the line. + /// + /// The line. + /// The cancellation token. + /// A task that represents the asynchronous write operation. + public Task WriteLineAsync(String line, CancellationToken cancellationToken = default) => this.WriteLineAsync(line, this.TextEncoding, cancellationToken); + + #endregion + + #region Socket Methods + + /// + /// Upgrades the active stream to an SSL stream if this connection object is hosted in the server. + /// + /// The server certificate. + /// true if the object is hosted in the server; otherwise, false. + public async Task UpgradeToSecureAsServerAsync(X509Certificate2 serverCertificate) { + if(this.IsActiveStreamSecure) { + return true; + } + + _ = this._writeDone.WaitOne(); + + SslStream? secureStream = null; + + try { + secureStream = new SslStream(this.NetworkStream, true); + await secureStream.AuthenticateAsServerAsync(serverCertificate).ConfigureAwait(false); + this.SecureStream = secureStream; + return true; + } catch(Exception ex) { + ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); + secureStream?.Dispose(); + + return false; + } + } + + /// + /// Upgrades the active stream to an SSL stream if this connection object is hosted in the client. + /// + /// The hostname. + /// The callback. + /// A tasks with true if the upgrade to SSL was successful; otherwise, false. + public async Task UpgradeToSecureAsClientAsync(String? hostname = null, RemoteCertificateValidationCallback? callback = null) { + if(this.IsActiveStreamSecure) { + return true; + } + + SslStream secureStream = callback == null ? new SslStream(this.NetworkStream, true) : new SslStream(this.NetworkStream, true, callback); + + try { + await secureStream.AuthenticateAsClientAsync(hostname ?? Network.HostName.ToLowerInvariant()).ConfigureAwait(false); + this.SecureStream = secureStream; + } catch(Exception ex) { + secureStream.Dispose(); + ConnectionFailure(this, new ConnectionFailureEventArgs(ex)); + return false; + } + + return true; + } + + /// + /// Disconnects this connection. + /// + public void Disconnect() { + if(this._disconnectCalls > 0) { + return; + } + + this._disconnectCalls++; + _ = this._writeDone.WaitOne(); + + try { + ClientDisconnected(this, EventArgs.Empty); + } catch { + // ignore + } + + try { + this.RemoteClient?.Dispose(); + this.SecureStream?.Dispose(); + this.NetworkStream?.Dispose(); + + } finally { + this.NetworkStream = null; + this.SecureStream = null; + this.RemoteClient = null; + this._continuousReadingThread = null; + } + } + + #endregion + + #region Dispose + + /// + public void Dispose() { + if(this._hasDisposed) { + return; + } + + // Release managed resources + this.Disconnect(); + this._continuousReadingThread = null; + this._writeDone.Dispose(); + + this._hasDisposed = true; + } + + #endregion + + #region Continuous Read Methods + + private void RaiseReceiveBufferEvents(IEnumerable receivedData) { + if(this.RemoteClient == null) { + return; + } + Boolean moreAvailable = this.RemoteClient.Available > 0; + + foreach(Byte data in receivedData) { + this.ProcessReceivedBlock(data, moreAvailable); + } + + // Check if we are left with some more stuff to handle + if(this._receiveBufferPointer <= 0) { + return; + } + + // Extract the segments split by newline terminated bytes + List sequences = this._receiveBuffer.Skip(0).Take(this._receiveBufferPointer).ToArray().Split(0, this._newLineSequenceBytes); + + // Something really wrong happened + if(sequences.Count == 0) { + throw new InvalidOperationException("Split function failed! This is terribly wrong!"); + } + + // We only have one sequence and it is not newline-terminated + // we don't have to do anything. + if(sequences.Count == 1 && sequences[0].EndsWith(this._newLineSequenceBytes) == false) { + return; + } + + // Process the events for each sequence + for(Int32 i = 0; i < sequences.Count; i++) { + Byte[] sequenceBytes = sequences[i]; + Boolean isNewLineTerminated = sequences[i].EndsWith(this._newLineSequenceBytes); + Boolean isLast = i == sequences.Count - 1; + + if(isNewLineTerminated) { + ConnectionDataReceivedEventArgs eventArgs = new ConnectionDataReceivedEventArgs(sequenceBytes, ConnectionDataReceivedTrigger.NewLineSequenceEncountered, isLast == false); + DataReceived(this, eventArgs); + } + + // Depending on the last segment determine what to do with the receive buffer + if(!isLast) { + continue; + } + + if(isNewLineTerminated) { + // Simply reset the buffer pointer if the last segment was also terminated + this._receiveBufferPointer = 0; + } else { + // If we have not received the termination sequence, then just shift the receive buffer to the left + // and adjust the pointer + Array.Copy(sequenceBytes, this._receiveBuffer, sequenceBytes.Length); + this._receiveBufferPointer = sequenceBytes.Length; + } + } + } + + private void ProcessReceivedBlock(Byte data, Boolean moreAvailable) { + this._receiveBuffer[this._receiveBufferPointer] = data; + this._receiveBufferPointer++; + + // Block size reached + if(this.ProtocolBlockSize > 0 && this._receiveBufferPointer >= this.ProtocolBlockSize) { + this.SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BlockSizeReached); + return; + } + + // The receive buffer is full. Time to flush + if(this._receiveBufferPointer >= this._receiveBuffer.Length) { + this.SendBuffer(moreAvailable, ConnectionDataReceivedTrigger.BufferFull); + } + } + + private void SendBuffer(Boolean moreAvailable, ConnectionDataReceivedTrigger trigger) { + Byte[] eventBuffer = new Byte[this._receiveBuffer.Length]; + Array.Copy(this._receiveBuffer, eventBuffer, eventBuffer.Length); + + DataReceived(this, new ConnectionDataReceivedEventArgs(eventBuffer, trigger, moreAvailable)); + this._receiveBufferPointer = 0; + } + + private void PerformContinuousReading(Object threadContext) { + this._continuousReadingThread = Thread.CurrentThread; + + // Check if the RemoteClient is still there + if(this.RemoteClient == null) { + return; + } + + Byte[] receiveBuffer = new Byte[this.RemoteClient.ReceiveBufferSize * 2]; + + while(this.IsConnected && this._disconnectCalls <= 0) { + Boolean doThreadSleep = false; + + try { + if(this._readTask == null) { + this._readTask = this.ActiveStream?.ReadAsync(receiveBuffer, 0, receiveBuffer.Length); + } + + if(this._readTask != null && this._readTask.Wait(this._continuousReadingInterval)) { + Int32 bytesReceivedCount = this._readTask.Result; + if(bytesReceivedCount > 0) { + this.DataReceivedLastTimeUtc = DateTime.UtcNow; + Byte[] buffer = new Byte[bytesReceivedCount]; + Array.Copy(receiveBuffer, 0, buffer, 0, bytesReceivedCount); + this.RaiseReceiveBufferEvents(buffer); + } + + this._readTask = null; + } else { + doThreadSleep = this._disconnectCalls <= 0; + } + } catch(Exception ex) { + ex.Log(nameof(PerformContinuousReading), "Continuous Read operation errored"); + } finally { + if(doThreadSleep) { + Thread.Sleep(this._continuousReadingInterval); + } + } + } + } + + #endregion + } } diff --git a/Swan/Net/ConnectionDataReceivedTrigger.cs b/Swan/Net/ConnectionDataReceivedTrigger.cs index e6d592d..3d883fb 100644 --- a/Swan/Net/ConnectionDataReceivedTrigger.cs +++ b/Swan/Net/ConnectionDataReceivedTrigger.cs @@ -1,28 +1,26 @@ -namespace Swan -{ +namespace Swan { + /// + /// Enumerates the possible causes of the DataReceived event occurring. + /// + public enum ConnectionDataReceivedTrigger { /// - /// Enumerates the possible causes of the DataReceived event occurring. + /// The trigger was a forceful flush of the buffer /// - public enum ConnectionDataReceivedTrigger - { - /// - /// The trigger was a forceful flush of the buffer - /// - Flush, - - /// - /// The new line sequence bytes were received - /// - NewLineSequenceEncountered, - - /// - /// The buffer was full - /// - BufferFull, - - /// - /// The block size reached - /// - BlockSizeReached, - } + Flush, + + /// + /// The new line sequence bytes were received + /// + NewLineSequenceEncountered, + + /// + /// The buffer was full + /// + BufferFull, + + /// + /// The block size reached + /// + BlockSizeReached, + } } diff --git a/Swan/Net/ConnectionListener.cs b/Swan/Net/ConnectionListener.cs index 1c4b2bb..76138b3 100644 --- a/Swan/Net/ConnectionListener.cs +++ b/Swan/Net/ConnectionListener.cs @@ -1,253 +1,226 @@ -namespace Swan.Net -{ - using System; - using System.Net; - using System.Net.Sockets; - using System.Threading; - using System.Threading.Tasks; - +#nullable enable +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Net { + /// + /// TCP Listener manager with built-in events and asynchronous functionality. + /// This networking component is typically used when writing server software. + /// + /// + public sealed class ConnectionListener : IDisposable { + private readonly Object _stateLock = new Object(); + private TcpListener? _listenerSocket; + private Boolean _cancellationPending; + [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0069:Verwerfbare Felder verwerfen", Justification = "")] + private CancellationTokenSource? _cancelListening; + private Task? _backgroundWorkerTask; + private Boolean _hasDisposed; + + #region Events + /// - /// TCP Listener manager with built-in events and asynchronous functionality. - /// This networking component is typically used when writing server software. + /// Occurs when a new connection requests a socket from the listener. + /// Set Cancel = true to prevent the TCP client from being accepted. /// - /// - public sealed class ConnectionListener : IDisposable - { - private readonly object _stateLock = new object(); - private TcpListener _listenerSocket; - private bool _cancellationPending; - private CancellationTokenSource _cancelListening; - private Task? _backgroundWorkerTask; - private bool _hasDisposed; + public event EventHandler OnConnectionAccepting = (s, e) => { }; + + /// + /// Occurs when a new connection is accepted. + /// + public event EventHandler OnConnectionAccepted = (s, e) => { }; + + /// + /// Occurs when a connection fails to get accepted + /// + public event EventHandler OnConnectionFailure = (s, e) => { }; + + /// + /// Occurs when the listener stops. + /// + public event EventHandler OnListenerStopped = (s, e) => { }; + + #endregion + + #region Constructors + + /// + /// Initializes a new instance of the class. + /// + /// The listen end point. + public ConnectionListener(IPEndPoint listenEndPoint) { + this.Id = Guid.NewGuid(); + this.LocalEndPoint = listenEndPoint ?? throw new ArgumentNullException(nameof(listenEndPoint)); + } + + /// + /// Initializes a new instance of the class. + /// It uses the loopback address for listening. + /// + /// The listen port. + public ConnectionListener(Int32 listenPort) : this(new IPEndPoint(IPAddress.Loopback, listenPort)) { + } + + /// + /// Initializes a new instance of the class. + /// + /// The listen address. + /// The listen port. + public ConnectionListener(IPAddress listenAddress, Int32 listenPort) : this(new IPEndPoint(listenAddress, listenPort)) { + } + + /// + /// Finalizes an instance of the class. + /// + ~ConnectionListener() { + this.Dispose(false); + } + + #endregion + + #region Public Properties + + /// + /// Gets the local end point on which we are listening. + /// + /// + /// The local end point. + /// + public IPEndPoint LocalEndPoint { + get; + } + + /// + /// Gets a value indicating whether this listener is active. + /// + /// + /// true if this instance is listening; otherwise, false. + /// + public Boolean IsListening => this._backgroundWorkerTask != null; + + /// + /// Gets a unique identifier that gets automatically assigned upon instantiation of this class. + /// + /// + /// The unique identifier. + /// + public Guid Id { + get; + } + + #endregion + + #region Start and Stop + + /// + /// Starts the listener in an asynchronous, non-blocking fashion. + /// Subscribe to the events of this class to gain access to connected client sockets. + /// + /// Cancellation has already been requested. This listener is not reusable. + public void Start() { + lock(this._stateLock) { + if(this._backgroundWorkerTask != null) { + return; + } + + if(this._cancellationPending) { + throw new InvalidOperationException("Cancellation has already been requested. This listener is not reusable."); + } + + this._backgroundWorkerTask = this.DoWorkAsync(); + } + } + + /// + /// Stops the listener from receiving new connections. + /// This does not prevent the listener from . + /// + public void Stop() { + lock(this._stateLock) { + this._cancellationPending = true; + this._listenerSocket?.Stop(); + this._cancelListening?.Cancel(); + this._backgroundWorkerTask?.Wait(); + this._backgroundWorkerTask = null; + this._cancellationPending = false; + } + } + + /// + /// Returns a that represents this instance. + /// + /// + /// A that represents this instance. + /// + public override String ToString() => this.LocalEndPoint.ToString(); + + /// + public void Dispose() { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Releases unmanaged and - optionally - managed resources. + /// + /// true to release both managed and unmanaged resources; false to release only unmanaged resources. + private void Dispose(Boolean disposing) { + if(this._hasDisposed) { + return; + } + + if(disposing) { + // Release managed resources + this.Stop(); + } + + this._hasDisposed = true; + } + + /// + /// Continuously checks for client connections until the Close method has been called. + /// + /// A task that represents the asynchronous connection operation. + private async Task DoWorkAsync() { + this._cancellationPending = false; + this._listenerSocket = new TcpListener(this.LocalEndPoint); + this._listenerSocket.Start(); + this._cancelListening = new CancellationTokenSource(); + + try { + while(this._cancellationPending == false) { + try { + TcpClient client = await Task.Run(() => this._listenerSocket.AcceptTcpClientAsync(), this._cancelListening.Token).ConfigureAwait(false); + ConnectionAcceptingEventArgs acceptingArgs = new ConnectionAcceptingEventArgs(client); + OnConnectionAccepting(this, acceptingArgs); + + if(acceptingArgs.Cancel) { + client.Dispose(); - #region Events - - /// - /// Occurs when a new connection requests a socket from the listener. - /// Set Cancel = true to prevent the TCP client from being accepted. - /// - public event EventHandler OnConnectionAccepting = (s, e) => { }; - - /// - /// Occurs when a new connection is accepted. - /// - public event EventHandler OnConnectionAccepted = (s, e) => { }; - - /// - /// Occurs when a connection fails to get accepted - /// - public event EventHandler OnConnectionFailure = (s, e) => { }; - - /// - /// Occurs when the listener stops. - /// - public event EventHandler OnListenerStopped = (s, e) => { }; - - #endregion - - #region Constructors - - /// - /// Initializes a new instance of the class. - /// - /// The listen end point. - public ConnectionListener(IPEndPoint listenEndPoint) - { - Id = Guid.NewGuid(); - LocalEndPoint = listenEndPoint ?? throw new ArgumentNullException(nameof(listenEndPoint)); - } - - /// - /// Initializes a new instance of the class. - /// It uses the loopback address for listening. - /// - /// The listen port. - public ConnectionListener(int listenPort) - : this(new IPEndPoint(IPAddress.Loopback, listenPort)) - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The listen address. - /// The listen port. - public ConnectionListener(IPAddress listenAddress, int listenPort) - : this(new IPEndPoint(listenAddress, listenPort)) - { - } - - /// - /// Finalizes an instance of the class. - /// - ~ConnectionListener() - { - Dispose(false); - } - - #endregion - - #region Public Properties - - /// - /// Gets the local end point on which we are listening. - /// - /// - /// The local end point. - /// - public IPEndPoint LocalEndPoint { get; } - - /// - /// Gets a value indicating whether this listener is active. - /// - /// - /// true if this instance is listening; otherwise, false. - /// - public bool IsListening => _backgroundWorkerTask != null; - - /// - /// Gets a unique identifier that gets automatically assigned upon instantiation of this class. - /// - /// - /// The unique identifier. - /// - public Guid Id { get; } - - #endregion - - #region Start and Stop - - /// - /// Starts the listener in an asynchronous, non-blocking fashion. - /// Subscribe to the events of this class to gain access to connected client sockets. - /// - /// Cancellation has already been requested. This listener is not reusable. - public void Start() - { - lock (_stateLock) - { - if (_backgroundWorkerTask != null) - { - return; - } - - if (_cancellationPending) - { - throw new InvalidOperationException( - "Cancellation has already been requested. This listener is not reusable."); - } - - _backgroundWorkerTask = DoWorkAsync(); - } - } - - /// - /// Stops the listener from receiving new connections. - /// This does not prevent the listener from . - /// - public void Stop() - { - lock (_stateLock) - { - _cancellationPending = true; - _listenerSocket?.Stop(); - _cancelListening?.Cancel(); - _backgroundWorkerTask?.Wait(); - _backgroundWorkerTask = null; - _cancellationPending = false; - } - } - - /// - /// Returns a that represents this instance. - /// - /// - /// A that represents this instance. - /// - public override string ToString() => LocalEndPoint.ToString(); - - /// - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - /// - /// Releases unmanaged and - optionally - managed resources. - /// - /// true to release both managed and unmanaged resources; false to release only unmanaged resources. - private void Dispose(bool disposing) - { - if (_hasDisposed) - return; - - if (disposing) - { - // Release managed resources - Stop(); - } - - _hasDisposed = true; - } - - /// - /// Continuously checks for client connections until the Close method has been called. - /// - /// A task that represents the asynchronous connection operation. - private async Task DoWorkAsync() - { - _cancellationPending = false; - _listenerSocket = new TcpListener(LocalEndPoint); - _listenerSocket.Start(); - _cancelListening = new CancellationTokenSource(); - - try - { - while (_cancellationPending == false) - { - try - { - var client = await Task.Run(() => _listenerSocket.AcceptTcpClientAsync(), _cancelListening.Token).ConfigureAwait(false); - var acceptingArgs = new ConnectionAcceptingEventArgs(client); - OnConnectionAccepting(this, acceptingArgs); - - if (acceptingArgs.Cancel) - { -#if !NET461 - client.Dispose(); -#else - client.Close(); -#endif - continue; - } - - OnConnectionAccepted(this, new ConnectionAcceptedEventArgs(client)); - } - catch (Exception ex) - { - OnConnectionFailure(this, new ConnectionFailureEventArgs(ex)); - } - } - - OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(LocalEndPoint)); - } - catch (ObjectDisposedException) - { - OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(LocalEndPoint)); - } - catch (Exception ex) - { - OnListenerStopped(this, - new ConnectionListenerStoppedEventArgs(LocalEndPoint, _cancellationPending ? null : ex)); - } - finally - { - _backgroundWorkerTask = null; - _cancellationPending = false; - } - } - - #endregion - } + continue; + } + + OnConnectionAccepted(this, new ConnectionAcceptedEventArgs(client)); + } catch(Exception ex) { + OnConnectionFailure(this, new ConnectionFailureEventArgs(ex)); + } + } + + OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(this.LocalEndPoint)); + } catch(ObjectDisposedException) { + OnListenerStopped(this, new ConnectionListenerStoppedEventArgs(this.LocalEndPoint)); + } catch(Exception ex) { + OnListenerStopped(this, + new ConnectionListenerStoppedEventArgs(this.LocalEndPoint, this._cancellationPending ? null : ex)); + } finally { + this._backgroundWorkerTask = null; + this._cancellationPending = false; + } + } + + #endregion + } } diff --git a/Swan/Net/Dns/DnsClient.Interfaces.cs b/Swan/Net/Dns/DnsClient.Interfaces.cs index 1d301f1..d869e38 100644 --- a/Swan/Net/Dns/DnsClient.Interfaces.cs +++ b/Swan/Net/Dns/DnsClient.Interfaces.cs @@ -1,62 +1,96 @@ -namespace Swan.Net.Dns -{ - using System; - using System.Threading.Tasks; - using System.Collections.Generic; - - /// - /// DnsClient public interfaces. - /// - internal partial class DnsClient - { - public interface IDnsMessage - { - IList Questions { get; } - - int Size { get; } - byte[] ToArray(); - } - - public interface IDnsMessageEntry - { - DnsDomain Name { get; } - DnsRecordType Type { get; } - DnsRecordClass Class { get; } - - int Size { get; } - byte[] ToArray(); - } - - public interface IDnsResourceRecord : IDnsMessageEntry - { - TimeSpan TimeToLive { get; } - int DataLength { get; } - byte[] Data { get; } - } - - public interface IDnsRequest : IDnsMessage - { - int Id { get; set; } - DnsOperationCode OperationCode { get; set; } - bool RecursionDesired { get; set; } - } - - public interface IDnsResponse : IDnsMessage - { - int Id { get; set; } - IList AnswerRecords { get; } - IList AuthorityRecords { get; } - IList AdditionalRecords { get; } - bool IsRecursionAvailable { get; set; } - bool IsAuthorativeServer { get; set; } - bool IsTruncated { get; set; } - DnsOperationCode OperationCode { get; set; } - DnsResponseCode ResponseCode { get; set; } - } - - public interface IDnsRequestResolver - { - Task Request(DnsClientRequest request); - } - } +using System; +using System.Threading.Tasks; +using System.Collections.Generic; + +namespace Swan.Net.Dns { + /// + /// DnsClient public interfaces. + /// + internal partial class DnsClient { + public interface IDnsMessage { + IList Questions { + get; + } + + Int32 Size { + get; + } + Byte[] ToArray(); + } + + public interface IDnsMessageEntry { + DnsDomain Name { + get; + } + DnsRecordType Type { + get; + } + DnsRecordClass Class { + get; + } + + Int32 Size { + get; + } + Byte[] ToArray(); + } + + public interface IDnsResourceRecord : IDnsMessageEntry { + TimeSpan TimeToLive { + get; + } + Int32 DataLength { + get; + } + Byte[] Data { + get; + } + } + + public interface IDnsRequest : IDnsMessage { + Int32 Id { + get; set; + } + DnsOperationCode OperationCode { + get; set; + } + Boolean RecursionDesired { + get; set; + } + } + + public interface IDnsResponse : IDnsMessage { + Int32 Id { + get; set; + } + IList AnswerRecords { + get; + } + IList AuthorityRecords { + get; + } + IList AdditionalRecords { + get; + } + Boolean IsRecursionAvailable { + get; set; + } + Boolean IsAuthorativeServer { + get; set; + } + Boolean IsTruncated { + get; set; + } + DnsOperationCode OperationCode { + get; set; + } + DnsResponseCode ResponseCode { + get; set; + } + } + + public interface IDnsRequestResolver { + Task Request(DnsClientRequest request); + } + } } diff --git a/Swan/Net/Dns/DnsClient.Request.cs b/Swan/Net/Dns/DnsClient.Request.cs index e962844..8b9f396 100644 --- a/Swan/Net/Dns/DnsClient.Request.cs +++ b/Swan/Net/Dns/DnsClient.Request.cs @@ -1,681 +1,558 @@ -namespace Swan.Net.Dns -{ - using Formatters; - using System; - using System.Collections.Generic; - using System.IO; - using System.Threading.Tasks; - using System.Linq; - using System.Net; - using System.Net.Sockets; - using System.Runtime.InteropServices; - using System.Text; - - /// - /// DnsClient Request inner class. - /// - internal partial class DnsClient - { - public class DnsClientRequest : IDnsRequest - { - private readonly IDnsRequestResolver _resolver; - private readonly IDnsRequest _request; - - public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null) - { - Dns = dns; - _request = request == null ? new DnsRequest() : new DnsRequest(request); - _resolver = resolver ?? new DnsUdpRequestResolver(); - } - - public int Id - { - get => _request.Id; - set => _request.Id = value; - } - - public DnsOperationCode OperationCode - { - get => _request.OperationCode; - set => _request.OperationCode = value; - } - - public bool RecursionDesired - { - get => _request.RecursionDesired; - set => _request.RecursionDesired = value; - } - - public IList Questions => _request.Questions; - - public int Size => _request.Size; - - public IPEndPoint Dns { get; set; } - - public byte[] ToArray() => _request.ToArray(); - - public override string ToString() => _request.ToString(); - - /// - /// Resolves this request into a response using the provided DNS information. The given - /// request strategy is used to retrieve the response. - /// - /// Throw if a malformed response is received from the server. - /// Thrown if a IO error occurs. - /// Thrown if a the reading or writing to the socket fails. - /// The response received from server. - public async Task Resolve() - { - try - { - var response = await _resolver.Request(this).ConfigureAwait(false); - - if (response.Id != Id) - { - throw new DnsQueryException(response, "Mismatching request/response IDs"); - } - - if (response.ResponseCode != DnsResponseCode.NoError) - { - throw new DnsQueryException(response); - } - - return response; - } - catch (Exception e) - { - if (e is ArgumentException || e is SocketException) - throw new DnsQueryException("Invalid response", e); - - throw; - } - } - } - - public class DnsRequest : IDnsRequest - { - private static readonly Random Random = new Random(); - - private DnsHeader header; - - public DnsRequest() - { - Questions = new List(); - header = new DnsHeader - { - OperationCode = DnsOperationCode.Query, - Response = false, - Id = Random.Next(ushort.MaxValue), - }; - } - - public DnsRequest(IDnsRequest request) - { - header = new DnsHeader(); - Questions = new List(request.Questions); - - header.Response = false; - - Id = request.Id; - OperationCode = request.OperationCode; - RecursionDesired = request.RecursionDesired; - } - - public IList Questions { get; } - - public int Size => header.Size + Questions.Sum(q => q.Size); - - public int Id - { - get => header.Id; - set => header.Id = value; - } - - public DnsOperationCode OperationCode - { - get => header.OperationCode; - set => header.OperationCode = value; - } - - public bool RecursionDesired - { - get => header.RecursionDesired; - set => header.RecursionDesired = value; - } - - public byte[] ToArray() - { - UpdateHeader(); - using var result = new MemoryStream(Size); - - return result - .Append(header.ToArray()) - .Append(Questions.Select(q => q.ToArray())) - .ToArray(); - } - - public override string ToString() - { - UpdateHeader(); - - return Json.Serialize(this, true); - } - - private void UpdateHeader() - { - header.QuestionCount = Questions.Count; - } - } - - public class DnsTcpRequestResolver : IDnsRequestResolver - { - public async Task Request(DnsClientRequest request) - { - var tcp = new TcpClient(); - - try - { -#if !NET461 - await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false); -#else - tcp.Client.Connect(request.Dns); -#endif - var stream = tcp.GetStream(); - var buffer = request.ToArray(); - var length = BitConverter.GetBytes((ushort)buffer.Length); - - if (BitConverter.IsLittleEndian) - Array.Reverse(length); - - await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false); - await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false); - - buffer = new byte[2]; - await Read(stream, buffer).ConfigureAwait(false); - - if (BitConverter.IsLittleEndian) - Array.Reverse(buffer); - - buffer = new byte[BitConverter.ToUInt16(buffer, 0)]; - await Read(stream, buffer).ConfigureAwait(false); - - var response = DnsResponse.FromArray(buffer); - - return new DnsClientResponse(request, response, buffer); - } - finally - { -#if NET461 - tcp.Close(); -#else - tcp.Dispose(); -#endif - } - } - - private static async Task Read(Stream stream, byte[] buffer) - { - var length = buffer.Length; - var offset = 0; - int size; - - while (length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0) - { - offset += size; - length -= size; - } - - if (length > 0) - { - throw new IOException("Unexpected end of stream"); - } - } - } - - public class DnsUdpRequestResolver : IDnsRequestResolver - { - private readonly IDnsRequestResolver _fallback; - - public DnsUdpRequestResolver(IDnsRequestResolver fallback) - { - _fallback = fallback; - } - - public DnsUdpRequestResolver() - { - _fallback = new DnsNullRequestResolver(); - } - - public async Task Request(DnsClientRequest request) - { - var udp = new UdpClient(); - var dns = request.Dns; - - try - { - udp.Client.SendTimeout = 7000; - udp.Client.ReceiveTimeout = 7000; -#if !NET461 - await udp.Client.ConnectAsync(dns).ConfigureAwait(false); -#else - udp.Client.Connect(dns); -#endif - - await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false); - - var bufferList = new List(); - - do - { - var tempBuffer = new byte[1024]; - var receiveCount = udp.Client.Receive(tempBuffer); - bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount)); - } - while (udp.Client.Available > 0 || bufferList.Count == 0); - - var buffer = bufferList.ToArray(); - var response = DnsResponse.FromArray(buffer); - - return response.IsTruncated - ? await _fallback.Request(request).ConfigureAwait(false) - : new DnsClientResponse(request, response, buffer); - } - finally - { -#if NET461 - udp.Close(); -#else - udp.Dispose(); -#endif - } - } - } - - public class DnsNullRequestResolver : IDnsRequestResolver - { - public Task Request(DnsClientRequest request) => throw new DnsQueryException("Request failed"); - } - - // 12 bytes message header - [StructEndianness(Endianness.Big)] - [StructLayout(LayoutKind.Sequential, Pack = 1)] - public struct DnsHeader - { - public const int SIZE = 12; - - private ushort id; - - private byte flag0; - private byte flag1; - - // Question count: number of questions in the Question section - private ushort questionCount; - - // Answer record count: number of records in the Answer section - private ushort answerCount; - - // Authority record count: number of records in the Authority section - private ushort authorityCount; - - // Additional record count: number of records in the Additional section - private ushort addtionalCount; - - public int Id - { - get => id; - set => id = (ushort)value; - } - - public int QuestionCount - { - get => questionCount; - set => questionCount = (ushort)value; - } - - public int AnswerRecordCount - { - get => answerCount; - set => answerCount = (ushort)value; - } - - public int AuthorityRecordCount - { - get => authorityCount; - set => authorityCount = (ushort)value; - } - - public int AdditionalRecordCount - { - get => addtionalCount; - set => addtionalCount = (ushort)value; - } - - public bool Response - { - get => Qr == 1; - set => Qr = Convert.ToByte(value); - } - - public DnsOperationCode OperationCode - { - get => (DnsOperationCode)Opcode; - set => Opcode = (byte)value; - } - - public bool AuthorativeServer - { - get => Aa == 1; - set => Aa = Convert.ToByte(value); - } - - public bool Truncated - { - get => Tc == 1; - set => Tc = Convert.ToByte(value); - } - - public bool RecursionDesired - { - get => Rd == 1; - set => Rd = Convert.ToByte(value); - } - - public bool RecursionAvailable - { - get => Ra == 1; - set => Ra = Convert.ToByte(value); - } - - public DnsResponseCode ResponseCode - { - get => (DnsResponseCode)RCode; - set => RCode = (byte)value; - } - - public int Size => SIZE; - - // Query/Response Flag - private byte Qr - { - get => Flag0.GetBitValueAt(7); - set => Flag0 = Flag0.SetBitValueAt(7, 1, value); - } - - // Operation Code - private byte Opcode - { - get => Flag0.GetBitValueAt(3, 4); - set => Flag0 = Flag0.SetBitValueAt(3, 4, value); - } - - // Authorative Answer Flag - private byte Aa - { - get => Flag0.GetBitValueAt(2); - set => Flag0 = Flag0.SetBitValueAt(2, 1, value); - } - - // Truncation Flag - private byte Tc - { - get => Flag0.GetBitValueAt(1); - set => Flag0 = Flag0.SetBitValueAt(1, 1, value); - } - - // Recursion Desired - private byte Rd - { - get => Flag0.GetBitValueAt(0); - set => Flag0 = Flag0.SetBitValueAt(0, 1, value); - } - - // Recursion Available - private byte Ra - { - get => Flag1.GetBitValueAt(7); - set => Flag1 = Flag1.SetBitValueAt(7, 1, value); - } - - // Zero (Reserved) - private byte Z - { - get => Flag1.GetBitValueAt(4, 3); - set { } - } - - // Response Code - private byte RCode - { - get => Flag1.GetBitValueAt(0, 4); - set => Flag1 = Flag1.SetBitValueAt(0, 4, value); - } - - private byte Flag0 - { - get => flag0; - set => flag0 = value; - } - - private byte Flag1 - { - get => flag1; - set => flag1 = value; - } - - public static DnsHeader FromArray(byte[] header) => - header.Length < SIZE - ? throw new ArgumentException("Header length too small") - : header.ToStruct(0, SIZE); - - public byte[] ToArray() => this.ToBytes(); - - public override string ToString() - => Json.SerializeExcluding(this, true, nameof(Size)); - } - - public class DnsDomain : IComparable - { - private readonly string[] _labels; - - public DnsDomain(string domain) - : this(domain.Split('.')) - { - } - - public DnsDomain(string[] labels) - { - _labels = labels; - } - - public int Size => _labels.Sum(l => l.Length) + _labels.Length + 1; - - public static DnsDomain FromArray(byte[] message, int offset) - => FromArray(message, offset, out offset); - - public static DnsDomain FromArray(byte[] message, int offset, out int endOffset) - { - var labels = new List(); - var endOffsetAssigned = false; - endOffset = 0; - byte lengthOrPointer; - - while ((lengthOrPointer = message[offset++]) > 0) - { - // Two heighest bits are set (pointer) - if (lengthOrPointer.GetBitValueAt(6, 2) == 3) - { - if (!endOffsetAssigned) - { - endOffsetAssigned = true; - endOffset = offset + 1; - } - - ushort pointer = lengthOrPointer.GetBitValueAt(0, 6); - offset = (pointer << 8) | message[offset]; - - continue; - } - - if (lengthOrPointer.GetBitValueAt(6, 2) != 0) - { - throw new ArgumentException("Unexpected bit pattern in label length"); - } - - var length = lengthOrPointer; - var label = new byte[length]; - Array.Copy(message, offset, label, 0, length); - - labels.Add(label); - - offset += length; - } - - if (!endOffsetAssigned) - { - endOffset = offset; - } - - return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray()); - } - - public static DnsDomain PointerName(IPAddress ip) - => new DnsDomain(FormatReverseIP(ip)); - - public byte[] ToArray() - { - var result = new byte[Size]; - var offset = 0; - - foreach (var l in _labels.Select(label => Encoding.ASCII.GetBytes(label))) - { - result[offset++] = (byte)l.Length; - l.CopyTo(result, offset); - - offset += l.Length; - } - - result[offset] = 0; - - return result; - } - - public override string ToString() - => string.Join(".", _labels); - - public int CompareTo(DnsDomain other) - => string.Compare(ToString(), other.ToString(), StringComparison.Ordinal); - - public override bool Equals(object obj) - => obj is DnsDomain domain && CompareTo(domain) == 0; - - public override int GetHashCode() => ToString().GetHashCode(); - - private static string FormatReverseIP(IPAddress ip) - { - var address = ip.GetAddressBytes(); - - if (address.Length == 4) - { - return string.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa"; - } - - var nibbles = new byte[address.Length * 2]; - - for (int i = 0, j = 0; i < address.Length; i++, j = 2 * i) - { - var b = address[i]; - - nibbles[j] = b.GetBitValueAt(4, 4); - nibbles[j + 1] = b.GetBitValueAt(0, 4); - } - - return string.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa"; - } - } - - public class DnsQuestion : IDnsMessageEntry - { - private readonly DnsRecordType _type; - private readonly DnsRecordClass _klass; - - public static IList GetAllFromArray(byte[] message, int offset, int questionCount) => - GetAllFromArray(message, offset, questionCount, out offset); - - public static IList GetAllFromArray( - byte[] message, - int offset, - int questionCount, - out int endOffset) - { - IList questions = new List(questionCount); - - for (var i = 0; i < questionCount; i++) - { - questions.Add(FromArray(message, offset, out offset)); - } - - endOffset = offset; - return questions; - } - - public static DnsQuestion FromArray(byte[] message, int offset, out int endOffset) - { - var domain = DnsDomain.FromArray(message, offset, out offset); - var tail = message.ToStruct(offset, Tail.SIZE); - - endOffset = offset + Tail.SIZE; - - return new DnsQuestion(domain, tail.Type, tail.Class); - } - - public DnsQuestion( - DnsDomain domain, - DnsRecordType type = DnsRecordType.A, - DnsRecordClass klass = DnsRecordClass.IN) - { - Name = domain; - _type = type; - _klass = klass; - } - - public DnsDomain Name { get; } - - public DnsRecordType Type => _type; - - public DnsRecordClass Class => _klass; - - public int Size => Name.Size + Tail.SIZE; - - public byte[] ToArray() => - new MemoryStream(Size) - .Append(Name.ToArray()) - .Append(new Tail { Type = Type, Class = Class }.ToBytes()) - .ToArray(); - - public override string ToString() - => Json.SerializeOnly(this, true, nameof(Name), nameof(Type), nameof(Class)); - - [StructEndianness(Endianness.Big)] - [StructLayout(LayoutKind.Sequential, Pack = 2)] - private struct Tail - { - public const int SIZE = 4; - - private ushort type; - private ushort klass; - - public DnsRecordType Type - { - get => (DnsRecordType)type; - set => type = (ushort)value; - } - - public DnsRecordClass Class - { - get => (DnsRecordClass)klass; - set => klass = (ushort)value; - } - } - } - } +#nullable enable +using Swan.Formatters; +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Text; + +namespace Swan.Net.Dns { + /// + /// DnsClient Request inner class. + /// + internal partial class DnsClient { + public class DnsClientRequest : IDnsRequest { + private readonly IDnsRequestResolver _resolver; + private readonly IDnsRequest _request; + + public DnsClientRequest(IPEndPoint dns, IDnsRequest? request = null, IDnsRequestResolver? resolver = null) { + this.Dns = dns; + this._request = request == null ? new DnsRequest() : new DnsRequest(request); + this._resolver = resolver ?? new DnsUdpRequestResolver(); + } + + public Int32 Id { + get => this._request.Id; + set => this._request.Id = value; + } + + public DnsOperationCode OperationCode { + get => this._request.OperationCode; + set => this._request.OperationCode = value; + } + + public Boolean RecursionDesired { + get => this._request.RecursionDesired; + set => this._request.RecursionDesired = value; + } + + public IList Questions => this._request.Questions; + + public Int32 Size => this._request.Size; + + public IPEndPoint Dns { + get; set; + } + + public Byte[] ToArray() => this._request.ToArray(); + + public override String ToString() => this._request.ToString()!; + + /// + /// Resolves this request into a response using the provided DNS information. The given + /// request strategy is used to retrieve the response. + /// + /// Throw if a malformed response is received from the server. + /// Thrown if a IO error occurs. + /// Thrown if a the reading or writing to the socket fails. + /// The response received from server. + public async Task Resolve() { + try { + DnsClientResponse response = await this._resolver.Request(this).ConfigureAwait(false); + + if(response.Id != this.Id) { + throw new DnsQueryException(response, "Mismatching request/response IDs"); + } + + if(response.ResponseCode != DnsResponseCode.NoError) { + throw new DnsQueryException(response); + } + + return response; + } catch(Exception e) { + if(e is ArgumentException || e is SocketException) { + throw new DnsQueryException("Invalid response", e); + } + + throw; + } + } + } + + public class DnsRequest : IDnsRequest { + private static readonly Random Random = new Random(); + + private DnsHeader header; + + public DnsRequest() { + this.Questions = new List(); + this.header = new DnsHeader { + OperationCode = DnsOperationCode.Query, + Response = false, + Id = Random.Next(UInt16.MaxValue), + }; + } + + public DnsRequest(IDnsRequest request) { + this.header = new DnsHeader(); + this.Questions = new List(request.Questions); + + this.header.Response = false; + + this.Id = request.Id; + this.OperationCode = request.OperationCode; + this.RecursionDesired = request.RecursionDesired; + } + + public IList Questions { + get; + } + + public Int32 Size => this.header.Size + this.Questions.Sum(q => q.Size); + + public Int32 Id { + get => this.header.Id; + set => this.header.Id = value; + } + + public DnsOperationCode OperationCode { + get => this.header.OperationCode; + set => this.header.OperationCode = value; + } + + public Boolean RecursionDesired { + get => this.header.RecursionDesired; + set => this.header.RecursionDesired = value; + } + + public Byte[] ToArray() { + this.UpdateHeader(); + using MemoryStream result = new MemoryStream(this.Size); + + return result.Append(this.header.ToArray()).Append(this.Questions.Select(q => q.ToArray())).ToArray(); + } + + public override String ToString() { + this.UpdateHeader(); + + return Json.Serialize(this, true); + } + + private void UpdateHeader() => this.header.QuestionCount = this.Questions.Count; + } + + public class DnsTcpRequestResolver : IDnsRequestResolver { + public async Task Request(DnsClientRequest request) { + TcpClient tcp = new TcpClient(); + + try { + await tcp.Client.ConnectAsync(request.Dns).ConfigureAwait(false); + + NetworkStream stream = tcp.GetStream(); + Byte[] buffer = request.ToArray(); + Byte[] length = BitConverter.GetBytes((UInt16)buffer.Length); + + if(BitConverter.IsLittleEndian) { + Array.Reverse(length); + } + + await stream.WriteAsync(length, 0, length.Length).ConfigureAwait(false); + await stream.WriteAsync(buffer, 0, buffer.Length).ConfigureAwait(false); + + buffer = new Byte[2]; + await Read(stream, buffer).ConfigureAwait(false); + + if(BitConverter.IsLittleEndian) { + Array.Reverse(buffer); + } + + buffer = new Byte[BitConverter.ToUInt16(buffer, 0)]; + await Read(stream, buffer).ConfigureAwait(false); + + DnsResponse response = DnsResponse.FromArray(buffer); + + return new DnsClientResponse(request, response, buffer); + } finally { + tcp.Dispose(); + } + } + + private static async Task Read(Stream stream, Byte[] buffer) { + Int32 length = buffer.Length; + Int32 offset = 0; + Int32 size; + + while(length > 0 && (size = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false)) > 0) { + offset += size; + length -= size; + } + + if(length > 0) { + throw new IOException("Unexpected end of stream"); + } + } + } + + public class DnsUdpRequestResolver : IDnsRequestResolver { + private readonly IDnsRequestResolver _fallback; + + public DnsUdpRequestResolver(IDnsRequestResolver fallback) => this._fallback = fallback; + + public DnsUdpRequestResolver() => this._fallback = new DnsNullRequestResolver(); + + public async Task Request(DnsClientRequest request) { + UdpClient udp = new UdpClient(); + IPEndPoint dns = request.Dns; + + try { + udp.Client.SendTimeout = 7000; + udp.Client.ReceiveTimeout = 7000; + + await udp.Client.ConnectAsync(dns).ConfigureAwait(false); + + + _ = await udp.SendAsync(request.ToArray(), request.Size).ConfigureAwait(false); + + List bufferList = new List(); + + do { + Byte[] tempBuffer = new Byte[1024]; + Int32 receiveCount = udp.Client.Receive(tempBuffer); + bufferList.AddRange(tempBuffer.Skip(0).Take(receiveCount)); + } + while(udp.Client.Available > 0 || bufferList.Count == 0); + + Byte[] buffer = bufferList.ToArray(); + DnsResponse response = DnsResponse.FromArray(buffer); + + return response.IsTruncated + ? await this._fallback.Request(request).ConfigureAwait(false) + : new DnsClientResponse(request, response, buffer); + } finally { + udp.Dispose(); + } + } + } + + public class DnsNullRequestResolver : IDnsRequestResolver { + public Task Request(DnsClientRequest request) => throw new DnsQueryException("Request failed"); + } + + // 12 bytes message header + [StructEndianness(Endianness.Big)] + [StructLayout(LayoutKind.Sequential, Pack = 1)] + public struct DnsHeader { + public const Int32 SIZE = 12; + + private UInt16 id; + + // Question count: number of questions in the Question section + private UInt16 questionCount; + + // Answer record count: number of records in the Answer section + private UInt16 answerCount; + + // Authority record count: number of records in the Authority section + private UInt16 authorityCount; + + // Additional record count: number of records in the Additional section + private UInt16 addtionalCount; + + public Int32 Id { + get => this.id; + set => this.id = (UInt16)value; + } + + public Int32 QuestionCount { + get => this.questionCount; + set => this.questionCount = (UInt16)value; + } + + public Int32 AnswerRecordCount { + get => this.answerCount; + set => this.answerCount = (UInt16)value; + } + + public Int32 AuthorityRecordCount { + get => this.authorityCount; + set => this.authorityCount = (UInt16)value; + } + + public Int32 AdditionalRecordCount { + get => this.addtionalCount; + set => this.addtionalCount = (UInt16)value; + } + + public Boolean Response { + get => this.Qr == 1; + set => this.Qr = Convert.ToByte(value); + } + + public DnsOperationCode OperationCode { + get => (DnsOperationCode)this.Opcode; + set => this.Opcode = (Byte)value; + } + + public Boolean AuthorativeServer { + get => this.Aa == 1; + set => this.Aa = Convert.ToByte(value); + } + + public Boolean Truncated { + get => this.Tc == 1; + set => this.Tc = Convert.ToByte(value); + } + + public Boolean RecursionDesired { + get => this.Rd == 1; + set => this.Rd = Convert.ToByte(value); + } + + public Boolean RecursionAvailable { + get => this.Ra == 1; + set => this.Ra = Convert.ToByte(value); + } + + public DnsResponseCode ResponseCode { + get => (DnsResponseCode)this.RCode; + set => this.RCode = (Byte)value; + } + + public Int32 Size => SIZE; + + // Query/Response Flag + private Byte Qr { + get => this.Flag0.GetBitValueAt(7); + set => this.Flag0 = this.Flag0.SetBitValueAt(7, 1, value); + } + + // Operation Code + private Byte Opcode { + get => this.Flag0.GetBitValueAt(3, 4); + set => this.Flag0 = this.Flag0.SetBitValueAt(3, 4, value); + } + + // Authorative Answer Flag + private Byte Aa { + get => this.Flag0.GetBitValueAt(2); + set => this.Flag0 = this.Flag0.SetBitValueAt(2, 1, value); + } + + // Truncation Flag + private Byte Tc { + get => this.Flag0.GetBitValueAt(1); + set => this.Flag0 = this.Flag0.SetBitValueAt(1, 1, value); + } + + // Recursion Desired + private Byte Rd { + get => this.Flag0.GetBitValueAt(0); + set => this.Flag0 = this.Flag0.SetBitValueAt(0, 1, value); + } + + // Recursion Available + private Byte Ra { + get => this.Flag1.GetBitValueAt(7); + set => this.Flag1 = this.Flag1.SetBitValueAt(7, 1, value); + } + + // Zero (Reserved) + private Byte Z { + get => this.Flag1.GetBitValueAt(4, 3); + set { + } + } + + // Response Code + private Byte RCode { + get => this.Flag1.GetBitValueAt(0, 4); + set => this.Flag1 = this.Flag1.SetBitValueAt(0, 4, value); + } + + private Byte Flag0 { + get; + set; + } + + private Byte Flag1 { + get; + set; + } + + public static DnsHeader FromArray(Byte[] header) => header.Length < SIZE ? throw new ArgumentException("Header length too small") : header.ToStruct(0, SIZE); + + public Byte[] ToArray() => this.ToBytes(); + + public override String ToString() => Json.SerializeExcluding(this, true, nameof(this.Size)); + } + + public class DnsDomain : IComparable { + private readonly String[] _labels; + + public DnsDomain(String domain) : this(domain.Split('.')) { + } + + public DnsDomain(String[] labels) => this._labels = labels; + + public Int32 Size => this._labels.Sum(l => l.Length) + this._labels.Length + 1; + + public static DnsDomain FromArray(Byte[] message, Int32 offset) => FromArray(message, offset, out _); + + public static DnsDomain FromArray(Byte[] message, Int32 offset, out Int32 endOffset) { + List labels = new List(); + Boolean endOffsetAssigned = false; + endOffset = 0; + Byte lengthOrPointer; + + while((lengthOrPointer = message[offset++]) > 0) { + // Two heighest bits are set (pointer) + if(lengthOrPointer.GetBitValueAt(6, 2) == 3) { + if(!endOffsetAssigned) { + endOffsetAssigned = true; + endOffset = offset + 1; + } + + UInt16 pointer = lengthOrPointer.GetBitValueAt(0, 6); + offset = (pointer << 8) | message[offset]; + + continue; + } + + if(lengthOrPointer.GetBitValueAt(6, 2) != 0) { + throw new ArgumentException("Unexpected bit pattern in label length"); + } + + Byte length = lengthOrPointer; + Byte[] label = new Byte[length]; + Array.Copy(message, offset, label, 0, length); + + labels.Add(label); + + offset += length; + } + + if(!endOffsetAssigned) { + endOffset = offset; + } + + return new DnsDomain(labels.Select(l => l.ToText(Encoding.ASCII)).ToArray()); + } + + public static DnsDomain PointerName(IPAddress ip) => new DnsDomain(FormatReverseIP(ip)); + + public Byte[] ToArray() { + Byte[] result = new Byte[this.Size]; + Int32 offset = 0; + + foreach(Byte[] l in this._labels.Select(label => Encoding.ASCII.GetBytes(label))) { + result[offset++] = (Byte)l.Length; + l.CopyTo(result, offset); + + offset += l.Length; + } + + result[offset] = 0; + + return result; + } + + public override String ToString() => String.Join(".", this._labels); + + public Int32 CompareTo(DnsDomain other) => String.Compare(this.ToString(), other.ToString(), StringComparison.Ordinal); + + public override Boolean Equals(Object? obj) => obj is DnsDomain domain && this.CompareTo(domain) == 0; + + public override Int32 GetHashCode() => this.ToString().GetHashCode(); + + private static String FormatReverseIP(IPAddress ip) { + Byte[] address = ip.GetAddressBytes(); + + if(address.Length == 4) { + return String.Join(".", address.Reverse().Select(b => b.ToString())) + ".in-addr.arpa"; + } + + Byte[] nibbles = new Byte[address.Length * 2]; + + for(Int32 i = 0, j = 0; i < address.Length; i++, j = 2 * i) { + Byte b = address[i]; + + nibbles[j] = b.GetBitValueAt(4, 4); + nibbles[j + 1] = b.GetBitValueAt(0, 4); + } + + return String.Join(".", nibbles.Reverse().Select(b => b.ToString("x"))) + ".ip6.arpa"; + } + } + + public class DnsQuestion : IDnsMessageEntry { + public static IList GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount) => GetAllFromArray(message, offset, questionCount, out _); + + public static IList GetAllFromArray(Byte[] message, Int32 offset, Int32 questionCount, out Int32 endOffset) { + IList questions = new List(questionCount); + + for(Int32 i = 0; i < questionCount; i++) { + questions.Add(FromArray(message, offset, out offset)); + } + + endOffset = offset; + return questions; + } + + public static DnsQuestion FromArray(Byte[] message, Int32 offset, out Int32 endOffset) { + DnsDomain domain = DnsDomain.FromArray(message, offset, out offset); + Tail tail = message.ToStruct(offset, Tail.SIZE); + + endOffset = offset + Tail.SIZE; + + return new DnsQuestion(domain, tail.Type, tail.Class); + } + + public DnsQuestion(DnsDomain domain, DnsRecordType type = DnsRecordType.A, DnsRecordClass klass = DnsRecordClass.IN) { + this.Name = domain; + this.Type = type; + this.Class = klass; + } + + public DnsDomain Name { + get; + } + + public DnsRecordType Type { + get; + } + + public DnsRecordClass Class { + get; + } + + public Int32 Size => this.Name.Size + Tail.SIZE; + + public Byte[] ToArray() => new MemoryStream(this.Size).Append(this.Name.ToArray()).Append(new Tail { Type = Type, Class = Class }.ToBytes()).ToArray(); + + public override String ToString() => Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class)); + + [StructEndianness(Endianness.Big)] + [StructLayout(LayoutKind.Sequential, Pack = 2)] + private struct Tail { + public const Int32 SIZE = 4; + + private UInt16 type; + private UInt16 klass; + + public DnsRecordType Type { + get => (DnsRecordType)this.type; + set => this.type = (UInt16)value; + } + + public DnsRecordClass Class { + get => (DnsRecordClass)this.klass; + set => this.klass = (UInt16)value; + } + } + } + } } diff --git a/Swan/Net/Dns/DnsClient.ResourceRecords.cs b/Swan/Net/Dns/DnsClient.ResourceRecords.cs index 1ad6c38..10cf023 100644 --- a/Swan/Net/Dns/DnsClient.ResourceRecords.cs +++ b/Swan/Net/Dns/DnsClient.ResourceRecords.cs @@ -1,419 +1,344 @@ -namespace Swan.Net.Dns -{ - using Formatters; - using System; - using System.Collections.Generic; - using System.IO; - using System.Net; - using System.Runtime.InteropServices; - - /// - /// DnsClient public methods. - /// - internal partial class DnsClient - { - public abstract class DnsResourceRecordBase : IDnsResourceRecord - { - private readonly IDnsResourceRecord _record; - - protected DnsResourceRecordBase(IDnsResourceRecord record) - { - _record = record; - } - - public DnsDomain Name => _record.Name; - - public DnsRecordType Type => _record.Type; - - public DnsRecordClass Class => _record.Class; - - public TimeSpan TimeToLive => _record.TimeToLive; - - public int DataLength => _record.DataLength; - - public byte[] Data => _record.Data; - - public int Size => _record.Size; - - protected virtual string[] IncludedProperties - => new[] {nameof(Name), nameof(Type), nameof(Class), nameof(TimeToLive), nameof(DataLength)}; - - public byte[] ToArray() => _record.ToArray(); - - public override string ToString() - => Json.SerializeOnly(this, true, IncludedProperties); - } - - public class DnsResourceRecord : IDnsResourceRecord - { - public DnsResourceRecord( - DnsDomain domain, - byte[] data, - DnsRecordType type, - DnsRecordClass klass = DnsRecordClass.IN, - TimeSpan ttl = default) - { - Name = domain; - Type = type; - Class = klass; - TimeToLive = ttl; - Data = data; - } - - public DnsDomain Name { get; } - - public DnsRecordType Type { get; } - - public DnsRecordClass Class { get; } - - public TimeSpan TimeToLive { get; } - - public int DataLength => Data.Length; - - public byte[] Data { get; } - - public int Size => Name.Size + Tail.SIZE + Data.Length; - - public static DnsResourceRecord FromArray(byte[] message, int offset, out int endOffset) - { - var domain = DnsDomain.FromArray(message, offset, out offset); - var tail = message.ToStruct(offset, Tail.SIZE); - - var data = new byte[tail.DataLength]; - - offset += Tail.SIZE; - Array.Copy(message, offset, data, 0, data.Length); - - endOffset = offset + data.Length; - - return new DnsResourceRecord(domain, data, tail.Type, tail.Class, tail.TimeToLive); - } - - public byte[] ToArray() => - new MemoryStream(Size) - .Append(Name.ToArray()) - .Append(new Tail() - { - Type = Type, - Class = Class, - TimeToLive = TimeToLive, - DataLength = Data.Length, - }.ToBytes()) - .Append(Data) - .ToArray(); - - public override string ToString() - { - return Json.SerializeOnly( - this, - true, - nameof(Name), - nameof(Type), - nameof(Class), - nameof(TimeToLive), - nameof(DataLength)); - } - - [StructEndianness(Endianness.Big)] - [StructLayout(LayoutKind.Sequential, Pack = 2)] - private struct Tail - { - public const int SIZE = 10; - - private ushort type; - private ushort klass; - private uint ttl; - private ushort dataLength; - - public DnsRecordType Type - { - get => (DnsRecordType) type; - set => type = (ushort) value; - } - - public DnsRecordClass Class - { - get => (DnsRecordClass) klass; - set => klass = (ushort) value; - } - - public TimeSpan TimeToLive - { - get => TimeSpan.FromSeconds(ttl); - set => ttl = (uint) value.TotalSeconds; - } - - public int DataLength - { - get => dataLength; - set => dataLength = (ushort) value; - } - } - } - - public class DnsPointerResourceRecord : DnsResourceRecordBase - { - public DnsPointerResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset) - : base(record) - { - PointerDomainName = DnsDomain.FromArray(message, dataOffset); - } - - public DnsDomain PointerDomainName { get; } - - protected override string[] IncludedProperties - { - get - { - var temp = new List(base.IncludedProperties) {nameof(PointerDomainName)}; - return temp.ToArray(); - } - } - } - - public class DnsIPAddressResourceRecord : DnsResourceRecordBase - { - public DnsIPAddressResourceRecord(IDnsResourceRecord record) - : base(record) - { - IPAddress = new IPAddress(Data); - } - - public IPAddress IPAddress { get; } - - protected override string[] IncludedProperties - => new List(base.IncludedProperties) {nameof(IPAddress)}.ToArray(); - } - - public class DnsNameServerResourceRecord : DnsResourceRecordBase - { - public DnsNameServerResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset) - : base(record) - { - NSDomainName = DnsDomain.FromArray(message, dataOffset); - } - - public DnsDomain NSDomainName { get; } - - protected override string[] IncludedProperties - => new List(base.IncludedProperties) {nameof(NSDomainName)}.ToArray(); - } - - public class DnsCanonicalNameResourceRecord : DnsResourceRecordBase - { - public DnsCanonicalNameResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset) - : base(record) - { - CanonicalDomainName = DnsDomain.FromArray(message, dataOffset); - } - - public DnsDomain CanonicalDomainName { get; } - - protected override string[] IncludedProperties - => new List(base.IncludedProperties) {nameof(CanonicalDomainName)}.ToArray(); - } - - public class DnsMailExchangeResourceRecord : DnsResourceRecordBase - { - private const int PreferenceSize = 2; - - public DnsMailExchangeResourceRecord( - IDnsResourceRecord record, - byte[] message, - int dataOffset) - : base(record) - { - var preference = new byte[PreferenceSize]; - Array.Copy(message, dataOffset, preference, 0, preference.Length); - - if (BitConverter.IsLittleEndian) - { - Array.Reverse(preference); - } - - dataOffset += PreferenceSize; - - Preference = BitConverter.ToUInt16(preference, 0); - ExchangeDomainName = DnsDomain.FromArray(message, dataOffset); - } - - public int Preference { get; } - - public DnsDomain ExchangeDomainName { get; } - - protected override string[] IncludedProperties => new List(base.IncludedProperties) - { - nameof(Preference), - nameof(ExchangeDomainName), - }.ToArray(); - } - - public class DnsStartOfAuthorityResourceRecord : DnsResourceRecordBase - { - public DnsStartOfAuthorityResourceRecord(IDnsResourceRecord record, byte[] message, int dataOffset) - : base(record) - { - MasterDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset); - ResponsibleDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset); - - var tail = message.ToStruct(dataOffset, Options.SIZE); - - SerialNumber = tail.SerialNumber; - RefreshInterval = tail.RefreshInterval; - RetryInterval = tail.RetryInterval; - ExpireInterval = tail.ExpireInterval; - MinimumTimeToLive = tail.MinimumTimeToLive; - } - - public DnsStartOfAuthorityResourceRecord( - DnsDomain domain, - DnsDomain master, - DnsDomain responsible, - long serial, - TimeSpan refresh, - TimeSpan retry, - TimeSpan expire, - TimeSpan minTtl, - TimeSpan ttl = default) - : base(Create(domain, master, responsible, serial, refresh, retry, expire, minTtl, ttl)) - { - MasterDomainName = master; - ResponsibleDomainName = responsible; - - SerialNumber = serial; - RefreshInterval = refresh; - RetryInterval = retry; - ExpireInterval = expire; - MinimumTimeToLive = minTtl; - } - - public DnsDomain MasterDomainName { get; } - - public DnsDomain ResponsibleDomainName { get; } - - public long SerialNumber { get; } - - public TimeSpan RefreshInterval { get; } - - public TimeSpan RetryInterval { get; } - - public TimeSpan ExpireInterval { get; } - - public TimeSpan MinimumTimeToLive { get; } - - protected override string[] IncludedProperties => new List(base.IncludedProperties) - { - nameof(MasterDomainName), - nameof(ResponsibleDomainName), - nameof(SerialNumber), - }.ToArray(); - - private static IDnsResourceRecord Create( - DnsDomain domain, - DnsDomain master, - DnsDomain responsible, - long serial, - TimeSpan refresh, - TimeSpan retry, - TimeSpan expire, - TimeSpan minTtl, - TimeSpan ttl) - { - var data = new MemoryStream(Options.SIZE + master.Size + responsible.Size); - var tail = new Options - { - SerialNumber = serial, - RefreshInterval = refresh, - RetryInterval = retry, - ExpireInterval = expire, - MinimumTimeToLive = minTtl, - }; - - data.Append(master.ToArray()).Append(responsible.ToArray()).Append(tail.ToBytes()); - - return new DnsResourceRecord(domain, data.ToArray(), DnsRecordType.SOA, DnsRecordClass.IN, ttl); - } - - [StructEndianness(Endianness.Big)] - [StructLayout(LayoutKind.Sequential, Pack = 4)] - public struct Options - { - public const int SIZE = 20; - - private uint serialNumber; - private uint refreshInterval; - private uint retryInterval; - private uint expireInterval; - private uint ttl; - - public long SerialNumber - { - get => serialNumber; - set => serialNumber = (uint) value; - } - - public TimeSpan RefreshInterval - { - get => TimeSpan.FromSeconds(refreshInterval); - set => refreshInterval = (uint) value.TotalSeconds; - } - - public TimeSpan RetryInterval - { - get => TimeSpan.FromSeconds(retryInterval); - set => retryInterval = (uint) value.TotalSeconds; - } - - public TimeSpan ExpireInterval - { - get => TimeSpan.FromSeconds(expireInterval); - set => expireInterval = (uint) value.TotalSeconds; - } - - public TimeSpan MinimumTimeToLive - { - get => TimeSpan.FromSeconds(ttl); - set => ttl = (uint) value.TotalSeconds; - } - } - } - - private static class DnsResourceRecordFactory - { - public static IList GetAllFromArray( - byte[] message, - int offset, - int count, - out int endOffset) - { - var result = new List(count); - - for (var i = 0; i < count; i++) - { - result.Add(GetFromArray(message, offset, out offset)); - } - - endOffset = offset; - return result; - } - - private static IDnsResourceRecord GetFromArray(byte[] message, int offset, out int endOffset) - { - var record = DnsResourceRecord.FromArray(message, offset, out endOffset); - var dataOffset = endOffset - record.DataLength; - - return record.Type switch - { - DnsRecordType.A => (IDnsResourceRecord) new DnsIPAddressResourceRecord(record), - DnsRecordType.AAAA => new DnsIPAddressResourceRecord(record), - DnsRecordType.NS => new DnsNameServerResourceRecord(record, message, dataOffset), - DnsRecordType.CNAME => new DnsCanonicalNameResourceRecord(record, message, dataOffset), - DnsRecordType.SOA => new DnsStartOfAuthorityResourceRecord(record, message, dataOffset), - DnsRecordType.PTR => new DnsPointerResourceRecord(record, message, dataOffset), - DnsRecordType.MX => new DnsMailExchangeResourceRecord(record, message, dataOffset), - _ => record - }; - } - } - } +using Swan.Formatters; +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Runtime.InteropServices; + +namespace Swan.Net.Dns { + /// + /// DnsClient public methods. + /// + internal partial class DnsClient { + public abstract class DnsResourceRecordBase : IDnsResourceRecord { + private readonly IDnsResourceRecord _record; + + protected DnsResourceRecordBase(IDnsResourceRecord record) => this._record = record; + + public DnsDomain Name => this._record.Name; + + public DnsRecordType Type => this._record.Type; + + public DnsRecordClass Class => this._record.Class; + + public TimeSpan TimeToLive => this._record.TimeToLive; + + public Int32 DataLength => this._record.DataLength; + + public Byte[] Data => this._record.Data; + + public Int32 Size => this._record.Size; + + protected virtual String[] IncludedProperties => new[] { nameof(this.Name), nameof(this.Type), nameof(this.Class), nameof(this.TimeToLive), nameof(this.DataLength) }; + + public Byte[] ToArray() => this._record.ToArray(); + + public override String ToString() => Json.SerializeOnly(this, true, this.IncludedProperties); + } + + public class DnsResourceRecord : IDnsResourceRecord { + public DnsResourceRecord(DnsDomain domain, Byte[] data, DnsRecordType type, DnsRecordClass klass = DnsRecordClass.IN, TimeSpan ttl = default) { + this.Name = domain; + this.Type = type; + this.Class = klass; + this.TimeToLive = ttl; + this.Data = data; + } + + public DnsDomain Name { + get; + } + + public DnsRecordType Type { + get; + } + + public DnsRecordClass Class { + get; + } + + public TimeSpan TimeToLive { + get; + } + + public Int32 DataLength => this.Data.Length; + + public Byte[] Data { + get; + } + + public Int32 Size => this.Name.Size + Tail.SIZE + this.Data.Length; + + public static DnsResourceRecord FromArray(Byte[] message, Int32 offset, out Int32 endOffset) { + DnsDomain domain = DnsDomain.FromArray(message, offset, out offset); + Tail tail = message.ToStruct(offset, Tail.SIZE); + + Byte[] data = new Byte[tail.DataLength]; + + offset += Tail.SIZE; + Array.Copy(message, offset, data, 0, data.Length); + + endOffset = offset + data.Length; + + return new DnsResourceRecord(domain, data, tail.Type, tail.Class, tail.TimeToLive); + } + + public Byte[] ToArray() => new MemoryStream(this.Size).Append(this.Name.ToArray()).Append(new Tail() { Type = Type, Class = Class, TimeToLive = TimeToLive, DataLength = this.Data.Length, }.ToBytes()).Append(this.Data).ToArray(); + + public override String ToString() => Json.SerializeOnly(this, true, nameof(this.Name), nameof(this.Type), nameof(this.Class), nameof(this.TimeToLive), nameof(this.DataLength)); + + [StructEndianness(Endianness.Big)] + [StructLayout(LayoutKind.Sequential, Pack = 2)] + private struct Tail { + public const Int32 SIZE = 10; + + private UInt16 type; + private UInt16 klass; + private UInt32 ttl; + private UInt16 dataLength; + + public DnsRecordType Type { + get => (DnsRecordType)this.type; + set => this.type = (UInt16)value; + } + + public DnsRecordClass Class { + get => (DnsRecordClass)this.klass; + set => this.klass = (UInt16)value; + } + + public TimeSpan TimeToLive { + get => TimeSpan.FromSeconds(this.ttl); + set => this.ttl = (UInt32)value.TotalSeconds; + } + + public Int32 DataLength { + get => this.dataLength; + set => this.dataLength = (UInt16)value; + } + } + } + + public class DnsPointerResourceRecord : DnsResourceRecordBase { + public DnsPointerResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.PointerDomainName = DnsDomain.FromArray(message, dataOffset); + + public DnsDomain PointerDomainName { + get; + } + + protected override String[] IncludedProperties { + get { + List temp = new List(base.IncludedProperties) { nameof(this.PointerDomainName) }; + return temp.ToArray(); + } + } + } + + public class DnsIPAddressResourceRecord : DnsResourceRecordBase { + public DnsIPAddressResourceRecord(IDnsResourceRecord record) : base(record) => this.IPAddress = new IPAddress(this.Data); + + public IPAddress IPAddress { + get; + } + + protected override String[] IncludedProperties => new List(base.IncludedProperties) { nameof(this.IPAddress) }.ToArray(); + } + + public class DnsNameServerResourceRecord : DnsResourceRecordBase { + public DnsNameServerResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.NSDomainName = DnsDomain.FromArray(message, dataOffset); + + public DnsDomain NSDomainName { + get; + } + + protected override String[] IncludedProperties => new List(base.IncludedProperties) { nameof(this.NSDomainName) }.ToArray(); + } + + public class DnsCanonicalNameResourceRecord : DnsResourceRecordBase { + public DnsCanonicalNameResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) => this.CanonicalDomainName = DnsDomain.FromArray(message, dataOffset); + + public DnsDomain CanonicalDomainName { + get; + } + + protected override String[] IncludedProperties => new List(base.IncludedProperties) { nameof(this.CanonicalDomainName) }.ToArray(); + } + + public class DnsMailExchangeResourceRecord : DnsResourceRecordBase { + private const Int32 PreferenceSize = 2; + + public DnsMailExchangeResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) + : base(record) { + Byte[] preference = new Byte[PreferenceSize]; + Array.Copy(message, dataOffset, preference, 0, preference.Length); + + if(BitConverter.IsLittleEndian) { + Array.Reverse(preference); + } + + dataOffset += PreferenceSize; + + this.Preference = BitConverter.ToUInt16(preference, 0); + this.ExchangeDomainName = DnsDomain.FromArray(message, dataOffset); + } + + public Int32 Preference { + get; + } + + public DnsDomain ExchangeDomainName { + get; + } + + protected override String[] IncludedProperties => new List(base.IncludedProperties) + { + nameof(this.Preference), + nameof(this.ExchangeDomainName), + }.ToArray(); + } + + public class DnsStartOfAuthorityResourceRecord : DnsResourceRecordBase { + public DnsStartOfAuthorityResourceRecord(IDnsResourceRecord record, Byte[] message, Int32 dataOffset) : base(record) { + this.MasterDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset); + this.ResponsibleDomainName = DnsDomain.FromArray(message, dataOffset, out dataOffset); + + Options tail = message.ToStruct(dataOffset, Options.SIZE); + + this.SerialNumber = tail.SerialNumber; + this.RefreshInterval = tail.RefreshInterval; + this.RetryInterval = tail.RetryInterval; + this.ExpireInterval = tail.ExpireInterval; + this.MinimumTimeToLive = tail.MinimumTimeToLive; + } + + public DnsStartOfAuthorityResourceRecord(DnsDomain domain, DnsDomain master, DnsDomain responsible, Int64 serial, TimeSpan refresh, TimeSpan retry, TimeSpan expire, TimeSpan minTtl, TimeSpan ttl = default) + : base(Create(domain, master, responsible, serial, refresh, retry, expire, minTtl, ttl)) { + this.MasterDomainName = master; + this.ResponsibleDomainName = responsible; + + this.SerialNumber = serial; + this.RefreshInterval = refresh; + this.RetryInterval = retry; + this.ExpireInterval = expire; + this.MinimumTimeToLive = minTtl; + } + + public DnsDomain MasterDomainName { + get; + } + + public DnsDomain ResponsibleDomainName { + get; + } + + public Int64 SerialNumber { + get; + } + + public TimeSpan RefreshInterval { + get; + } + + public TimeSpan RetryInterval { + get; + } + + public TimeSpan ExpireInterval { + get; + } + + public TimeSpan MinimumTimeToLive { + get; + } + + protected override String[] IncludedProperties => new List(base.IncludedProperties) + { + nameof(this.MasterDomainName), + nameof(this.ResponsibleDomainName), + nameof(this.SerialNumber), + }.ToArray(); + + private static IDnsResourceRecord Create(DnsDomain domain, DnsDomain master, DnsDomain responsible, Int64 serial, TimeSpan refresh, TimeSpan retry, TimeSpan expire, TimeSpan minTtl, TimeSpan ttl) { + MemoryStream data = new MemoryStream(Options.SIZE + master.Size + responsible.Size); + Options tail = new Options { + SerialNumber = serial, + RefreshInterval = refresh, + RetryInterval = retry, + ExpireInterval = expire, + MinimumTimeToLive = minTtl, + }; + + _ = data.Append(master.ToArray()).Append(responsible.ToArray()).Append(tail.ToBytes()); + + return new DnsResourceRecord(domain, data.ToArray(), DnsRecordType.SOA, DnsRecordClass.IN, ttl); + } + + [StructEndianness(Endianness.Big)] + [StructLayout(LayoutKind.Sequential, Pack = 4)] + public struct Options { + public const Int32 SIZE = 20; + + private UInt32 serialNumber; + private UInt32 refreshInterval; + private UInt32 retryInterval; + private UInt32 expireInterval; + private UInt32 ttl; + + public Int64 SerialNumber { + get => this.serialNumber; + set => this.serialNumber = (UInt32)value; + } + + public TimeSpan RefreshInterval { + get => TimeSpan.FromSeconds(this.refreshInterval); + set => this.refreshInterval = (UInt32)value.TotalSeconds; + } + + public TimeSpan RetryInterval { + get => TimeSpan.FromSeconds(this.retryInterval); + set => this.retryInterval = (UInt32)value.TotalSeconds; + } + + public TimeSpan ExpireInterval { + get => TimeSpan.FromSeconds(this.expireInterval); + set => this.expireInterval = (UInt32)value.TotalSeconds; + } + + public TimeSpan MinimumTimeToLive { + get => TimeSpan.FromSeconds(this.ttl); + set => this.ttl = (UInt32)value.TotalSeconds; + } + } + } + + private static class DnsResourceRecordFactory { + public static IList GetAllFromArray(Byte[] message, Int32 offset, Int32 count, out Int32 endOffset) { + List result = new List(count); + + for(Int32 i = 0; i < count; i++) { + result.Add(GetFromArray(message, offset, out offset)); + } + + endOffset = offset; + return result; + } + + private static IDnsResourceRecord GetFromArray(Byte[] message, Int32 offset, out Int32 endOffset) { + DnsResourceRecord record = DnsResourceRecord.FromArray(message, offset, out endOffset); + Int32 dataOffset = endOffset - record.DataLength; + + return record.Type switch + { + DnsRecordType.A => (new DnsIPAddressResourceRecord(record)), + DnsRecordType.AAAA => new DnsIPAddressResourceRecord(record), + DnsRecordType.NS => new DnsNameServerResourceRecord(record, message, dataOffset), + DnsRecordType.CNAME => new DnsCanonicalNameResourceRecord(record, message, dataOffset), + DnsRecordType.SOA => new DnsStartOfAuthorityResourceRecord(record, message, dataOffset), + DnsRecordType.PTR => new DnsPointerResourceRecord(record, message, dataOffset), + DnsRecordType.MX => new DnsMailExchangeResourceRecord(record, message, dataOffset), + _ => record + }; + } + } + } } diff --git a/Swan/Net/Dns/DnsClient.Response.cs b/Swan/Net/Dns/DnsClient.Response.cs index 5dcdb38..5799ef8 100644 --- a/Swan/Net/Dns/DnsClient.Response.cs +++ b/Swan/Net/Dns/DnsClient.Response.cs @@ -1,215 +1,174 @@ -namespace Swan.Net.Dns -{ - using Formatters; - using System; - using System.Collections.Generic; - using System.Collections.ObjectModel; - using System.IO; - using System.Linq; - - /// - /// DnsClient Response inner class. - /// - internal partial class DnsClient - { - public class DnsClientResponse : IDnsResponse - { - private readonly DnsResponse _response; - private readonly byte[] _message; - - internal DnsClientResponse(DnsClientRequest request, DnsResponse response, byte[] message) - { - Request = request; - - _message = message; - _response = response; - } - - public DnsClientRequest Request { get; } - - public int Id - { - get { return _response.Id; } - set { } - } - - public IList AnswerRecords => _response.AnswerRecords; - - public IList AuthorityRecords => - new ReadOnlyCollection(_response.AuthorityRecords); - - public IList AdditionalRecords => - new ReadOnlyCollection(_response.AdditionalRecords); - - public bool IsRecursionAvailable - { - get { return _response.IsRecursionAvailable; } - set { } - } - - public bool IsAuthorativeServer - { - get { return _response.IsAuthorativeServer; } - set { } - } - - public bool IsTruncated - { - get { return _response.IsTruncated; } - set { } - } - - public DnsOperationCode OperationCode - { - get { return _response.OperationCode; } - set { } - } - - public DnsResponseCode ResponseCode - { - get { return _response.ResponseCode; } - set { } - } - - public IList Questions => new ReadOnlyCollection(_response.Questions); - - public int Size => _message.Length; - - public byte[] ToArray() => _message; - - public override string ToString() => _response.ToString(); - } - - public class DnsResponse : IDnsResponse - { - private DnsHeader _header; - - public DnsResponse( - DnsHeader header, - IList questions, - IList answers, - IList authority, - IList additional) - { - _header = header; - Questions = questions; - AnswerRecords = answers; - AuthorityRecords = authority; - AdditionalRecords = additional; - } - - public IList Questions { get; } - - public IList AnswerRecords { get; } - - public IList AuthorityRecords { get; } - - public IList AdditionalRecords { get; } - - public int Id - { - get => _header.Id; - set => _header.Id = value; - } - - public bool IsRecursionAvailable - { - get => _header.RecursionAvailable; - set => _header.RecursionAvailable = value; - } - - public bool IsAuthorativeServer - { - get => _header.AuthorativeServer; - set => _header.AuthorativeServer = value; - } - - public bool IsTruncated - { - get => _header.Truncated; - set => _header.Truncated = value; - } - - public DnsOperationCode OperationCode - { - get => _header.OperationCode; - set => _header.OperationCode = value; - } - - public DnsResponseCode ResponseCode - { - get => _header.ResponseCode; - set => _header.ResponseCode = value; - } - - public int Size - => _header.Size + - Questions.Sum(q => q.Size) + - AnswerRecords.Sum(a => a.Size) + - AuthorityRecords.Sum(a => a.Size) + - AdditionalRecords.Sum(a => a.Size); - - public static DnsResponse FromArray(byte[] message) - { - var header = DnsHeader.FromArray(message); - var offset = header.Size; - - if (!header.Response || header.QuestionCount == 0) - { - throw new ArgumentException("Invalid response message"); - } - - if (header.Truncated) - { - return new DnsResponse(header, - DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount), - new List(), - new List(), - new List()); - } - - return new DnsResponse(header, - DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount, out offset), - DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AnswerRecordCount, out offset), - DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AuthorityRecordCount, out offset), - DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AdditionalRecordCount, out offset)); - } - - public byte[] ToArray() - { - UpdateHeader(); - var result = new MemoryStream(Size); - - result - .Append(_header.ToArray()) - .Append(Questions.Select(q => q.ToArray())) - .Append(AnswerRecords.Select(a => a.ToArray())) - .Append(AuthorityRecords.Select(a => a.ToArray())) - .Append(AdditionalRecords.Select(a => a.ToArray())); - - return result.ToArray(); - } - - public override string ToString() - { - UpdateHeader(); - - return Json.SerializeOnly( - this, - true, - nameof(Questions), - nameof(AnswerRecords), - nameof(AuthorityRecords), - nameof(AdditionalRecords)); - } - - private void UpdateHeader() - { - _header.QuestionCount = Questions.Count; - _header.AnswerRecordCount = AnswerRecords.Count; - _header.AuthorityRecordCount = AuthorityRecords.Count; - _header.AdditionalRecordCount = AdditionalRecords.Count; - } - } - } +using Swan.Formatters; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; + +namespace Swan.Net.Dns { + /// + /// DnsClient Response inner class. + /// + internal partial class DnsClient { + public class DnsClientResponse : IDnsResponse { + private readonly DnsResponse _response; + private readonly Byte[] _message; + + internal DnsClientResponse(DnsClientRequest request, DnsResponse response, Byte[] message) { + this.Request = request; + + this._message = message; + this._response = response; + } + + public DnsClientRequest Request { + get; + } + + public Int32 Id { + get => this._response.Id; + set { + } + } + + public IList AnswerRecords => this._response.AnswerRecords; + + public IList AuthorityRecords => new ReadOnlyCollection(this._response.AuthorityRecords); + + public IList AdditionalRecords => new ReadOnlyCollection(this._response.AdditionalRecords); + + public Boolean IsRecursionAvailable { + get => this._response.IsRecursionAvailable; + set { + } + } + + public Boolean IsAuthorativeServer { + get => this._response.IsAuthorativeServer; + set { + } + } + + public Boolean IsTruncated { + get => this._response.IsTruncated; + set { + } + } + + public DnsOperationCode OperationCode { + get => this._response.OperationCode; + set { + } + } + + public DnsResponseCode ResponseCode { + get => this._response.ResponseCode; + set { + } + } + + public IList Questions => new ReadOnlyCollection(this._response.Questions); + + public Int32 Size => this._message.Length; + + public Byte[] ToArray() => this._message; + + public override String ToString() => this._response.ToString(); + } + + public class DnsResponse : IDnsResponse { + private DnsHeader _header; + + public DnsResponse(DnsHeader header, IList questions, IList answers, IList authority, IList additional) { + this._header = header; + this.Questions = questions; + this.AnswerRecords = answers; + this.AuthorityRecords = authority; + this.AdditionalRecords = additional; + } + + public IList Questions { + get; + } + + public IList AnswerRecords { + get; + } + + public IList AuthorityRecords { + get; + } + + public IList AdditionalRecords { + get; + } + + public Int32 Id { + get => this._header.Id; + set => this._header.Id = value; + } + + public Boolean IsRecursionAvailable { + get => this._header.RecursionAvailable; + set => this._header.RecursionAvailable = value; + } + + public Boolean IsAuthorativeServer { + get => this._header.AuthorativeServer; + set => this._header.AuthorativeServer = value; + } + + public Boolean IsTruncated { + get => this._header.Truncated; + set => this._header.Truncated = value; + } + + public DnsOperationCode OperationCode { + get => this._header.OperationCode; + set => this._header.OperationCode = value; + } + + public DnsResponseCode ResponseCode { + get => this._header.ResponseCode; + set => this._header.ResponseCode = value; + } + + public Int32 Size => this._header.Size + this.Questions.Sum(q => q.Size) + this.AnswerRecords.Sum(a => a.Size) + this.AuthorityRecords.Sum(a => a.Size) + this.AdditionalRecords.Sum(a => a.Size); + + public static DnsResponse FromArray(Byte[] message) { + DnsHeader header = DnsHeader.FromArray(message); + Int32 offset = header.Size; + + if(!header.Response || header.QuestionCount == 0) { + throw new ArgumentException("Invalid response message"); + } + + return header.Truncated + ? new DnsResponse(header, DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount), new List(), new List(), new List()) + : new DnsResponse(header, DnsQuestion.GetAllFromArray(message, offset, header.QuestionCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AnswerRecordCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AuthorityRecordCount, out offset), DnsResourceRecordFactory.GetAllFromArray(message, offset, header.AdditionalRecordCount, out _)); + } + + public Byte[] ToArray() { + this.UpdateHeader(); + MemoryStream result = new MemoryStream(this.Size); + + _ = result.Append(this._header.ToArray()).Append(this.Questions.Select(q => q.ToArray())).Append(this.AnswerRecords.Select(a => a.ToArray())).Append(this.AuthorityRecords.Select(a => a.ToArray())).Append(this.AdditionalRecords.Select(a => a.ToArray())); + + return result.ToArray(); + } + + public override String ToString() { + this.UpdateHeader(); + + return Json.SerializeOnly(this, true, nameof(this.Questions), nameof(this.AnswerRecords), nameof(this.AuthorityRecords), nameof(this.AdditionalRecords)); + } + + private void UpdateHeader() { + this._header.QuestionCount = this.Questions.Count; + this._header.AnswerRecordCount = this.AnswerRecords.Count; + this._header.AuthorityRecordCount = this.AuthorityRecords.Count; + this._header.AdditionalRecordCount = this.AdditionalRecords.Count; + } + } + } } \ No newline at end of file diff --git a/Swan/Net/Dns/DnsClient.cs b/Swan/Net/Dns/DnsClient.cs index 8cb5aca..d030418 100644 --- a/Swan/Net/Dns/DnsClient.cs +++ b/Swan/Net/Dns/DnsClient.cs @@ -1,79 +1,65 @@ -namespace Swan.Net.Dns -{ - using System; - using System.Collections.Generic; - using System.Linq; - using System.Net; - using System.Threading.Tasks; - - /// - /// DnsClient public methods. - /// - internal partial class DnsClient - { - private readonly IPEndPoint _dns; - private readonly IDnsRequestResolver _resolver; - - public DnsClient(IPEndPoint dns, IDnsRequestResolver? resolver = null) - { - _dns = dns; - _resolver = resolver ?? new DnsUdpRequestResolver(new DnsTcpRequestResolver()); - } - - public DnsClient(IPAddress ip, int port = Network.DnsDefaultPort, IDnsRequestResolver? resolver = null) - : this(new IPEndPoint(ip, port), resolver) - { - } - - public DnsClientRequest Create(IDnsRequest? request = null) - => new DnsClientRequest(_dns, request, _resolver); - - public async Task> Lookup(string domain, DnsRecordType type = DnsRecordType.A) - { - if (string.IsNullOrWhiteSpace(domain)) - throw new ArgumentNullException(nameof(domain)); - - if (type != DnsRecordType.A && type != DnsRecordType.AAAA) - { - throw new ArgumentException("Invalid record type " + type); - } - - var response = await Resolve(domain, type).ConfigureAwait(false); - var ips = response.AnswerRecords - .Where(r => r.Type == type) - .Cast() - .Select(r => r.IPAddress) - .ToList(); - - return ips.Count == 0 ? throw new DnsQueryException(response, "No matching records") : ips; - } - - public async Task Reverse(IPAddress ip) - { - if (ip == null) - throw new ArgumentNullException(nameof(ip)); - - var response = await Resolve(DnsDomain.PointerName(ip), DnsRecordType.PTR); - var ptr = response.AnswerRecords.FirstOrDefault(r => r.Type == DnsRecordType.PTR); - - return ptr == null - ? throw new DnsQueryException(response, "No matching records") - : ((DnsPointerResourceRecord) ptr).PointerDomainName.ToString(); - } - - public Task Resolve(string domain, DnsRecordType type) => - Resolve(new DnsDomain(domain), type); - - public Task Resolve(DnsDomain domain, DnsRecordType type) - { - var request = Create(); - var question = new DnsQuestion(domain, type); - - request.Questions.Add(question); - request.OperationCode = DnsOperationCode.Query; - request.RecursionDesired = true; - - return request.Resolve(); - } - } +using System; +using System.Collections.Generic; +using System.Linq; +#nullable enable +using System.Net; +using System.Threading.Tasks; + +namespace Swan.Net.Dns { + /// + /// DnsClient public methods. + /// + internal partial class DnsClient { + private readonly IPEndPoint _dns; + private readonly IDnsRequestResolver _resolver; + + public DnsClient(IPEndPoint dns, IDnsRequestResolver? resolver = null) { + this._dns = dns; + this._resolver = resolver ?? new DnsUdpRequestResolver(new DnsTcpRequestResolver()); + } + + public DnsClient(IPAddress ip, Int32 port = Network.DnsDefaultPort, IDnsRequestResolver? resolver = null) : this(new IPEndPoint(ip, port), resolver) { + } + + public DnsClientRequest Create(IDnsRequest? request = null) => new DnsClientRequest(this._dns, request, this._resolver); + + public async Task> Lookup(String domain, DnsRecordType type = DnsRecordType.A) { + if(String.IsNullOrWhiteSpace(domain)) { + throw new ArgumentNullException(nameof(domain)); + } + + if(type != DnsRecordType.A && type != DnsRecordType.AAAA) { + throw new ArgumentException("Invalid record type " + type); + } + + DnsClientResponse response = await this.Resolve(domain, type).ConfigureAwait(false); + List ips = response.AnswerRecords.Where(r => r.Type == type).Cast().Select(r => r.IPAddress).ToList(); + + return ips.Count == 0 ? throw new DnsQueryException(response, "No matching records") : ips; + } + + public async Task Reverse(IPAddress ip) { + if(ip == null) { + throw new ArgumentNullException(nameof(ip)); + } + + DnsClientResponse response = await this.Resolve(DnsDomain.PointerName(ip), DnsRecordType.PTR); + IDnsResourceRecord ptr = response.AnswerRecords.FirstOrDefault(r => r.Type == DnsRecordType.PTR); + + return ptr == null ? throw new DnsQueryException(response, "No matching records") : ((DnsPointerResourceRecord)ptr).PointerDomainName.ToString(); + } + + public Task Resolve(String domain, DnsRecordType type) => this.Resolve(new DnsDomain(domain), type); + + public Task Resolve(DnsDomain domain, DnsRecordType type) { + DnsClientRequest request = this.Create(); + DnsQuestion question = new DnsQuestion(domain, type); + + request.Questions.Add(question); + request.OperationCode = DnsOperationCode.Query; + request.RecursionDesired = true; + + return request.Resolve(); + } + } } diff --git a/Swan/Net/Dns/DnsQueryException.cs b/Swan/Net/Dns/DnsQueryException.cs index f07eeae..c0f9e7f 100644 --- a/Swan/Net/Dns/DnsQueryException.cs +++ b/Swan/Net/Dns/DnsQueryException.cs @@ -1,37 +1,28 @@ -namespace Swan.Net.Dns -{ - using System; - - /// - /// An exception thrown when the DNS query fails. - /// - /// - [Serializable] - public class DnsQueryException : Exception - { - internal DnsQueryException(string message) - : base(message) - { - } - - internal DnsQueryException(string message, Exception e) - : base(message, e) - { - } - - internal DnsQueryException(DnsClient.IDnsResponse response) - : this(response, Format(response)) - { - } - - internal DnsQueryException(DnsClient.IDnsResponse response, string message) - : base(message) - { - Response = response; - } - - internal DnsClient.IDnsResponse? Response { get; } - - private static string Format(DnsClient.IDnsResponse response) => $"Invalid response received with code {response.ResponseCode}"; - } +#nullable enable +using System; + +namespace Swan.Net.Dns { + /// + /// An exception thrown when the DNS query fails. + /// + /// + [Serializable] + public class DnsQueryException : Exception { + internal DnsQueryException(String message) : base(message) { + } + + internal DnsQueryException(String message, Exception e) : base(message, e) { + } + + internal DnsQueryException(DnsClient.IDnsResponse response) : this(response, Format(response)) { + } + + internal DnsQueryException(DnsClient.IDnsResponse response, String message) : base(message) => this.Response = response; + + internal DnsClient.IDnsResponse? Response { + get; + } + + private static String Format(DnsClient.IDnsResponse response) => $"Invalid response received with code {response.ResponseCode}"; + } } diff --git a/Swan/Net/Dns/DnsQueryResult.cs b/Swan/Net/Dns/DnsQueryResult.cs index 31d164c..76b8fa7 100644 --- a/Swan/Net/Dns/DnsQueryResult.cs +++ b/Swan/Net/Dns/DnsQueryResult.cs @@ -1,123 +1,130 @@ -namespace Swan.Net.Dns -{ - using System.Collections.Generic; - +namespace Swan.Net.Dns { + using System.Collections.Generic; + + /// + /// Represents a response from a DNS server. + /// + public class DnsQueryResult { + private readonly List _mAnswerRecords = new List(); + private readonly List _mAdditionalRecords = new List(); + private readonly List _mAuthorityRecords = new List(); + /// - /// Represents a response from a DNS server. + /// Initializes a new instance of the class. /// - public class DnsQueryResult - { - private readonly List _mAnswerRecords = new List(); - private readonly List _mAdditionalRecords = new List(); - private readonly List _mAuthorityRecords = new List(); - - /// - /// Initializes a new instance of the class. - /// - /// The response. - internal DnsQueryResult(DnsClient.IDnsResponse response) - : this() - { - Id = response.Id; - IsAuthoritativeServer = response.IsAuthorativeServer; - IsRecursionAvailable = response.IsRecursionAvailable; - IsTruncated = response.IsTruncated; - OperationCode = response.OperationCode; - ResponseCode = response.ResponseCode; - - if (response.AnswerRecords != null) - { - foreach (var record in response.AnswerRecords) - AnswerRecords.Add(new DnsRecord(record)); - } - - if (response.AuthorityRecords != null) - { - foreach (var record in response.AuthorityRecords) - AuthorityRecords.Add(new DnsRecord(record)); - } - - if (response.AdditionalRecords != null) - { - foreach (var record in response.AdditionalRecords) - AdditionalRecords.Add(new DnsRecord(record)); - } - } - - private DnsQueryResult() - { - } - - /// - /// Gets the identifier. - /// - /// - /// The identifier. - /// - public int Id { get; } - - /// - /// Gets a value indicating whether this instance is authoritative server. - /// - /// - /// true if this instance is authoritative server; otherwise, false. - /// - public bool IsAuthoritativeServer { get; } - - /// - /// Gets a value indicating whether this instance is truncated. - /// - /// - /// true if this instance is truncated; otherwise, false. - /// - public bool IsTruncated { get; } - - /// - /// Gets a value indicating whether this instance is recursion available. - /// - /// - /// true if this instance is recursion available; otherwise, false. - /// - public bool IsRecursionAvailable { get; } - - /// - /// Gets the operation code. - /// - /// - /// The operation code. - /// - public DnsOperationCode OperationCode { get; } - - /// - /// Gets the response code. - /// - /// - /// The response code. - /// - public DnsResponseCode ResponseCode { get; } - - /// - /// Gets the answer records. - /// - /// - /// The answer records. - /// - public IList AnswerRecords => _mAnswerRecords; - - /// - /// Gets the additional records. - /// - /// - /// The additional records. - /// - public IList AdditionalRecords => _mAdditionalRecords; - - /// - /// Gets the authority records. - /// - /// - /// The authority records. - /// - public IList AuthorityRecords => _mAuthorityRecords; - } + /// The response. + internal DnsQueryResult(DnsClient.IDnsResponse response) : this() { + this.Id = response.Id; + this.IsAuthoritativeServer = response.IsAuthorativeServer; + this.IsRecursionAvailable = response.IsRecursionAvailable; + this.IsTruncated = response.IsTruncated; + this.OperationCode = response.OperationCode; + this.ResponseCode = response.ResponseCode; + + if(response.AnswerRecords != null) { + foreach(DnsClient.IDnsResourceRecord record in response.AnswerRecords) { + this.AnswerRecords.Add(new DnsRecord(record)); + } + } + + if(response.AuthorityRecords != null) { + foreach(DnsClient.IDnsResourceRecord record in response.AuthorityRecords) { + this.AuthorityRecords.Add(new DnsRecord(record)); + } + } + + if(response.AdditionalRecords != null) { + foreach(DnsClient.IDnsResourceRecord record in response.AdditionalRecords) { + this.AdditionalRecords.Add(new DnsRecord(record)); + } + } + } + + private DnsQueryResult() { + } + + /// + /// Gets the identifier. + /// + /// + /// The identifier. + /// + public System.Int32 Id { + get; + } + + /// + /// Gets a value indicating whether this instance is authoritative server. + /// + /// + /// true if this instance is authoritative server; otherwise, false. + /// + public System.Boolean IsAuthoritativeServer { + get; + } + + /// + /// Gets a value indicating whether this instance is truncated. + /// + /// + /// true if this instance is truncated; otherwise, false. + /// + public System.Boolean IsTruncated { + get; + } + + /// + /// Gets a value indicating whether this instance is recursion available. + /// + /// + /// true if this instance is recursion available; otherwise, false. + /// + public System.Boolean IsRecursionAvailable { + get; + } + + /// + /// Gets the operation code. + /// + /// + /// The operation code. + /// + public DnsOperationCode OperationCode { + get; + } + + /// + /// Gets the response code. + /// + /// + /// The response code. + /// + public DnsResponseCode ResponseCode { + get; + } + + /// + /// Gets the answer records. + /// + /// + /// The answer records. + /// + public IList AnswerRecords => this._mAnswerRecords; + + /// + /// Gets the additional records. + /// + /// + /// The additional records. + /// + public IList AdditionalRecords => this._mAdditionalRecords; + + /// + /// Gets the authority records. + /// + /// + /// The authority records. + /// + public IList AuthorityRecords => this._mAuthorityRecords; + } } diff --git a/Swan/Net/Dns/DnsRecord.cs b/Swan/Net/Dns/DnsRecord.cs index 78f19ac..699e75a 100644 --- a/Swan/Net/Dns/DnsRecord.cs +++ b/Swan/Net/Dns/DnsRecord.cs @@ -1,208 +1,239 @@ -namespace Swan.Net.Dns -{ - using System; - using System.Net; - using System.Text; - +using System; +using System.Net; +using System.Text; + +namespace Swan.Net.Dns { + /// + /// Represents a DNS record entry. + /// + public class DnsRecord { /// - /// Represents a DNS record entry. + /// Initializes a new instance of the class. /// - public class DnsRecord - { - /// - /// Initializes a new instance of the class. - /// - /// The record. - internal DnsRecord(DnsClient.IDnsResourceRecord record) - : this() - { - Name = record.Name.ToString(); - Type = record.Type; - Class = record.Class; - TimeToLive = record.TimeToLive; - Data = record.Data; - - // PTR - PointerDomainName = (record as DnsClient.DnsPointerResourceRecord)?.PointerDomainName?.ToString(); - - // A - IPAddress = (record as DnsClient.DnsIPAddressResourceRecord)?.IPAddress; - - // NS - NameServerDomainName = (record as DnsClient.DnsNameServerResourceRecord)?.NSDomainName?.ToString(); - - // CNAME - CanonicalDomainName = (record as DnsClient.DnsCanonicalNameResourceRecord)?.CanonicalDomainName.ToString(); - - // MX - MailExchangerDomainName = (record as DnsClient.DnsMailExchangeResourceRecord)?.ExchangeDomainName.ToString(); - MailExchangerPreference = (record as DnsClient.DnsMailExchangeResourceRecord)?.Preference; - - // SOA - SoaMasterDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MasterDomainName.ToString(); - SoaResponsibleDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ResponsibleDomainName.ToString(); - SoaSerialNumber = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.SerialNumber; - SoaRefreshInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RefreshInterval; - SoaRetryInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RetryInterval; - SoaExpireInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ExpireInterval; - SoaMinimumTimeToLive = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MinimumTimeToLive; - } - - private DnsRecord() - { - // placeholder - } - - /// - /// Gets the name. - /// - /// - /// The name. - /// - public string Name { get; } - - /// - /// Gets the type. - /// - /// - /// The type. - /// - public DnsRecordType Type { get; } - - /// - /// Gets the class. - /// - /// - /// The class. - /// - public DnsRecordClass Class { get; } - - /// - /// Gets the time to live. - /// - /// - /// The time to live. - /// - public TimeSpan TimeToLive { get; } - - /// - /// Gets the raw data of the record. - /// - /// - /// The data. - /// - public byte[] Data { get; } - - /// - /// Gets the data text bytes in ASCII encoding. - /// - /// - /// The data text. - /// - public string DataText => Data == null ? string.Empty : Encoding.ASCII.GetString(Data); - - /// - /// Gets the name of the pointer domain. - /// - /// - /// The name of the pointer domain. - /// - public string PointerDomainName { get; } - - /// - /// Gets the ip address. - /// - /// - /// The ip address. - /// - public IPAddress IPAddress { get; } - - /// - /// Gets the name of the name server domain. - /// - /// - /// The name of the name server domain. - /// - public string NameServerDomainName { get; } - - /// - /// Gets the name of the canonical domain. - /// - /// - /// The name of the canonical domain. - /// - public string CanonicalDomainName { get; } - - /// - /// Gets the mail exchanger preference. - /// - /// - /// The mail exchanger preference. - /// - public int? MailExchangerPreference { get; } - - /// - /// Gets the name of the mail exchanger domain. - /// - /// - /// The name of the mail exchanger domain. - /// - public string MailExchangerDomainName { get; } - - /// - /// Gets the name of the soa master domain. - /// - /// - /// The name of the soa master domain. - /// - public string SoaMasterDomainName { get; } - - /// - /// Gets the name of the soa responsible domain. - /// - /// - /// The name of the soa responsible domain. - /// - public string SoaResponsibleDomainName { get; } - - /// - /// Gets the soa serial number. - /// - /// - /// The soa serial number. - /// - public long? SoaSerialNumber { get; } - - /// - /// Gets the soa refresh interval. - /// - /// - /// The soa refresh interval. - /// - public TimeSpan? SoaRefreshInterval { get; } - - /// - /// Gets the soa retry interval. - /// - /// - /// The soa retry interval. - /// - public TimeSpan? SoaRetryInterval { get; } - - /// - /// Gets the soa expire interval. - /// - /// - /// The soa expire interval. - /// - public TimeSpan? SoaExpireInterval { get; } - - /// - /// Gets the soa minimum time to live. - /// - /// - /// The soa minimum time to live. - /// - public TimeSpan? SoaMinimumTimeToLive { get; } - } + /// The record. + internal DnsRecord(DnsClient.IDnsResourceRecord record) : this() { + this.Name = record.Name.ToString(); + this.Type = record.Type; + this.Class = record.Class; + this.TimeToLive = record.TimeToLive; + this.Data = record.Data; + + // PTR + this.PointerDomainName = (record as DnsClient.DnsPointerResourceRecord)?.PointerDomainName?.ToString(); + + // A + this.IPAddress = (record as DnsClient.DnsIPAddressResourceRecord)?.IPAddress; + + // NS + this.NameServerDomainName = (record as DnsClient.DnsNameServerResourceRecord)?.NSDomainName?.ToString(); + + // CNAME + this.CanonicalDomainName = (record as DnsClient.DnsCanonicalNameResourceRecord)?.CanonicalDomainName.ToString(); + + // MX + this.MailExchangerDomainName = (record as DnsClient.DnsMailExchangeResourceRecord)?.ExchangeDomainName.ToString(); + this.MailExchangerPreference = (record as DnsClient.DnsMailExchangeResourceRecord)?.Preference; + + // SOA + this.SoaMasterDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MasterDomainName.ToString(); + this.SoaResponsibleDomainName = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ResponsibleDomainName.ToString(); + this.SoaSerialNumber = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.SerialNumber; + this.SoaRefreshInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RefreshInterval; + this.SoaRetryInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.RetryInterval; + this.SoaExpireInterval = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.ExpireInterval; + this.SoaMinimumTimeToLive = (record as DnsClient.DnsStartOfAuthorityResourceRecord)?.MinimumTimeToLive; + } + + private DnsRecord() { + // placeholder + } + + /// + /// Gets the name. + /// + /// + /// The name. + /// + public String Name { + get; + } + + /// + /// Gets the type. + /// + /// + /// The type. + /// + public DnsRecordType Type { + get; + } + + /// + /// Gets the class. + /// + /// + /// The class. + /// + public DnsRecordClass Class { + get; + } + + /// + /// Gets the time to live. + /// + /// + /// The time to live. + /// + public TimeSpan TimeToLive { + get; + } + + /// + /// Gets the raw data of the record. + /// + /// + /// The data. + /// + public Byte[] Data { + get; + } + + /// + /// Gets the data text bytes in ASCII encoding. + /// + /// + /// The data text. + /// + public String DataText => this.Data == null ? String.Empty : Encoding.ASCII.GetString(this.Data); + + /// + /// Gets the name of the pointer domain. + /// + /// + /// The name of the pointer domain. + /// + public String PointerDomainName { + get; + } + + /// + /// Gets the ip address. + /// + /// + /// The ip address. + /// + public IPAddress IPAddress { + get; + } + + /// + /// Gets the name of the name server domain. + /// + /// + /// The name of the name server domain. + /// + public String NameServerDomainName { + get; + } + + /// + /// Gets the name of the canonical domain. + /// + /// + /// The name of the canonical domain. + /// + public String CanonicalDomainName { + get; + } + + /// + /// Gets the mail exchanger preference. + /// + /// + /// The mail exchanger preference. + /// + public Int32? MailExchangerPreference { + get; + } + + /// + /// Gets the name of the mail exchanger domain. + /// + /// + /// The name of the mail exchanger domain. + /// + public String MailExchangerDomainName { + get; + } + + /// + /// Gets the name of the soa master domain. + /// + /// + /// The name of the soa master domain. + /// + public String SoaMasterDomainName { + get; + } + + /// + /// Gets the name of the soa responsible domain. + /// + /// + /// The name of the soa responsible domain. + /// + public String SoaResponsibleDomainName { + get; + } + + /// + /// Gets the soa serial number. + /// + /// + /// The soa serial number. + /// + public Int64? SoaSerialNumber { + get; + } + + /// + /// Gets the soa refresh interval. + /// + /// + /// The soa refresh interval. + /// + public TimeSpan? SoaRefreshInterval { + get; + } + + /// + /// Gets the soa retry interval. + /// + /// + /// The soa retry interval. + /// + public TimeSpan? SoaRetryInterval { + get; + } + + /// + /// Gets the soa expire interval. + /// + /// + /// The soa expire interval. + /// + public TimeSpan? SoaExpireInterval { + get; + } + + /// + /// Gets the soa minimum time to live. + /// + /// + /// The soa minimum time to live. + /// + public TimeSpan? SoaMinimumTimeToLive { + get; + } + } } diff --git a/Swan/Net/Dns/Enums.Dns.cs b/Swan/Net/Dns/Enums.Dns.cs index 8891993..8b1f21c 100644 --- a/Swan/Net/Dns/Enums.Dns.cs +++ b/Swan/Net/Dns/Enums.Dns.cs @@ -1,172 +1,167 @@ // ReSharper disable InconsistentNaming -namespace Swan.Net.Dns -{ +namespace Swan.Net.Dns { + /// + /// Enumerates the different DNS record types. + /// + public enum DnsRecordType { /// - /// Enumerates the different DNS record types. - /// - public enum DnsRecordType - { - /// - /// A records - /// - A = 1, - - /// - /// Nameserver records - /// - NS = 2, - - /// - /// CNAME records - /// - CNAME = 5, - - /// - /// SOA records - /// - SOA = 6, - - /// - /// WKS records - /// - WKS = 11, - - /// - /// PTR records - /// - PTR = 12, - - /// - /// MX records - /// - MX = 15, - - /// - /// TXT records - /// - TXT = 16, - - /// - /// A records fot IPv6 - /// - AAAA = 28, - - /// - /// SRV records - /// - SRV = 33, - - /// - /// ANY records - /// - ANY = 255, - } - + /// A records + /// + A = 1, + /// - /// Enumerates the different DNS record classes. - /// - public enum DnsRecordClass - { - /// - /// IN records - /// - IN = 1, - - /// - /// ANY records - /// - ANY = 255, - } - + /// Nameserver records + /// + NS = 2, + /// - /// Enumerates the different DNS operation codes. - /// - public enum DnsOperationCode - { - /// - /// Query operation - /// - Query = 0, - - /// - /// IQuery operation - /// - IQuery, - - /// - /// Status operation - /// - Status, - - /// - /// Notify operation - /// - Notify = 4, - - /// - /// Update operation - /// - Update, - } - + /// CNAME records + /// + CNAME = 5, + /// - /// Enumerates the different DNS query response codes. - /// - public enum DnsResponseCode - { - /// - /// No error - /// - NoError = 0, - - /// - /// No error - /// - FormatError, - - /// - /// Format error - /// - ServerFailure, - - /// - /// Server failure error - /// - NameError, - - /// - /// Name error - /// - NotImplemented, - - /// - /// Not implemented error - /// - Refused, - - /// - /// Refused error - /// - YXDomain, - - /// - /// YXRR error - /// - YXRRSet, - - /// - /// NXRR Set error - /// - NXRRSet, - - /// - /// Not authorized error - /// - NotAuth, - - /// - /// Not zone error - /// - NotZone, - } + /// SOA records + /// + SOA = 6, + + /// + /// WKS records + /// + WKS = 11, + + /// + /// PTR records + /// + PTR = 12, + + /// + /// MX records + /// + MX = 15, + + /// + /// TXT records + /// + TXT = 16, + + /// + /// A records fot IPv6 + /// + AAAA = 28, + + /// + /// SRV records + /// + SRV = 33, + + /// + /// ANY records + /// + ANY = 255, + } + + /// + /// Enumerates the different DNS record classes. + /// + public enum DnsRecordClass { + /// + /// IN records + /// + IN = 1, + + /// + /// ANY records + /// + ANY = 255, + } + + /// + /// Enumerates the different DNS operation codes. + /// + public enum DnsOperationCode { + /// + /// Query operation + /// + Query = 0, + + /// + /// IQuery operation + /// + IQuery, + + /// + /// Status operation + /// + Status, + + /// + /// Notify operation + /// + Notify = 4, + + /// + /// Update operation + /// + Update, + } + + /// + /// Enumerates the different DNS query response codes. + /// + public enum DnsResponseCode { + /// + /// No error + /// + NoError = 0, + + /// + /// No error + /// + FormatError, + + /// + /// Format error + /// + ServerFailure, + + /// + /// Server failure error + /// + NameError, + + /// + /// Name error + /// + NotImplemented, + + /// + /// Not implemented error + /// + Refused, + + /// + /// Refused error + /// + YXDomain, + + /// + /// YXRR error + /// + YXRRSet, + + /// + /// NXRR Set error + /// + NXRRSet, + + /// + /// Not authorized error + /// + NotAuth, + + /// + /// Not zone error + /// + NotZone, + } } diff --git a/Swan/Net/Eventing.ConnectionListener.cs b/Swan/Net/Eventing.ConnectionListener.cs index 58c4910..a2d53c5 100644 --- a/Swan/Net/Eventing.ConnectionListener.cs +++ b/Swan/Net/Eventing.ConnectionListener.cs @@ -1,158 +1,157 @@ -namespace Swan.Net -{ - using System; - using System.Net; - using System.Net.Sockets; - +#nullable enable +using System; +using System.Net; +using System.Net.Sockets; + +namespace Swan.Net { + /// + /// The event arguments for when connections are accepted. + /// + /// + public class ConnectionAcceptedEventArgs : EventArgs { /// - /// The event arguments for when connections are accepted. + /// Initializes a new instance of the class. /// - /// - public class ConnectionAcceptedEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The client. - /// client. - public ConnectionAcceptedEventArgs(TcpClient client) - { - Client = client ?? throw new ArgumentNullException(nameof(client)); - } - - /// - /// Gets the client. - /// - /// - /// The client. - /// - public TcpClient Client { get; } - } - + /// The client. + /// client. + public ConnectionAcceptedEventArgs(TcpClient client) => this.Client = client ?? throw new ArgumentNullException(nameof(client)); + /// - /// Occurs before a connection is accepted. Set the Cancel property to true to prevent the connection from being accepted. + /// Gets the client. /// - /// - public class ConnectionAcceptingEventArgs : ConnectionAcceptedEventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The client. - public ConnectionAcceptingEventArgs(TcpClient client) - : base(client) - { - } - - /// - /// Setting Cancel to true rejects the new TcpClient. - /// - /// - /// true if cancel; otherwise, false. - /// - public bool Cancel { get; set; } - } - + /// + /// The client. + /// + public TcpClient Client { + get; + } + } + + /// + /// Occurs before a connection is accepted. Set the Cancel property to true to prevent the connection from being accepted. + /// + /// + public class ConnectionAcceptingEventArgs : ConnectionAcceptedEventArgs { /// - /// Event arguments for when a server listener is started. + /// Initializes a new instance of the class. /// - /// - public class ConnectionListenerStartedEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The listener end point. - /// listenerEndPoint. - public ConnectionListenerStartedEventArgs(IPEndPoint listenerEndPoint) - { - EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); - } - - /// - /// Gets the end point. - /// - /// - /// The end point. - /// - public IPEndPoint EndPoint { get; } - } - + /// The client. + public ConnectionAcceptingEventArgs(TcpClient client) : base(client) { + } + /// - /// Event arguments for when a server listener fails to start. + /// Setting Cancel to true rejects the new TcpClient. /// - /// - public class ConnectionListenerFailedEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The listener end point. - /// The ex. - /// - /// listenerEndPoint - /// or - /// ex. - /// - public ConnectionListenerFailedEventArgs(IPEndPoint listenerEndPoint, Exception ex) - { - EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); - Error = ex ?? throw new ArgumentNullException(nameof(ex)); - } - - /// - /// Gets the end point. - /// - /// - /// The end point. - /// - public IPEndPoint EndPoint { get; } - - /// - /// Gets the error. - /// - /// - /// The error. - /// - public Exception Error { get; } - } - + /// + /// true if cancel; otherwise, false. + /// + public Boolean Cancel { + get; set; + } + } + + /// + /// Event arguments for when a server listener is started. + /// + /// + public class ConnectionListenerStartedEventArgs : EventArgs { /// - /// Event arguments for when a server listener stopped. + /// Initializes a new instance of the class. /// - /// - public class ConnectionListenerStoppedEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The listener end point. - /// The ex. - /// - /// listenerEndPoint - /// or - /// ex. - /// - public ConnectionListenerStoppedEventArgs(IPEndPoint listenerEndPoint, Exception? ex = null) - { - EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); - Error = ex; - } - - /// - /// Gets the end point. - /// - /// - /// The end point. - /// - public IPEndPoint EndPoint { get; } - - /// - /// Gets the error. - /// - /// - /// The error. - /// - public Exception? Error { get; } - } + /// The listener end point. + /// listenerEndPoint. + public ConnectionListenerStartedEventArgs(IPEndPoint listenerEndPoint) => this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); + + /// + /// Gets the end point. + /// + /// + /// The end point. + /// + public IPEndPoint EndPoint { + get; + } + } + + /// + /// Event arguments for when a server listener fails to start. + /// + /// + public class ConnectionListenerFailedEventArgs : EventArgs { + /// + /// Initializes a new instance of the class. + /// + /// The listener end point. + /// The ex. + /// + /// listenerEndPoint + /// or + /// ex. + /// + public ConnectionListenerFailedEventArgs(IPEndPoint listenerEndPoint, Exception ex) { + this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); + this.Error = ex ?? throw new ArgumentNullException(nameof(ex)); + } + + /// + /// Gets the end point. + /// + /// + /// The end point. + /// + public IPEndPoint EndPoint { + get; + } + + /// + /// Gets the error. + /// + /// + /// The error. + /// + public Exception Error { + get; + } + } + + /// + /// Event arguments for when a server listener stopped. + /// + /// + public class ConnectionListenerStoppedEventArgs : EventArgs { + /// + /// Initializes a new instance of the class. + /// + /// The listener end point. + /// The ex. + /// + /// listenerEndPoint + /// or + /// ex. + /// + public ConnectionListenerStoppedEventArgs(IPEndPoint listenerEndPoint, Exception? ex = null) { + this.EndPoint = listenerEndPoint ?? throw new ArgumentNullException(nameof(listenerEndPoint)); + this.Error = ex; + } + + /// + /// Gets the end point. + /// + /// + /// The end point. + /// + public IPEndPoint EndPoint { + get; + } + + /// + /// Gets the error. + /// + /// + /// The error. + /// + public Exception? Error { + get; + } + } } diff --git a/Swan/Net/Eventing.cs b/Swan/Net/Eventing.cs index ab1fc15..ce8396c 100644 --- a/Swan/Net/Eventing.cs +++ b/Swan/Net/Eventing.cs @@ -1,84 +1,84 @@ -namespace Swan.Net -{ - using System; - using System.Text; - +using System; +using System.Text; + +namespace Swan.Net { + /// + /// The event arguments for connection failure events. + /// + /// + public class ConnectionFailureEventArgs : EventArgs { /// - /// The event arguments for connection failure events. + /// Initializes a new instance of the class. /// - /// - public class ConnectionFailureEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The ex. - public ConnectionFailureEventArgs(Exception ex) - { - Error = ex; - } - - /// - /// Gets the error. - /// - /// - /// The error. - /// - public Exception Error { get; } - } - + /// The ex. + public ConnectionFailureEventArgs(Exception ex) => this.Error = ex; + /// - /// Event arguments for when data is received. + /// Gets the error. /// - /// - public class ConnectionDataReceivedEventArgs : EventArgs - { - /// - /// Initializes a new instance of the class. - /// - /// The buffer. - /// The trigger. - /// if set to true [more available]. - public ConnectionDataReceivedEventArgs(byte[] buffer, ConnectionDataReceivedTrigger trigger, bool moreAvailable) - { - Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); - Trigger = trigger; - HasMoreAvailable = moreAvailable; - } - - /// - /// Gets the buffer. - /// - /// - /// The buffer. - /// - public byte[] Buffer { get; } - - /// - /// Gets the cause as to why this event was thrown. - /// - /// - /// The trigger. - /// - public ConnectionDataReceivedTrigger Trigger { get; } - - /// - /// Gets a value indicating whether the receive buffer has more bytes available. - /// - /// - /// true if this instance has more available; otherwise, false. - /// - public bool HasMoreAvailable { get; } - - /// - /// Gets the string from buffer. - /// - /// The encoding. - /// - /// A that contains the results of decoding the specified sequence of bytes. - /// - /// encoding - public string GetStringFromBuffer(Encoding encoding) - => encoding?.GetString(Buffer).TrimEnd('\r', '\n') ?? throw new ArgumentNullException(nameof(encoding)); - } + /// + /// The error. + /// + public Exception Error { + get; + } + } + + /// + /// Event arguments for when data is received. + /// + /// + public class ConnectionDataReceivedEventArgs : EventArgs { + /// + /// Initializes a new instance of the class. + /// + /// The buffer. + /// The trigger. + /// if set to true [more available]. + public ConnectionDataReceivedEventArgs(Byte[] buffer, ConnectionDataReceivedTrigger trigger, Boolean moreAvailable) { + this.Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); + this.Trigger = trigger; + this.HasMoreAvailable = moreAvailable; + } + + /// + /// Gets the buffer. + /// + /// + /// The buffer. + /// + public Byte[] Buffer { + get; + } + + /// + /// Gets the cause as to why this event was thrown. + /// + /// + /// The trigger. + /// + public ConnectionDataReceivedTrigger Trigger { + get; + } + + /// + /// Gets a value indicating whether the receive buffer has more bytes available. + /// + /// + /// true if this instance has more available; otherwise, false. + /// + public Boolean HasMoreAvailable { + get; + } + + /// + /// Gets the string from buffer. + /// + /// The encoding. + /// + /// A that contains the results of decoding the specified sequence of bytes. + /// + /// encoding + public String GetStringFromBuffer(Encoding encoding) => encoding?.GetString(this.Buffer).TrimEnd('\r', '\n') ?? throw new ArgumentNullException(nameof(encoding)); + } } diff --git a/Swan/Net/JsonClient.cs b/Swan/Net/JsonClient.cs index 402ab32..0438973 100644 --- a/Swan/Net/JsonClient.cs +++ b/Swan/Net/JsonClient.cs @@ -1,418 +1,313 @@ -namespace Swan.Net -{ - using Formatters; - using System; - using System.Collections.Generic; - using System.Net.Http; - using System.Net.Http.Headers; - using System.Security; - using System.Text; - using System.Threading; - using System.Threading.Tasks; - +#nullable enable +using Swan.Formatters; +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Net { + /// + /// Represents a HttpClient with extended methods to use with JSON payloads + /// and bearer tokens authentication. + /// + public static class JsonClient { + private const String JsonMimeType = "application/json"; + private const String FormType = "application/x-www-form-urlencoded"; + + private static readonly HttpClient HttpClient = new HttpClient(); + /// - /// Represents a HttpClient with extended methods to use with JSON payloads - /// and bearer tokens authentication. + /// Post a object as JSON with optional authorization token. /// - public static class JsonClient - { - private const string JsonMimeType = "application/json"; - private const string FormType = "application/x-www-form-urlencoded"; - - private static readonly HttpClient HttpClient = new HttpClient(); - - /// - /// Post a object as JSON with optional authorization token. - /// - /// The type of response object. - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested type. - /// - public static async Task Post( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken cancellationToken = default) - { - var jsonString = await PostString(requestUri, payload, authorization, cancellationToken) - .ConfigureAwait(false); - - return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; - } - - /// - /// Posts the specified URL. - /// - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result as a collection of key/value pairs. - /// - public static async Task?> Post( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken cancellationToken = default) - { - var jsonString = await PostString(requestUri, payload, authorization, cancellationToken) - .ConfigureAwait(false); - - return string.IsNullOrWhiteSpace(jsonString) - ? default - : Json.Deserialize(jsonString) as IDictionary; - } - - /// - /// Posts the specified URL. - /// - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - /// url. - /// Error POST JSON. - public static Task PostString( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken cancellationToken = default) - => SendAsync(HttpMethod.Post, requestUri, payload, authorization, cancellationToken); - - /// - /// Puts the specified URL. - /// - /// The type of response object. - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested type. - /// - public static async Task Put( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken ct = default) - { - var jsonString = await PutString(requestUri, payload, authorization, ct) - .ConfigureAwait(false); - - return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; - } - - /// - /// Puts the specified URL. - /// - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested collection of key/value pairs. - /// - public static async Task?> Put( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken cancellationToken = default) - { - var response = await Put(requestUri, payload, authorization, cancellationToken) - .ConfigureAwait(false); - - return response as IDictionary; - } - - /// - /// Puts as string. - /// - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - /// url. - /// Error PUT JSON. - public static Task PutString( - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken ct = default) => SendAsync(HttpMethod.Put, requestUri, payload, authorization, ct); - - /// - /// Gets as string. - /// - /// The request URI. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - /// url. - /// Error GET JSON. - public static Task GetString( - Uri requestUri, - string? authorization = null, - CancellationToken ct = default) - => GetString(requestUri, null, authorization, ct); - - /// - /// Gets the string. - /// - /// The URI. - /// The headers. - /// The authorization. - /// The ct. - /// - /// A task with a result of the requested string. - /// - public static async Task GetString( - Uri uri, - IDictionary>? headers, - string? authorization = null, - CancellationToken ct = default) - { - var response = await GetHttpContent(uri, ct, authorization, headers) - .ConfigureAwait(false); - - return await response.ReadAsStringAsync() - .ConfigureAwait(false); - } - - /// - /// Gets the specified URL and return the JSON data as object - /// with optional authorization token. - /// - /// The response type. - /// The request URI. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested type. - /// - public static async Task Get( - Uri requestUri, - string? authorization = null, - CancellationToken ct = default) - { - var jsonString = await GetString(requestUri, authorization, ct) - .ConfigureAwait(false); - - return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; - } - - /// - /// Gets the specified URL and return the JSON data as object - /// with optional authorization token. - /// - /// The response type. - /// The request URI. - /// The headers. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested type. - /// - public static async Task Get( - Uri requestUri, - IDictionary>? headers, - string? authorization = null, - CancellationToken ct = default) - { - var jsonString = await GetString(requestUri, headers, authorization, ct) - .ConfigureAwait(false); - - return !string.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; - } - - /// - /// Gets the binary. - /// - /// The request URI. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested byte array. - /// - /// url. - /// Error GET Binary. - public static async Task GetBinary( - Uri requestUri, - string? authorization = null, - CancellationToken ct = default) - { - var response = await GetHttpContent(requestUri, ct, authorization) - .ConfigureAwait(false); - - return await response.ReadAsByteArrayAsync() - .ConfigureAwait(false); - } - - /// - /// Authenticate against a web server using Bearer Token. - /// - /// The request URI. - /// The username. - /// The password. - /// The cancellation token. - /// - /// A task with a Dictionary with authentication data. - /// - /// url - /// or - /// username. - /// Error Authenticating. - public static async Task?> Authenticate( - Uri requestUri, - string username, - string password, - CancellationToken ct = default) - { - if (string.IsNullOrWhiteSpace(username)) - throw new ArgumentNullException(nameof(username)); - - // ignore empty password for now - var content = $"grant_type=password&username={username}&password={password}"; - using var requestContent = new StringContent(content, Encoding.UTF8, FormType); - var response = await HttpClient.PostAsync(requestUri, requestContent, ct).ConfigureAwait(false); - - if (!response.IsSuccessStatusCode) - throw new SecurityException($"Error Authenticating. Status code: {response.StatusCode}."); - - var jsonPayload = await response.Content.ReadAsStringAsync().ConfigureAwait(false); - - return Json.Deserialize(jsonPayload) as IDictionary; - } - - /// - /// Posts the file. - /// - /// The request URI. - /// The buffer. - /// Name of the file. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - public static Task PostFileString( - Uri requestUri, - byte[] buffer, - string fileName, - string? authorization = null, - CancellationToken ct = default) => - PostString(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct); - - /// - /// Posts the file. - /// - /// The response type. - /// The request URI. - /// The buffer. - /// Name of the file. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - public static Task PostFile( - Uri requestUri, - byte[] buffer, - string fileName, - string? authorization = null, - CancellationToken ct = default) => - Post(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct); - - /// - /// Sends the asynchronous request. - /// - /// The method. - /// The request URI. - /// The payload. - /// The authorization. - /// The cancellation token. - /// - /// A task with a result of the requested string. - /// - /// requestUri. - /// Error {method} JSON. - public static async Task SendAsync( - HttpMethod method, - Uri requestUri, - object payload, - string? authorization = null, - CancellationToken ct = default) - { - using var response = await GetResponse(requestUri, authorization, null, payload, method, ct).ConfigureAwait(false); - if (!response.IsSuccessStatusCode) - { - throw new JsonRequestException( - $"Error {method} JSON", - (int)response.StatusCode, - await response.Content.ReadAsStringAsync().ConfigureAwait(false)); - } - - return await response.Content.ReadAsStringAsync() - .ConfigureAwait(false); - } - - private static async Task GetHttpContent( - Uri uri, - CancellationToken ct, - string? authorization = null, - IDictionary>? headers = null) - { - var response = await GetResponse(uri, authorization, headers, ct: ct) - .ConfigureAwait(false); - - return response.IsSuccessStatusCode - ? response.Content - : throw new JsonRequestException("Error GET", (int)response.StatusCode); - } - - private static async Task GetResponse( - Uri uri, - string? authorization, - IDictionary>? headers, - object? payload = null, - HttpMethod? method = default, - CancellationToken ct = default) - { - if (uri == null) - throw new ArgumentNullException(nameof(uri)); - - using var requestMessage = new HttpRequestMessage(method ?? HttpMethod.Get, uri); - - if (!string.IsNullOrWhiteSpace(authorization)) - { - requestMessage.Headers.Authorization - = new AuthenticationHeaderValue("Bearer", authorization); - } - - if (headers != null) - { - foreach (var header in headers) - requestMessage.Headers.Add(header.Key, header.Value); - } - - if (payload != null && requestMessage.Method != HttpMethod.Get) - { - requestMessage.Content = new StringContent(Json.Serialize(payload), Encoding.UTF8, JsonMimeType); - } - - return await HttpClient.SendAsync(requestMessage, ct) - .ConfigureAwait(false); - } - } + /// The type of response object. + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested type. + /// + public static async Task Post(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) where T : notnull { + String jsonString = await PostString(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false); + + return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; + } + + /// + /// Posts the specified URL. + /// + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result as a collection of key/value pairs. + /// + public static async Task?> Post(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) { + String jsonString = await PostString(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false); + + return String.IsNullOrWhiteSpace(jsonString) ? default : Json.Deserialize(jsonString) as IDictionary; + } + + /// + /// Posts the specified URL. + /// + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + /// url. + /// Error POST JSON. + public static Task PostString(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) => SendAsync(HttpMethod.Post, requestUri, payload, authorization, cancellationToken); + + /// + /// Puts the specified URL. + /// + /// The type of response object. + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested type. + /// + public static async Task Put(Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) where T : notnull { + String jsonString = await PutString(requestUri, payload, authorization, ct).ConfigureAwait(false); + + return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; + } + + /// + /// Puts the specified URL. + /// + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested collection of key/value pairs. + /// + public static async Task?> Put(Uri requestUri, Object payload, String? authorization = null, CancellationToken cancellationToken = default) { + Object response = await Put(requestUri, payload, authorization, cancellationToken).ConfigureAwait(false); + + return response as IDictionary; + } + + /// + /// Puts as string. + /// + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + /// url. + /// Error PUT JSON. + public static Task PutString(Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) => SendAsync(HttpMethod.Put, requestUri, payload, authorization, ct); + + /// + /// Gets as string. + /// + /// The request URI. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + /// url. + /// Error GET JSON. + public static Task GetString(Uri requestUri, String? authorization = null, CancellationToken ct = default) => GetString(requestUri, null, authorization, ct); + + /// + /// Gets the string. + /// + /// The URI. + /// The headers. + /// The authorization. + /// The ct. + /// + /// A task with a result of the requested string. + /// + public static async Task GetString(Uri uri, IDictionary>? headers, String? authorization = null, CancellationToken ct = default) { + HttpContent response = await GetHttpContent(uri, ct, authorization, headers).ConfigureAwait(false); + + return await response.ReadAsStringAsync().ConfigureAwait(false); + } + + /// + /// Gets the specified URL and return the JSON data as object + /// with optional authorization token. + /// + /// The response type. + /// The request URI. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested type. + /// + public static async Task Get(Uri requestUri, String? authorization = null, CancellationToken ct = default) where T : notnull { + String jsonString = await GetString(requestUri, authorization, ct).ConfigureAwait(false); + + return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; + } + + /// + /// Gets the specified URL and return the JSON data as object + /// with optional authorization token. + /// + /// The response type. + /// The request URI. + /// The headers. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested type. + /// + public static async Task Get(Uri requestUri, IDictionary>? headers, String? authorization = null, CancellationToken ct = default) where T : notnull { + String jsonString = await GetString(requestUri, headers, authorization, ct).ConfigureAwait(false); + + return !String.IsNullOrEmpty(jsonString) ? Json.Deserialize(jsonString) : default; + } + + /// + /// Gets the binary. + /// + /// The request URI. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested byte array. + /// + /// url. + /// Error GET Binary. + public static async Task GetBinary(Uri requestUri, String? authorization = null, CancellationToken ct = default) { + HttpContent response = await GetHttpContent(requestUri, ct, authorization).ConfigureAwait(false); + + return await response.ReadAsByteArrayAsync().ConfigureAwait(false); + } + + /// + /// Authenticate against a web server using Bearer Token. + /// + /// The request URI. + /// The username. + /// The password. + /// The cancellation token. + /// + /// A task with a Dictionary with authentication data. + /// + /// url + /// or + /// username. + /// Error Authenticating. + public static async Task?> Authenticate(Uri requestUri, String username, String password, CancellationToken ct = default) { + if(String.IsNullOrWhiteSpace(username)) { + throw new ArgumentNullException(nameof(username)); + } + + // ignore empty password for now + String content = $"grant_type=password&username={username}&password={password}"; + using StringContent requestContent = new StringContent(content, Encoding.UTF8, FormType); + HttpResponseMessage response = await HttpClient.PostAsync(requestUri, requestContent, ct).ConfigureAwait(false); + + if(!response.IsSuccessStatusCode) { + throw new SecurityException($"Error Authenticating. Status code: {response.StatusCode}."); + } + + String jsonPayload = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + + return Json.Deserialize(jsonPayload) as IDictionary; + } + + /// + /// Posts the file. + /// + /// The request URI. + /// The buffer. + /// Name of the file. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + public static Task PostFileString(Uri requestUri, Byte[] buffer, String fileName, String? authorization = null, CancellationToken ct = default) => PostString(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct); + + /// + /// Posts the file. + /// + /// The response type. + /// The request URI. + /// The buffer. + /// Name of the file. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + public static Task PostFile(Uri requestUri, Byte[] buffer, String fileName, String? authorization = null, CancellationToken ct = default) where T : notnull => Post(requestUri, new { Filename = fileName, Data = buffer }, authorization, ct); + + /// + /// Sends the asynchronous request. + /// + /// The method. + /// The request URI. + /// The payload. + /// The authorization. + /// The cancellation token. + /// + /// A task with a result of the requested string. + /// + /// requestUri. + /// Error {method} JSON. + public static async Task SendAsync(HttpMethod method, Uri requestUri, Object payload, String? authorization = null, CancellationToken ct = default) { + using HttpResponseMessage response = await GetResponse(requestUri, authorization, null, payload, method, ct).ConfigureAwait(false); + if(!response.IsSuccessStatusCode) { + throw new JsonRequestException( + $"Error {method} JSON", + (Int32)response.StatusCode, + await response.Content.ReadAsStringAsync().ConfigureAwait(false)); + } + + return await response.Content.ReadAsStringAsync().ConfigureAwait(false); + } + + private static async Task GetHttpContent(Uri uri, CancellationToken ct, String? authorization = null, IDictionary>? headers = null) { + HttpResponseMessage response = await GetResponse(uri, authorization, headers, ct: ct).ConfigureAwait(false); + + return response.IsSuccessStatusCode ? response.Content : throw new JsonRequestException("Error GET", (Int32)response.StatusCode); + } + + private static async Task GetResponse(Uri uri, String? authorization, IDictionary>? headers, Object? payload = null, HttpMethod? method = default, CancellationToken ct = default) { + if(uri == null) { + throw new ArgumentNullException(nameof(uri)); + } + + using HttpRequestMessage requestMessage = new HttpRequestMessage(method ?? HttpMethod.Get, uri); + + if(!String.IsNullOrWhiteSpace(authorization)) { + requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", authorization); + } + + if(headers != null) { + foreach(KeyValuePair> header in headers) { + requestMessage.Headers.Add(header.Key, header.Value); + } + } + + if(payload != null && requestMessage.Method != HttpMethod.Get) { + requestMessage.Content = new StringContent(Json.Serialize(payload), Encoding.UTF8, JsonMimeType); + } + + return await HttpClient.SendAsync(requestMessage, ct).ConfigureAwait(false); + } + } } diff --git a/Swan/Net/JsonRequestException.cs b/Swan/Net/JsonRequestException.cs index a2cd373..a925d1f 100644 --- a/Swan/Net/JsonRequestException.cs +++ b/Swan/Net/JsonRequestException.cs @@ -1,47 +1,44 @@ -namespace Swan.Net -{ - using System; - +using System; + +namespace Swan.Net { + /// + /// Represents errors that occurs requesting a JSON file through HTTP. + /// + /// + [Serializable] + public class JsonRequestException : Exception { /// - /// Represents errors that occurs requesting a JSON file through HTTP. + /// Initializes a new instance of the class. /// - /// - [Serializable] - public class JsonRequestException - : Exception - { - /// - /// Initializes a new instance of the class. - /// - /// The message. - /// The HTTP error code. - /// Content of the error. - public JsonRequestException(string message, int httpErrorCode = 500, string errorContent = null) - : base(message) - { - HttpErrorCode = httpErrorCode; - HttpErrorContent = errorContent; - } - - /// - /// Gets the HTTP error code. - /// - /// - /// The HTTP error code. - /// - public int HttpErrorCode { get; } - - /// - /// Gets the content of the HTTP error. - /// - /// - /// The content of the HTTP error. - /// - public string HttpErrorContent { get; } - - /// - public override string ToString() => string.IsNullOrEmpty(HttpErrorContent) - ? $"HTTP Response Status Code {HttpErrorCode} Error Message: {HttpErrorContent}" - : base.ToString(); - } + /// The message. + /// The HTTP error code. + /// Content of the error. + public JsonRequestException(String message, Int32 httpErrorCode = 500, String errorContent = null) : base(message) { + this.HttpErrorCode = httpErrorCode; + this.HttpErrorContent = errorContent; + } + + /// + /// Gets the HTTP error code. + /// + /// + /// The HTTP error code. + /// + public Int32 HttpErrorCode { + get; + } + + /// + /// Gets the content of the HTTP error. + /// + /// + /// The content of the HTTP error. + /// + public String HttpErrorContent { + get; + } + + /// + public override String ToString() => String.IsNullOrEmpty(this.HttpErrorContent) ? $"HTTP Response Status Code {this.HttpErrorCode} Error Message: {this.HttpErrorContent}" : base.ToString(); + } } diff --git a/Swan/Net/Network.cs b/Swan/Net/Network.cs index 613263b..e92ac24 100644 --- a/Swan/Net/Network.cs +++ b/Swan/Net/Network.cs @@ -1,328 +1,289 @@ -namespace Swan.Net -{ - using Net.Dns; - using System; - using System.Collections.Generic; - using System.Linq; - using System.Net; - using System.Net.Http; - using System.Net.NetworkInformation; - using System.Net.Sockets; - using System.Threading; - using System.Threading.Tasks; - +using Swan.Net.Dns; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.NetworkInformation; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Net { + /// + /// Provides miscellaneous network utilities such as a Public IP finder, + /// a DNS client to query DNS records of any kind, and an NTP client. + /// + public static class Network { /// - /// Provides miscellaneous network utilities such as a Public IP finder, - /// a DNS client to query DNS records of any kind, and an NTP client. + /// The DNS default port. /// - public static class Network - { - /// - /// The DNS default port. - /// - public const int DnsDefaultPort = 53; + public const Int32 DnsDefaultPort = 53; + + /// + /// The NTP default port. + /// + public const Int32 NtpDefaultPort = 123; + + /// + /// Gets the name of the host. + /// + /// + /// The name of the host. + /// + public static String HostName => IPGlobalProperties.GetIPGlobalProperties().HostName; + + /// + /// Gets the name of the network domain. + /// + /// + /// The name of the network domain. + /// + public static String DomainName => IPGlobalProperties.GetIPGlobalProperties().DomainName; + + #region IP Addresses and Adapters Information Methods + + /// + /// Gets the active IPv4 interfaces. + /// Only those interfaces with a valid unicast address and a valid gateway will be returned in the collection. + /// + /// + /// A collection of NetworkInterface/IPInterfaceProperties pairs + /// that represents the active IPv4 interfaces. + /// + public static Dictionary GetIPv4Interfaces() { + // zero conf ip address + IPAddress zeroConf = new IPAddress(0); + + NetworkInterface[] adapters = NetworkInterface.GetAllNetworkInterfaces().Where(network => network.OperationalStatus == OperationalStatus.Up && network.NetworkInterfaceType != NetworkInterfaceType.Unknown && network.NetworkInterfaceType != NetworkInterfaceType.Loopback).ToArray(); + + Dictionary result = new Dictionary(); + + foreach(NetworkInterface adapter in adapters) { + IPInterfaceProperties properties = adapter.GetIPProperties(); + if(properties == null || properties.GatewayAddresses.Count == 0 || properties.GatewayAddresses.All(gateway => Equals(gateway.Address, zeroConf)) || properties.UnicastAddresses.Count == 0 || properties.GatewayAddresses.All(address => Equals(address.Address, zeroConf)) || properties.UnicastAddresses.Any(a => a.Address.AddressFamily == AddressFamily.InterNetwork) == false) { + continue; + } + + result[adapter] = properties; + } + + return result; + } + + /// + /// Retrieves the local ip addresses. + /// + /// if set to true [include loopback]. + /// An array of local ip addresses. + public static IPAddress[] GetIPv4Addresses(Boolean includeLoopback = true) => GetIPv4Addresses(NetworkInterfaceType.Unknown, true, includeLoopback); + + /// + /// Retrieves the local ip addresses. + /// + /// Type of the interface. + /// if set to true [skip type filter]. + /// if set to true [include loopback]. + /// An array of local ip addresses. + public static IPAddress[] GetIPv4Addresses(NetworkInterfaceType interfaceType, Boolean skipTypeFilter = false, Boolean includeLoopback = false) { + List addressList = new List(); + NetworkInterface[] interfaces = NetworkInterface.GetAllNetworkInterfaces() + .Where(ni => (skipTypeFilter || ni.NetworkInterfaceType == interfaceType) && ni.OperationalStatus == OperationalStatus.Up).ToArray(); + + foreach(NetworkInterface networkInterface in interfaces) { + IPInterfaceProperties properties = networkInterface.GetIPProperties(); + + if(properties.GatewayAddresses.All(g => g.Address.AddressFamily != AddressFamily.InterNetwork)) { + continue; + } + + addressList.AddRange(properties.UnicastAddresses.Where(i => i.Address.AddressFamily == AddressFamily.InterNetwork).Select(i => i.Address)); + } + + if(includeLoopback || interfaceType == NetworkInterfaceType.Loopback) { + addressList.Add(IPAddress.Loopback); + } + + return addressList.ToArray(); + } + + /// + /// Gets the public IP address using ipify.org. + /// + /// The cancellation token. + /// A public IP address of the result produced by this Task. + public static async Task GetPublicIPAddressAsync(CancellationToken cancellationToken = default) { + using HttpClient client = new HttpClient(); + HttpResponseMessage response = await client.GetAsync("https://api.ipify.org", cancellationToken).ConfigureAwait(false); + return IPAddress.Parse(await response.Content.ReadAsStringAsync().ConfigureAwait(false)); + } + + /// + /// Gets the configured IPv4 DNS servers for the active network interfaces. + /// + /// + /// A collection of NetworkInterface/IPInterfaceProperties pairs + /// that represents the active IPv4 interfaces. + /// + public static IPAddress[] GetIPv4DnsServers() => GetIPv4Interfaces().Select(a => a.Value.DnsAddresses.Where(d => d.AddressFamily == AddressFamily.InterNetwork)).SelectMany(d => d).ToArray(); + + #endregion + + #region DNS and NTP Clients + + /// + /// Gets the DNS host entry (a list of IP addresses) for the domain name. + /// + /// The FQDN. + /// An array of local ip addresses of the result produced by this task. + public static Task GetDnsHostEntryAsync(String fqdn) { + IPAddress dnsServer = GetIPv4DnsServers().FirstOrDefault() ?? IPAddress.Parse("8.8.8.8"); + return GetDnsHostEntryAsync(fqdn, dnsServer, DnsDefaultPort); + } + + /// + /// Gets the DNS host entry (a list of IP addresses) for the domain name. + /// + /// The FQDN. + /// The DNS server. + /// The port. + /// + /// An array of local ip addresses of the result produced by this task. + /// + /// fqdn. + public static async Task GetDnsHostEntryAsync(String fqdn, IPAddress dnsServer, Int32 port) { + if(fqdn == null) { + throw new ArgumentNullException(nameof(fqdn)); + } + + if(fqdn.IndexOf(".", StringComparison.Ordinal) == -1) { + fqdn += "." + IPGlobalProperties.GetIPGlobalProperties().DomainName; + } + + while(true) { + if(!fqdn.EndsWith(".", StringComparison.OrdinalIgnoreCase)) { + break; + } + + fqdn = fqdn[0..^1]; + } + + DnsClient client = new DnsClient(dnsServer, port); + IList result = await client.Lookup(fqdn).ConfigureAwait(false); + return result.ToArray(); + } + + /// + /// Gets the reverse lookup FQDN of the given IP Address. + /// + /// The query. + /// The DNS server. + /// The port. + /// A that represents the current object. + public static Task GetDnsPointerEntryAsync(IPAddress query, IPAddress dnsServer, Int32 port) { + DnsClient client = new DnsClient(dnsServer, port); + return client.Reverse(query); + } + + /// + /// Gets the reverse lookup FQDN of the given IP Address. + /// + /// The query. + /// A that represents the current object. + public static Task GetDnsPointerEntryAsync(IPAddress query) { + DnsClient client = new DnsClient(GetIPv4DnsServers().FirstOrDefault()); + return client.Reverse(query); + } + + /// + /// Queries the DNS server for the specified record type. + /// + /// The query. + /// Type of the record. + /// The DNS server. + /// The port. + /// Queries the DNS server for the specified record type of the result produced by this Task. + public static async Task QueryDnsAsync(String query, DnsRecordType recordType, IPAddress dnsServer, Int32 port) { + if(query == null) { + throw new ArgumentNullException(nameof(query)); + } + + DnsClient client = new DnsClient(dnsServer, port); + DnsClient.DnsClientResponse response = await client.Resolve(query, recordType).ConfigureAwait(false); + return new DnsQueryResult(response); + } + + /// + /// Queries the DNS server for the specified record type. + /// + /// The query. + /// Type of the record. + /// Queries the DNS server for the specified record type of the result produced by this Task. + public static Task QueryDnsAsync(String query, DnsRecordType recordType) => QueryDnsAsync(query, recordType, GetIPv4DnsServers().FirstOrDefault(), DnsDefaultPort); + + /// + /// Gets the UTC time by querying from an NTP server. + /// + /// The NTP server address. + /// The port. + /// The UTC time by querying from an NTP server of the result produced by this Task. + public static async Task GetNetworkTimeUtcAsync(IPAddress ntpServerAddress, Int32 port = NtpDefaultPort) { + if(ntpServerAddress == null) { + throw new ArgumentNullException(nameof(ntpServerAddress)); + } + + // NTP message size - 16 bytes of the digest (RFC 2030) + Byte[] ntpData = new Byte[48]; + + // Setting the Leap Indicator, Version Number and Mode values + ntpData[0] = 0x1B; // LI = 0 (no warning), VN = 3 (IPv4 only), Mode = 3 (Client Mode) + + // The UDP port number assigned to NTP is 123 + IPEndPoint endPoint = new IPEndPoint(ntpServerAddress, port); + Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + - /// - /// The NTP default port. - /// - public const int NtpDefaultPort = 123; + await socket.ConnectAsync(endPoint).ConfigureAwait(false); - /// - /// Gets the name of the host. - /// - /// - /// The name of the host. - /// - public static string HostName => IPGlobalProperties.GetIPGlobalProperties().HostName; - - /// - /// Gets the name of the network domain. - /// - /// - /// The name of the network domain. - /// - public static string DomainName => IPGlobalProperties.GetIPGlobalProperties().DomainName; - - #region IP Addresses and Adapters Information Methods - - /// - /// Gets the active IPv4 interfaces. - /// Only those interfaces with a valid unicast address and a valid gateway will be returned in the collection. - /// - /// - /// A collection of NetworkInterface/IPInterfaceProperties pairs - /// that represents the active IPv4 interfaces. - /// - public static Dictionary GetIPv4Interfaces() - { - // zero conf ip address - var zeroConf = new IPAddress(0); - - var adapters = NetworkInterface.GetAllNetworkInterfaces() - .Where(network => - network.OperationalStatus == OperationalStatus.Up - && network.NetworkInterfaceType != NetworkInterfaceType.Unknown - && network.NetworkInterfaceType != NetworkInterfaceType.Loopback) - .ToArray(); - - var result = new Dictionary(); - - foreach (var adapter in adapters) - { - var properties = adapter.GetIPProperties(); - if (properties == null - || properties.GatewayAddresses.Count == 0 - || properties.GatewayAddresses.All(gateway => Equals(gateway.Address, zeroConf)) - || properties.UnicastAddresses.Count == 0 - || properties.GatewayAddresses.All(address => Equals(address.Address, zeroConf)) - || properties.UnicastAddresses.Any(a => a.Address.AddressFamily == AddressFamily.InterNetwork) == - false) - continue; - - result[adapter] = properties; - } - - return result; - } - - /// - /// Retrieves the local ip addresses. - /// - /// if set to true [include loopback]. - /// An array of local ip addresses. - public static IPAddress[] GetIPv4Addresses(bool includeLoopback = true) => - GetIPv4Addresses(NetworkInterfaceType.Unknown, true, includeLoopback); - - /// - /// Retrieves the local ip addresses. - /// - /// Type of the interface. - /// if set to true [skip type filter]. - /// if set to true [include loopback]. - /// An array of local ip addresses. - public static IPAddress[] GetIPv4Addresses( - NetworkInterfaceType interfaceType, - bool skipTypeFilter = false, - bool includeLoopback = false) - { - var addressList = new List(); - var interfaces = NetworkInterface.GetAllNetworkInterfaces() - .Where(ni => -#if NET461 - ni.IsReceiveOnly == false && -#endif - (skipTypeFilter || ni.NetworkInterfaceType == interfaceType) && - ni.OperationalStatus == OperationalStatus.Up) - .ToArray(); - - foreach (var networkInterface in interfaces) - { - var properties = networkInterface.GetIPProperties(); - - if (properties.GatewayAddresses.All(g => g.Address.AddressFamily != AddressFamily.InterNetwork)) - continue; - - addressList.AddRange(properties.UnicastAddresses - .Where(i => i.Address.AddressFamily == AddressFamily.InterNetwork) - .Select(i => i.Address)); - } - - if (includeLoopback || interfaceType == NetworkInterfaceType.Loopback) - addressList.Add(IPAddress.Loopback); - - return addressList.ToArray(); - } - - /// - /// Gets the public IP address using ipify.org. - /// - /// The cancellation token. - /// A public IP address of the result produced by this Task. - public static async Task GetPublicIPAddressAsync(CancellationToken cancellationToken = default) - { - using var client = new HttpClient(); - var response = await client.GetAsync("https://api.ipify.org", cancellationToken).ConfigureAwait(false); - return IPAddress.Parse(await response.Content.ReadAsStringAsync().ConfigureAwait(false)); - } - - /// - /// Gets the configured IPv4 DNS servers for the active network interfaces. - /// - /// - /// A collection of NetworkInterface/IPInterfaceProperties pairs - /// that represents the active IPv4 interfaces. - /// - public static IPAddress[] GetIPv4DnsServers() - => GetIPv4Interfaces() - .Select(a => a.Value.DnsAddresses.Where(d => d.AddressFamily == AddressFamily.InterNetwork)) - .SelectMany(d => d) - .ToArray(); - - #endregion - - #region DNS and NTP Clients - - /// - /// Gets the DNS host entry (a list of IP addresses) for the domain name. - /// - /// The FQDN. - /// An array of local ip addresses of the result produced by this task. - public static Task GetDnsHostEntryAsync(string fqdn) - { - var dnsServer = GetIPv4DnsServers().FirstOrDefault() ?? IPAddress.Parse("8.8.8.8"); - return GetDnsHostEntryAsync(fqdn, dnsServer, DnsDefaultPort); - } - - /// - /// Gets the DNS host entry (a list of IP addresses) for the domain name. - /// - /// The FQDN. - /// The DNS server. - /// The port. - /// - /// An array of local ip addresses of the result produced by this task. - /// - /// fqdn. - public static async Task GetDnsHostEntryAsync(string fqdn, IPAddress dnsServer, int port) - { - if (fqdn == null) - throw new ArgumentNullException(nameof(fqdn)); - - if (fqdn.IndexOf(".", StringComparison.Ordinal) == -1) - { - fqdn += "." + IPGlobalProperties.GetIPGlobalProperties().DomainName; - } - - while (true) - { - if (!fqdn.EndsWith(".", StringComparison.OrdinalIgnoreCase)) break; - - fqdn = fqdn.Substring(0, fqdn.Length - 1); - } - - var client = new DnsClient(dnsServer, port); - var result = await client.Lookup(fqdn).ConfigureAwait(false); - return result.ToArray(); - } - - /// - /// Gets the reverse lookup FQDN of the given IP Address. - /// - /// The query. - /// The DNS server. - /// The port. - /// A that represents the current object. - public static Task GetDnsPointerEntryAsync(IPAddress query, IPAddress dnsServer, int port) - { - var client = new DnsClient(dnsServer, port); - return client.Reverse(query); - } - - /// - /// Gets the reverse lookup FQDN of the given IP Address. - /// - /// The query. - /// A that represents the current object. - public static Task GetDnsPointerEntryAsync(IPAddress query) - { - var client = new DnsClient(GetIPv4DnsServers().FirstOrDefault()); - return client.Reverse(query); - } - - /// - /// Queries the DNS server for the specified record type. - /// - /// The query. - /// Type of the record. - /// The DNS server. - /// The port. - /// Queries the DNS server for the specified record type of the result produced by this Task. - public static async Task QueryDnsAsync(string query, DnsRecordType recordType, IPAddress dnsServer, int port) - { - if (query == null) - throw new ArgumentNullException(nameof(query)); - - var client = new DnsClient(dnsServer, port); - var response = await client.Resolve(query, recordType).ConfigureAwait(false); - return new DnsQueryResult(response); - } - - /// - /// Queries the DNS server for the specified record type. - /// - /// The query. - /// Type of the record. - /// Queries the DNS server for the specified record type of the result produced by this Task. - public static Task QueryDnsAsync(string query, DnsRecordType recordType) => QueryDnsAsync(query, recordType, GetIPv4DnsServers().FirstOrDefault(), DnsDefaultPort); - - /// - /// Gets the UTC time by querying from an NTP server. - /// - /// The NTP server address. - /// The port. - /// The UTC time by querying from an NTP server of the result produced by this Task. - public static async Task GetNetworkTimeUtcAsync(IPAddress ntpServerAddress, int port = NtpDefaultPort) - { - if (ntpServerAddress == null) - throw new ArgumentNullException(nameof(ntpServerAddress)); - - // NTP message size - 16 bytes of the digest (RFC 2030) - var ntpData = new byte[48]; - - // Setting the Leap Indicator, Version Number and Mode values - ntpData[0] = 0x1B; // LI = 0 (no warning), VN = 3 (IPv4 only), Mode = 3 (Client Mode) - - // The UDP port number assigned to NTP is 123 - var endPoint = new IPEndPoint(ntpServerAddress, port); - var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); - -#if !NET461 - await socket.ConnectAsync(endPoint).ConfigureAwait(false); -#else - socket.Connect(endPoint); -#endif - - socket.ReceiveTimeout = 3000; // Stops code hang if NTP is blocked - socket.Send(ntpData); - socket.Receive(ntpData); - socket.Dispose(); - - // Offset to get to the "Transmit Timestamp" field (time at which the reply - // departed the server for the client, in 64-bit timestamp format." - const byte serverReplyTime = 40; - - // Get the seconds part - ulong intPart = BitConverter.ToUInt32(ntpData, serverReplyTime); - - // Get the seconds fraction - ulong fractPart = BitConverter.ToUInt32(ntpData, serverReplyTime + 4); - - // Convert From big-endian to little-endian to match the platform - if (BitConverter.IsLittleEndian) - { - intPart = intPart.SwapEndianness(); - fractPart = intPart.SwapEndianness(); - } - - var milliseconds = (intPart * 1000) + ((fractPart * 1000) / 0x100000000L); - - // The time is given in UTC - return new DateTime(1900, 1, 1, 0, 0, 0, DateTimeKind.Utc).AddMilliseconds((long) milliseconds); - } - - /// - /// Gets the UTC time by querying from an NTP server. - /// - /// The NTP server, by default pool.ntp.org. - /// The port, by default NTP 123. - /// The UTC time by querying from an NTP server of the result produced by this Task. - public static async Task GetNetworkTimeUtcAsync(string ntpServerName = "pool.ntp.org", - int port = NtpDefaultPort) - { - var addresses = await GetDnsHostEntryAsync(ntpServerName).ConfigureAwait(false); - return await GetNetworkTimeUtcAsync(addresses.First(), port).ConfigureAwait(false); - } - - #endregion - } + + socket.ReceiveTimeout = 3000; // Stops code hang if NTP is blocked + _ = socket.Send(ntpData); + _ = socket.Receive(ntpData); + socket.Dispose(); + + // Offset to get to the "Transmit Timestamp" field (time at which the reply + // departed the server for the client, in 64-bit timestamp format." + const Byte serverReplyTime = 40; + + // Get the seconds part + UInt64 intPart = BitConverter.ToUInt32(ntpData, serverReplyTime); + + // Get the seconds fraction + UInt64 fractPart = BitConverter.ToUInt32(ntpData, serverReplyTime + 4); + + // Convert From big-endian to little-endian to match the platform + if(BitConverter.IsLittleEndian) { + intPart = intPart.SwapEndianness(); + fractPart = intPart.SwapEndianness(); + } + + UInt64 milliseconds = intPart * 1000 + fractPart * 1000 / 0x100000000L; + + // The time is given in UTC + return new DateTime(1900, 1, 1, 0, 0, 0, DateTimeKind.Utc).AddMilliseconds((Int64)milliseconds); + } + + /// + /// Gets the UTC time by querying from an NTP server. + /// + /// The NTP server, by default pool.ntp.org. + /// The port, by default NTP 123. + /// The UTC time by querying from an NTP server of the result produced by this Task. + public static async Task GetNetworkTimeUtcAsync(String ntpServerName = "pool.ntp.org", Int32 port = NtpDefaultPort) { + IPAddress[] addresses = await GetDnsHostEntryAsync(ntpServerName).ConfigureAwait(false); + return await GetNetworkTimeUtcAsync(addresses.First(), port).ConfigureAwait(false); + } + + #endregion + } } diff --git a/Swan/Net/Smtp/Enums.Smtp.cs b/Swan/Net/Smtp/Enums.Smtp.cs index 069bbff..d2c11e7 100644 --- a/Swan/Net/Smtp/Enums.Smtp.cs +++ b/Swan/Net/Smtp/Enums.Smtp.cs @@ -1,166 +1,162 @@ // ReSharper disable InconsistentNaming -namespace Swan.Net.Smtp -{ +namespace Swan.Net.Smtp { + /// + /// Enumerates all of the well-known SMTP command names. + /// + public enum SmtpCommandNames { /// - /// Enumerates all of the well-known SMTP command names. + /// An unknown command /// - public enum SmtpCommandNames - { - /// - /// An unknown command - /// - Unknown, - - /// - /// The helo command - /// - HELO, - - /// - /// The ehlo command - /// - EHLO, - - /// - /// The quit command - /// - QUIT, - - /// - /// The help command - /// - HELP, - - /// - /// The noop command - /// - NOOP, - - /// - /// The rset command - /// - RSET, - - /// - /// The mail command - /// - MAIL, - - /// - /// The data command - /// - DATA, - - /// - /// The send command - /// - SEND, - - /// - /// The soml command - /// - SOML, - - /// - /// The saml command - /// - SAML, - - /// - /// The RCPT command - /// - RCPT, - - /// - /// The vrfy command - /// - VRFY, - - /// - /// The expn command - /// - EXPN, - - /// - /// The starttls command - /// - STARTTLS, - - /// - /// The authentication command - /// - AUTH, - } - + Unknown, + /// - /// Enumerates the reply code severities. + /// The helo command /// - public enum SmtpReplyCodeSeverities - { - /// - /// The unknown severity - /// - Unknown = 0, - - /// - /// The positive completion severity - /// - PositiveCompletion = 200, - - /// - /// The positive intermediate severity - /// - PositiveIntermediate = 300, - - /// - /// The transient negative severity - /// - TransientNegative = 400, - - /// - /// The permanent negative severity - /// - PermanentNegative = 500, - } - + HELO, + /// - /// Enumerates the reply code categories. + /// The ehlo command /// - public enum SmtpReplyCodeCategories - { - /// - /// The unknown category - /// - Unknown = -1, - - /// - /// The syntax category - /// - Syntax = 0, - - /// - /// The information category - /// - Information = 1, - - /// - /// The connections category - /// - Connections = 2, - - /// - /// The unspecified a category - /// - UnspecifiedA = 3, - - /// - /// The unspecified b category - /// - UnspecifiedB = 4, - - /// - /// The system category - /// - System = 5, - } + EHLO, + + /// + /// The quit command + /// + QUIT, + + /// + /// The help command + /// + HELP, + + /// + /// The noop command + /// + NOOP, + + /// + /// The rset command + /// + RSET, + + /// + /// The mail command + /// + MAIL, + + /// + /// The data command + /// + DATA, + + /// + /// The send command + /// + SEND, + + /// + /// The soml command + /// + SOML, + + /// + /// The saml command + /// + SAML, + + /// + /// The RCPT command + /// + RCPT, + + /// + /// The vrfy command + /// + VRFY, + + /// + /// The expn command + /// + EXPN, + + /// + /// The starttls command + /// + STARTTLS, + + /// + /// The authentication command + /// + AUTH, + } + + /// + /// Enumerates the reply code severities. + /// + public enum SmtpReplyCodeSeverities { + /// + /// The unknown severity + /// + Unknown = 0, + + /// + /// The positive completion severity + /// + PositiveCompletion = 200, + + /// + /// The positive intermediate severity + /// + PositiveIntermediate = 300, + + /// + /// The transient negative severity + /// + TransientNegative = 400, + + /// + /// The permanent negative severity + /// + PermanentNegative = 500, + } + + /// + /// Enumerates the reply code categories. + /// + public enum SmtpReplyCodeCategories { + /// + /// The unknown category + /// + Unknown = -1, + + /// + /// The syntax category + /// + Syntax = 0, + + /// + /// The information category + /// + Information = 1, + + /// + /// The connections category + /// + Connections = 2, + + /// + /// The unspecified a category + /// + UnspecifiedA = 3, + + /// + /// The unspecified b category + /// + UnspecifiedB = 4, + + /// + /// The system category + /// + System = 5, + } } \ No newline at end of file diff --git a/Swan/Net/Smtp/SmtpClient.cs b/Swan/Net/Smtp/SmtpClient.cs index bb0bfae..2684df4 100644 --- a/Swan/Net/Smtp/SmtpClient.cs +++ b/Swan/Net/Smtp/SmtpClient.cs @@ -1,388 +1,370 @@ -namespace Swan.Net.Smtp -{ - using System.Threading; - using System; - using System.Linq; - using System.Net; - using System.Net.Sockets; - using System.Security; - using System.Text; - using System.Net.Security; - using System.Threading.Tasks; - using System.Collections.Generic; - using System.Net.Mail; - +#nullable enable +using System.Threading; +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Security; +using System.Text; +using System.Net.Security; +using System.Threading.Tasks; +using System.Collections.Generic; +using System.Net.Mail; + +namespace Swan.Net.Smtp { + /// + /// Represents a basic SMTP client that is capable of submitting messages to an SMTP server. + /// + /// + /// The following code explains how to send a simple e-mail. + /// + /// using System.Net.Mail; + /// + /// class Example + /// { + /// static void Main() + /// { + /// // create a new smtp client using google's smtp server + /// var client = new Swan.Net.Smtp.SmtpClient("smtp.gmail.com", 587); + /// + /// // send an email + /// client.SendMailAsync( + /// new MailMessage("sender@test.com", "recipient@test.cm", "Subject", "Body")); + /// } + /// } + /// + /// + /// The following code demonstrates how to sent an e-mail using a SmtpSessionState: + /// + /// using Swan.Net.Smtp; + /// + /// class Example + /// { + /// static void Main() + /// { + /// // create a new smtp client using google's smtp server + /// var client = new SmtpClient("smtp.gmail.com", 587); + /// + /// // create a new session state with a sender address + /// var session = new SmtpSessionState { SenderAddress = "sender@test.com" }; + /// + /// // add a recipient + /// session.Recipients.Add("recipient@test.cm"); + /// + /// // send + /// client.SendMailAsync(session); + /// } + /// } + /// + /// + /// The following code shows how to send an e-mail with an attachment using MimeKit: + /// + /// using MimeKit; + /// using Swan.Net.Smtp; + /// + /// class Example + /// { + /// static void Main() + /// { + /// // create a new smtp client using google's smtp server + /// var client = new SmtpClient("smtp.gmail.com", 587); + /// + /// // create a new session state with a sender address + /// var session = new SmtpSessionState { SenderAddress = "sender@test.com" }; + /// + /// // add a recipient + /// session.Recipients.Add("recipient@test.cm"); + /// + /// // load a file as an attachment + /// var attachment = new MimePart("image", "gif") + /// { + /// Content = new + /// MimeContent(File.OpenRead("meme.gif"), ContentEncoding.Default), + /// ContentDisposition = + /// new ContentDisposition(ContentDisposition.Attachment), + /// ContentTransferEncoding = ContentEncoding.Base64, + /// FileName = Path.GetFileName("meme.gif") + /// }; + /// + /// // send + /// client.SendMailAsync(session); + /// } + /// } + /// + /// + public class SmtpClient { /// - /// Represents a basic SMTP client that is capable of submitting messages to an SMTP server. + /// Initializes a new instance of the class. /// - /// - /// The following code explains how to send a simple e-mail. - /// - /// using System.Net.Mail; - /// - /// class Example - /// { - /// static void Main() - /// { - /// // create a new smtp client using google's smtp server - /// var client = new Swan.Net.Smtp.SmtpClient("smtp.gmail.com", 587); - /// - /// // send an email - /// client.SendMailAsync( - /// new MailMessage("sender@test.com", "recipient@test.cm", "Subject", "Body")); - /// } - /// } - /// - /// - /// The following code demonstrates how to sent an e-mail using a SmtpSessionState: - /// - /// using Swan.Net.Smtp; - /// - /// class Example - /// { - /// static void Main() - /// { - /// // create a new smtp client using google's smtp server - /// var client = new SmtpClient("smtp.gmail.com", 587); - /// - /// // create a new session state with a sender address - /// var session = new SmtpSessionState { SenderAddress = "sender@test.com" }; - /// - /// // add a recipient - /// session.Recipients.Add("recipient@test.cm"); - /// - /// // send - /// client.SendMailAsync(session); - /// } - /// } - /// - /// - /// The following code shows how to send an e-mail with an attachment using MimeKit: - /// - /// using MimeKit; - /// using Swan.Net.Smtp; - /// - /// class Example - /// { - /// static void Main() - /// { - /// // create a new smtp client using google's smtp server - /// var client = new SmtpClient("smtp.gmail.com", 587); - /// - /// // create a new session state with a sender address - /// var session = new SmtpSessionState { SenderAddress = "sender@test.com" }; - /// - /// // add a recipient - /// session.Recipients.Add("recipient@test.cm"); - /// - /// // load a file as an attachment - /// var attachment = new MimePart("image", "gif") - /// { - /// Content = new - /// MimeContent(File.OpenRead("meme.gif"), ContentEncoding.Default), - /// ContentDisposition = - /// new ContentDisposition(ContentDisposition.Attachment), - /// ContentTransferEncoding = ContentEncoding.Base64, - /// FileName = Path.GetFileName("meme.gif") - /// }; - /// - /// // send - /// client.SendMailAsync(session); - /// } - /// } - /// - /// - public class SmtpClient - { - /// - /// Initializes a new instance of the class. - /// - /// The host. - /// The port. - /// host. - public SmtpClient(string host, int port) - { - Host = host ?? throw new ArgumentNullException(nameof(host)); - Port = port; - ClientHostname = Network.HostName; - } - - /// - /// Gets or sets the credentials. No credentials will be used if set to null. - /// - /// - /// The credentials. - /// - public NetworkCredential Credentials { get; set; } - - /// - /// Gets the host. - /// - /// - /// The host. - /// - public string Host { get; } - - /// - /// Gets the port. - /// - /// - /// The port. - /// - public int Port { get; } - - /// - /// Gets or sets a value indicating whether the SSL is enabled. - /// If set to false, communication between client and server will not be secured. - /// - /// - /// true if [enable SSL]; otherwise, false. - /// - public bool EnableSsl { get; set; } - - /// - /// Gets or sets the name of the client that gets announced to the server. - /// - /// - /// The client hostname. - /// - public string ClientHostname { get; set; } - - /// - /// Sends an email message asynchronously. - /// - /// The message. - /// The session identifier. - /// The callback. - /// The cancellation token. - /// - /// A task that represents the asynchronous of send email operation. - /// - /// message. - public Task SendMailAsync( - MailMessage message, - string? sessionId = null, - RemoteCertificateValidationCallback? callback = null, - CancellationToken cancellationToken = default) - { - if (message == null) - throw new ArgumentNullException(nameof(message)); - - var state = new SmtpSessionState - { - AuthMode = Credentials == null ? string.Empty : SmtpDefinitions.SmtpAuthMethods.Login, - ClientHostname = ClientHostname, - IsChannelSecure = EnableSsl, - SenderAddress = message.From.Address, - }; - - if (Credentials != null) - { - state.Username = Credentials.UserName; - state.Password = Credentials.Password; - } - - foreach (var recipient in message.To) - { - state.Recipients.Add(recipient.Address); - } - - state.DataBuffer.AddRange(message.ToMimeMessage().ToArray()); - - return SendMailAsync(state, sessionId, callback, cancellationToken); - } - - /// - /// Sends an email message using a session state object. - /// Credentials, Enable SSL and Client Hostname are NOT taken from the state object but - /// rather from the properties of this class. - /// - /// The state. - /// The session identifier. - /// The callback. - /// The cancellation token. - /// - /// A task that represents the asynchronous of send email operation. - /// - /// sessionState. - public Task SendMailAsync( - SmtpSessionState sessionState, - string? sessionId = null, - RemoteCertificateValidationCallback? callback = null, - CancellationToken cancellationToken = default) - { - if (sessionState == null) - throw new ArgumentNullException(nameof(sessionState)); - - return SendMailAsync(new[] { sessionState }, sessionId, callback, cancellationToken); - } - - /// - /// Sends an array of email messages using a session state object. - /// Credentials, Enable SSL and Client Hostname are NOT taken from the state object but - /// rather from the properties of this class. - /// - /// The session states. - /// The session identifier. - /// The callback. - /// The cancellation token. - /// - /// A task that represents the asynchronous of send email operation. - /// - /// sessionStates. - /// Could not upgrade the channel to SSL. - /// Defines an SMTP Exceptions class. - public async Task SendMailAsync( - IEnumerable sessionStates, - string? sessionId = null, - RemoteCertificateValidationCallback? callback = null, - CancellationToken cancellationToken = default) - { - if (sessionStates == null) - throw new ArgumentNullException(nameof(sessionStates)); - - using var tcpClient = new TcpClient(); - await tcpClient.ConnectAsync(Host, Port).ConfigureAwait(false); - - using var connection = new Connection(tcpClient, Encoding.UTF8, "\r\n", true, 1000); - var sender = new SmtpSender(sessionId); - - try - { - // Read the greeting message - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - - // EHLO 1 - await SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); - - // STARTTLS - if (EnableSsl) - { - sender.RequestText = $"{SmtpCommandNames.STARTTLS}"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - - if (await connection.UpgradeToSecureAsClientAsync(callback: callback).ConfigureAwait(false) == false) - throw new SecurityException("Could not upgrade the channel to SSL."); - } - - // EHLO 2 - await SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); - - // AUTH - if (Credentials != null) - { - var auth = new ConnectionAuth(connection, sender, Credentials); - await auth.AuthenticateAsync(cancellationToken).ConfigureAwait(false); - } - - foreach (var sessionState in sessionStates) - { - { - // MAIL FROM - sender.RequestText = $"{SmtpCommandNames.MAIL} FROM:<{sessionState.SenderAddress}>"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - } - - // RCPT TO - foreach (var recipient in sessionState.Recipients) - { - sender.RequestText = $"{SmtpCommandNames.RCPT} TO:<{recipient}>"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - } - - { - // DATA - sender.RequestText = $"{SmtpCommandNames.DATA}"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - } - - { - // CONTENT - var dataTerminator = sessionState.DataBuffer - .Skip(sessionState.DataBuffer.Count - 5) - .ToText(); - - sender.RequestText = $"Buffer ({sessionState.DataBuffer.Count} bytes)"; - - await connection.WriteDataAsync(sessionState.DataBuffer.ToArray(), true, cancellationToken).ConfigureAwait(false); - - if (!dataTerminator.EndsWith(SmtpDefinitions.SmtpDataCommandTerminator)) - await connection.WriteTextAsync(SmtpDefinitions.SmtpDataCommandTerminator, cancellationToken).ConfigureAwait(false); - - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - } - } - - { - // QUIT - sender.RequestText = $"{SmtpCommandNames.QUIT}"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - sender.ValidateReply(); - } - } - catch (Exception ex) - { - throw new SmtpException($"Could not send email - Session ID {sessionId}. {ex.Message}\r\n Last Request: {sender.RequestText}\r\n Last Reply: {sender.ReplyText}"); - } - } - - private async Task SendEhlo(SmtpSender sender, Connection connection, CancellationToken cancellationToken) - { - sender.RequestText = $"{SmtpCommandNames.EHLO} {ClientHostname}"; - - await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); - - do - { - sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); - } - while (!sender.IsReplyOk); - - sender.ValidateReply(); - } - - private class ConnectionAuth - { - private readonly SmtpSender _sender; - private readonly Connection _connection; - private readonly NetworkCredential _credentials; - - public ConnectionAuth(Connection connection, SmtpSender sender, NetworkCredential credentials) - { - _connection = connection; - _sender = sender; - _credentials = credentials; - } - - public async Task AuthenticateAsync(CancellationToken ct) - { - _sender.RequestText = - $"{SmtpCommandNames.AUTH} {SmtpDefinitions.SmtpAuthMethods.Login} {Convert.ToBase64String(Encoding.UTF8.GetBytes(_credentials.UserName))}"; - - await _connection.WriteLineAsync(_sender.RequestText, ct).ConfigureAwait(false); - _sender.ReplyText = await _connection.ReadLineAsync(ct).ConfigureAwait(false); - _sender.ValidateReply(); - _sender.RequestText = Convert.ToBase64String(Encoding.UTF8.GetBytes(_credentials.Password)); - - await _connection.WriteLineAsync(_sender.RequestText, ct).ConfigureAwait(false); - _sender.ReplyText = await _connection.ReadLineAsync(ct).ConfigureAwait(false); - _sender.ValidateReply(); - } - } - } + /// The host. + /// The port. + /// host. + public SmtpClient(String host, Int32 port) { + this.Host = host ?? throw new ArgumentNullException(nameof(host)); + this.Port = port; + this.ClientHostname = Network.HostName; + } + + /// + /// Gets or sets the credentials. No credentials will be used if set to null. + /// + /// + /// The credentials. + /// + public NetworkCredential? Credentials { + get; set; + } + + /// + /// Gets the host. + /// + /// + /// The host. + /// + public String Host { + get; + } + + /// + /// Gets the port. + /// + /// + /// The port. + /// + public Int32 Port { + get; + } + + /// + /// Gets or sets a value indicating whether the SSL is enabled. + /// If set to false, communication between client and server will not be secured. + /// + /// + /// true if [enable SSL]; otherwise, false. + /// + public Boolean EnableSsl { + get; set; + } + + /// + /// Gets or sets the name of the client that gets announced to the server. + /// + /// + /// The client hostname. + /// + public String ClientHostname { + get; set; + } + + + /// + /// Sends an email message asynchronously. + /// + /// The message. + /// The session identifier. + /// The callback. + /// The cancellation token. + /// + /// A task that represents the asynchronous of send email operation. + /// + /// message. + [System.Diagnostics.CodeAnalysis.SuppressMessage("Codequalität", "IDE0067:Objekte verwerfen, bevor Bereich verloren geht", Justification = "")] + public Task SendMailAsync(MailMessage message, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) { + if(message == null) { + throw new ArgumentNullException(nameof(message)); + } + + SmtpSessionState state = new SmtpSessionState { + AuthMode = this.Credentials == null ? String.Empty : SmtpDefinitions.SmtpAuthMethods.Login, + ClientHostname = ClientHostname, + IsChannelSecure = EnableSsl, + SenderAddress = message.From.Address, + }; + + if(this.Credentials != null) { + state.Username = this.Credentials.UserName; + state.Password = this.Credentials.Password; + } + + foreach(MailAddress recipient in message.To) { + state.Recipients.Add(recipient.Address); + } + + state.DataBuffer.AddRange(message.ToMimeMessage().ToArray()); + + return this.SendMailAsync(state, sessionId, callback, cancellationToken); + } + + /// + /// Sends an email message using a session state object. + /// Credentials, Enable SSL and Client Hostname are NOT taken from the state object but + /// rather from the properties of this class. + /// + /// The state. + /// The session identifier. + /// The callback. + /// The cancellation token. + /// + /// A task that represents the asynchronous of send email operation. + /// + /// sessionState. + public Task SendMailAsync(SmtpSessionState sessionState, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) { + if(sessionState == null) { + throw new ArgumentNullException(nameof(sessionState)); + } + + return this.SendMailAsync(new[] { sessionState }, sessionId, callback, cancellationToken); + } + + /// + /// Sends an array of email messages using a session state object. + /// Credentials, Enable SSL and Client Hostname are NOT taken from the state object but + /// rather from the properties of this class. + /// + /// The session states. + /// The session identifier. + /// The callback. + /// The cancellation token. + /// + /// A task that represents the asynchronous of send email operation. + /// + /// sessionStates. + /// Could not upgrade the channel to SSL. + /// Defines an SMTP Exceptions class. + public async Task SendMailAsync(IEnumerable sessionStates, String? sessionId = null, RemoteCertificateValidationCallback? callback = null, CancellationToken cancellationToken = default) { + if(sessionStates == null) { + throw new ArgumentNullException(nameof(sessionStates)); + } + + using TcpClient tcpClient = new TcpClient(); + await tcpClient.ConnectAsync(this.Host, this.Port).ConfigureAwait(false); + + using Connection connection = new Connection(tcpClient, Encoding.UTF8, "\r\n", true, 1000); + SmtpSender sender = new SmtpSender(sessionId); + + try { + // Read the greeting message + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + + // EHLO 1 + await this.SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); + + // STARTTLS + if(this.EnableSsl) { + sender.RequestText = $"{SmtpCommandNames.STARTTLS}"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + + if(await connection.UpgradeToSecureAsClientAsync(callback: callback).ConfigureAwait(false) == false) { + throw new SecurityException("Could not upgrade the channel to SSL."); + } + } + + // EHLO 2 + await this.SendEhlo(sender, connection, cancellationToken).ConfigureAwait(false); + + // AUTH + if(this.Credentials != null) { + ConnectionAuth auth = new ConnectionAuth(connection, sender, this.Credentials); + await auth.AuthenticateAsync(cancellationToken).ConfigureAwait(false); + } + + foreach(SmtpSessionState sessionState in sessionStates) { + { + // MAIL FROM + sender.RequestText = $"{SmtpCommandNames.MAIL} FROM:<{sessionState.SenderAddress}>"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + } + + // RCPT TO + foreach(String recipient in sessionState.Recipients) { + sender.RequestText = $"{SmtpCommandNames.RCPT} TO:<{recipient}>"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + } + + { + // DATA + sender.RequestText = $"{SmtpCommandNames.DATA}"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + } + + { + // CONTENT + String dataTerminator = sessionState.DataBuffer.Skip(sessionState.DataBuffer.Count - 5).ToText(); + + sender.RequestText = $"Buffer ({sessionState.DataBuffer.Count} bytes)"; + + await connection.WriteDataAsync(sessionState.DataBuffer.ToArray(), true, cancellationToken).ConfigureAwait(false); + + if(!dataTerminator.EndsWith(SmtpDefinitions.SmtpDataCommandTerminator)) { + await connection.WriteTextAsync(SmtpDefinitions.SmtpDataCommandTerminator, cancellationToken).ConfigureAwait(false); + } + + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + } + } + + { + // QUIT + sender.RequestText = $"{SmtpCommandNames.QUIT}"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + sender.ValidateReply(); + } + } catch(Exception ex) { + throw new SmtpException($"Could not send email - Session ID {sessionId}. {ex.Message}\r\n Last Request: {sender.RequestText}\r\n Last Reply: {sender.ReplyText}"); + } + } + + private async Task SendEhlo(SmtpSender sender, Connection connection, CancellationToken cancellationToken) { + sender.RequestText = $"{SmtpCommandNames.EHLO} {this.ClientHostname}"; + + await connection.WriteLineAsync(sender.RequestText, cancellationToken).ConfigureAwait(false); + + do { + sender.ReplyText = await connection.ReadLineAsync(cancellationToken).ConfigureAwait(false); + } + while(!sender.IsReplyOk); + + sender.ValidateReply(); + } + + private class ConnectionAuth { + private readonly SmtpSender _sender; + private readonly Connection _connection; + private readonly NetworkCredential _credentials; + + public ConnectionAuth(Connection connection, SmtpSender sender, NetworkCredential credentials) { + this._connection = connection; + this._sender = sender; + this._credentials = credentials; + } + + public async Task AuthenticateAsync(CancellationToken ct) { + this._sender.RequestText = $"{SmtpCommandNames.AUTH} {SmtpDefinitions.SmtpAuthMethods.Login} {Convert.ToBase64String(Encoding.UTF8.GetBytes(this._credentials.UserName))}"; + + await this._connection.WriteLineAsync(this._sender.RequestText, ct).ConfigureAwait(false); + this._sender.ReplyText = await this._connection.ReadLineAsync(ct).ConfigureAwait(false); + this._sender.ValidateReply(); + this._sender.RequestText = Convert.ToBase64String(Encoding.UTF8.GetBytes(this._credentials.Password)); + + await this._connection.WriteLineAsync(this._sender.RequestText, ct).ConfigureAwait(false); + this._sender.ReplyText = await this._connection.ReadLineAsync(ct).ConfigureAwait(false); + this._sender.ValidateReply(); + } + } + } } diff --git a/Swan/Net/Smtp/SmtpDefinitions.cs b/Swan/Net/Smtp/SmtpDefinitions.cs index 6b8fdad..b8a3717 100644 --- a/Swan/Net/Smtp/SmtpDefinitions.cs +++ b/Swan/Net/Smtp/SmtpDefinitions.cs @@ -1,29 +1,28 @@ -namespace Swan.Net.Smtp -{ +using System; + +namespace Swan.Net.Smtp { + /// + /// Contains useful constants and definitions. + /// + public static class SmtpDefinitions { /// - /// Contains useful constants and definitions. + /// The string sequence that delimits the end of the DATA command. /// - public static class SmtpDefinitions - { - /// - /// The string sequence that delimits the end of the DATA command. - /// - public const string SmtpDataCommandTerminator = "\r\n.\r\n"; - - /// - /// Lists the AUTH methods supported by default. - /// - public static class SmtpAuthMethods - { - /// - /// The plain method. - /// - public const string Plain = "PLAIN"; - - /// - /// The login method. - /// - public const string Login = "LOGIN"; - } - } + public const String SmtpDataCommandTerminator = "\r\n.\r\n"; + + /// + /// Lists the AUTH methods supported by default. + /// + public static class SmtpAuthMethods { + /// + /// The plain method. + /// + public const String Plain = "PLAIN"; + + /// + /// The login method. + /// + public const String Login = "LOGIN"; + } + } } diff --git a/Swan/Net/Smtp/SmtpSender.cs b/Swan/Net/Smtp/SmtpSender.cs index 61be247..5a3e819 100644 --- a/Swan/Net/Smtp/SmtpSender.cs +++ b/Swan/Net/Smtp/SmtpSender.cs @@ -1,60 +1,53 @@ -namespace Swan.Net.Smtp -{ - using Logging; - using System; - using System.Linq; - using System.Net.Mail; - - /// - /// Use this class to store the sender session data. - /// - internal class SmtpSender - { - private readonly string _sessionId; - private string _requestText; - - public SmtpSender(string sessionId) - { - _sessionId = sessionId; - } - - public string RequestText - { - get => _requestText; - set - { - _requestText = value; - $" TX {_requestText}".Trace(typeof(SmtpClient), _sessionId); - } - } - - public string ReplyText { get; set; } - - public bool IsReplyOk => ReplyText.StartsWith("250 ", StringComparison.OrdinalIgnoreCase); - - public void ValidateReply() - { - if (ReplyText == null) - throw new SmtpException("There was no response from the server"); - - try - { - var response = SmtpServerReply.Parse(ReplyText); - $" RX {ReplyText} - {response.IsPositive}".Trace(typeof(SmtpClient), _sessionId); - - if (response.IsPositive) return; - - var responseContent = response.Content.Any() - ? string.Join(";", response.Content.ToArray()) - : string.Empty; - - throw new SmtpException((SmtpStatusCode)response.ReplyCode, responseContent); - } - catch (Exception ex) - { - if (!(ex is SmtpException)) - throw new SmtpException($"Could not parse server response: {ReplyText}"); - } - } - } +using Swan.Logging; +using System; +using System.Linq; +using System.Net.Mail; + +namespace Swan.Net.Smtp { + /// + /// Use this class to store the sender session data. + /// + internal class SmtpSender { + private readonly String _sessionId; + private String _requestText; + + public SmtpSender(String sessionId) => this._sessionId = sessionId; + + public String RequestText { + get => this._requestText; + set { + this._requestText = value; + $" TX {this._requestText}".Trace(typeof(SmtpClient), this._sessionId); + } + } + + public String ReplyText { + get; set; + } + + public Boolean IsReplyOk => this.ReplyText.StartsWith("250 ", StringComparison.OrdinalIgnoreCase); + + public void ValidateReply() { + if(this.ReplyText == null) { + throw new SmtpException("There was no response from the server"); + } + + try { + SmtpServerReply response = SmtpServerReply.Parse(this.ReplyText); + $" RX {this.ReplyText} - {response.IsPositive}".Trace(typeof(SmtpClient), this._sessionId); + + if(response.IsPositive) { + return; + } + + String responseContent = response.Content.Any() ? String.Join(";", response.Content.ToArray()) : String.Empty; + + throw new SmtpException((SmtpStatusCode)response.ReplyCode, responseContent); + } catch(Exception ex) { + if(!(ex is SmtpException)) { + throw new SmtpException($"Could not parse server response: {this.ReplyText}"); + } + } + } + } } \ No newline at end of file diff --git a/Swan/Net/Smtp/SmtpServerReply.cs b/Swan/Net/Smtp/SmtpServerReply.cs index ae4b8c2..d08388d 100644 --- a/Swan/Net/Smtp/SmtpServerReply.cs +++ b/Swan/Net/Smtp/SmtpServerReply.cs @@ -1,243 +1,256 @@ -namespace Swan.Net.Smtp -{ - using System; - using System.Collections.Generic; - using System.Globalization; - using System.Linq; - using System.Text; - +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text; + +namespace Swan.Net.Smtp { + /// + /// Represents an SMTP server response object. + /// + public class SmtpServerReply { + #region Constructors + /// - /// Represents an SMTP server response object. + /// Initializes a new instance of the class. /// - public class SmtpServerReply - { - #region Constructors - - /// - /// Initializes a new instance of the class. - /// - /// The response code. - /// The status code. - /// The content. - public SmtpServerReply(int responseCode, string statusCode, params string[] content) - { - Content = new List(); - ReplyCode = responseCode; - EnhancedStatusCode = statusCode; - Content.AddRange(content); - IsValid = responseCode >= 200 && responseCode < 600; - ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown; - ReplyCodeCategory = SmtpReplyCodeCategories.Unknown; - - if (!IsValid) return; - if (responseCode >= 200) ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveCompletion; - if (responseCode >= 300) ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveIntermediate; - if (responseCode >= 400) ReplyCodeSeverity = SmtpReplyCodeSeverities.TransientNegative; - if (responseCode >= 500) ReplyCodeSeverity = SmtpReplyCodeSeverities.PermanentNegative; - if (responseCode >= 600) ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown; - - if (int.TryParse(responseCode.ToString(CultureInfo.InvariantCulture).Substring(1, 1), out var middleDigit)) - { - if (middleDigit >= 0 && middleDigit <= 5) - ReplyCodeCategory = (SmtpReplyCodeCategories) middleDigit; - } - } - - /// - /// Initializes a new instance of the class. - /// - public SmtpServerReply() - : this(0, string.Empty, string.Empty) - { - // placeholder - } - - /// - /// Initializes a new instance of the class. - /// - /// The response code. - /// The status code. - /// The content. - public SmtpServerReply(int responseCode, string statusCode, string content) - : this(responseCode, statusCode, new[] {content}) - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The response code. - /// The content. - public SmtpServerReply(int responseCode, string content) - : this(responseCode, string.Empty, content) - { - } - - #endregion - - #region Pre-built responses (https://tools.ietf.org/html/rfc5321#section-4.2.2) - - /// - /// Gets the command unrecognized reply. - /// - public static SmtpServerReply CommandUnrecognized => - new SmtpServerReply(500, "Syntax error, command unrecognized"); - - /// - /// Gets the syntax error arguments reply. - /// - public static SmtpServerReply SyntaxErrorArguments => - new SmtpServerReply(501, "Syntax error in parameters or arguments"); - - /// - /// Gets the command not implemented reply. - /// - public static SmtpServerReply CommandNotImplemented => new SmtpServerReply(502, "Command not implemented"); - - /// - /// Gets the bad sequence of commands reply. - /// - public static SmtpServerReply BadSequenceOfCommands => new SmtpServerReply(503, "Bad sequence of commands"); - - /// - /// Gets the protocol violation reply. - /// = - public static SmtpServerReply ProtocolViolation => - new SmtpServerReply(451, "Requested action aborted: error in processing"); - - /// - /// Gets the system status bye reply. - /// - public static SmtpServerReply SystemStatusBye => - new SmtpServerReply(221, "Service closing transmission channel"); - - /// - /// Gets the system status help reply. - /// = - public static SmtpServerReply SystemStatusHelp => new SmtpServerReply(221, "Refer to RFC 5321"); - - /// - /// Gets the bad syntax command empty reply. - /// - public static SmtpServerReply BadSyntaxCommandEmpty => new SmtpServerReply(400, "Error: bad syntax"); - - /// - /// Gets the OK reply. - /// - public static SmtpServerReply Ok => new SmtpServerReply(250, "OK"); - - /// - /// Gets the authorization required reply. - /// - public static SmtpServerReply AuthorizationRequired => new SmtpServerReply(530, "Authorization Required"); - - #endregion - - #region Properties - - /// - /// Gets the response severity. - /// - public SmtpReplyCodeSeverities ReplyCodeSeverity { get; } - - /// - /// Gets the response category. - /// - public SmtpReplyCodeCategories ReplyCodeCategory { get; } - - /// - /// Gets the numeric response code. - /// - public int ReplyCode { get; } - - /// - /// Gets the enhanced status code. - /// - public string EnhancedStatusCode { get; } - - /// - /// Gets the content. - /// - public List Content { get; } - - /// - /// Returns true if the response code is between 200 and 599. - /// - public bool IsValid { get; } - - /// - /// Gets a value indicating whether this instance is positive. - /// - public bool IsPositive => ReplyCode >= 200 && ReplyCode <= 399; - - #endregion - - #region Methods - - /// - /// Parses the specified text into a Server Reply for thorough analysis. - /// - /// The text. - /// A new instance of SMTP server response object. - public static SmtpServerReply Parse(string text) - { - var lines = text.Split(new[] {"\r\n"}, StringSplitOptions.RemoveEmptyEntries); - if (lines.Length == 0) return new SmtpServerReply(); - - var lastLineParts = lines.Last().Split(new[] {" "}, StringSplitOptions.RemoveEmptyEntries); - var enhancedStatusCode = string.Empty; - int.TryParse(lastLineParts[0], out var responseCode); - if (lastLineParts.Length > 1) - { - if (lastLineParts[1].Split('.').Length == 3) - enhancedStatusCode = lastLineParts[1]; - } - - var content = new List(); - - for (var i = 0; i < lines.Length; i++) - { - var splitChar = i == lines.Length - 1 ? " " : "-"; - - var lineParts = lines[i].Split(new[] {splitChar}, 2, StringSplitOptions.None); - var lineContent = lineParts.Last(); - if (string.IsNullOrWhiteSpace(enhancedStatusCode) == false) - lineContent = lineContent.Replace(enhancedStatusCode, string.Empty).Trim(); - - content.Add(lineContent); - } - - return new SmtpServerReply(responseCode, enhancedStatusCode, content.ToArray()); - } - - /// - /// Returns a that represents this instance. - /// - /// - /// A that represents this instance. - /// - public override string ToString() - { - var responseCodeText = ReplyCode.ToString(CultureInfo.InvariantCulture); - var statusCodeText = string.IsNullOrWhiteSpace(EnhancedStatusCode) - ? string.Empty - : $" {EnhancedStatusCode.Trim()}"; - if (Content.Count == 0) return $"{responseCodeText}{statusCodeText}"; - - var builder = new StringBuilder(); - - for (var i = 0; i < Content.Count; i++) - { - var isLastLine = i == Content.Count - 1; - - builder.Append(isLastLine - ? $"{responseCodeText}{statusCodeText} {Content[i]}" - : $"{responseCodeText}-{Content[i]}\r\n"); - } - - return builder.ToString(); - } - - #endregion - } + /// The response code. + /// The status code. + /// The content. + public SmtpServerReply(Int32 responseCode, String statusCode, params String[] content) { + this.Content = new List(); + this.ReplyCode = responseCode; + this.EnhancedStatusCode = statusCode; + this.Content.AddRange(content); + this.IsValid = responseCode >= 200 && responseCode < 600; + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown; + this.ReplyCodeCategory = SmtpReplyCodeCategories.Unknown; + + if(!this.IsValid) { + return; + } + + if(responseCode >= 200) { + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveCompletion; + } + + if(responseCode >= 300) { + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PositiveIntermediate; + } + + if(responseCode >= 400) { + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.TransientNegative; + } + + if(responseCode >= 500) { + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.PermanentNegative; + } + + if(responseCode >= 600) { + this.ReplyCodeSeverity = SmtpReplyCodeSeverities.Unknown; + } + + if(Int32.TryParse(responseCode.ToString(CultureInfo.InvariantCulture).Substring(1, 1), out Int32 middleDigit)) { + if(middleDigit >= 0 && middleDigit <= 5) { + this.ReplyCodeCategory = (SmtpReplyCodeCategories)middleDigit; + } + } + } + + /// + /// Initializes a new instance of the class. + /// + public SmtpServerReply() : this(0, String.Empty, String.Empty) { + // placeholder + } + + /// + /// Initializes a new instance of the class. + /// + /// The response code. + /// The status code. + /// The content. + public SmtpServerReply(Int32 responseCode, String statusCode, String content) : this(responseCode, statusCode, new[] { content }) { + } + + /// + /// Initializes a new instance of the class. + /// + /// The response code. + /// The content. + public SmtpServerReply(Int32 responseCode, String content) : this(responseCode, String.Empty, content) { + } + + #endregion + + #region Pre-built responses (https://tools.ietf.org/html/rfc5321#section-4.2.2) + + /// + /// Gets the command unrecognized reply. + /// + public static SmtpServerReply CommandUnrecognized => new SmtpServerReply(500, "Syntax error, command unrecognized"); + + /// + /// Gets the syntax error arguments reply. + /// + public static SmtpServerReply SyntaxErrorArguments => new SmtpServerReply(501, "Syntax error in parameters or arguments"); + + /// + /// Gets the command not implemented reply. + /// + public static SmtpServerReply CommandNotImplemented => new SmtpServerReply(502, "Command not implemented"); + + /// + /// Gets the bad sequence of commands reply. + /// + public static SmtpServerReply BadSequenceOfCommands => new SmtpServerReply(503, "Bad sequence of commands"); + + /// + /// Gets the protocol violation reply. + /// = + public static SmtpServerReply ProtocolViolation => new SmtpServerReply(451, "Requested action aborted: error in processing"); + + /// + /// Gets the system status bye reply. + /// + public static SmtpServerReply SystemStatusBye => new SmtpServerReply(221, "Service closing transmission channel"); + + /// + /// Gets the system status help reply. + /// = + public static SmtpServerReply SystemStatusHelp => new SmtpServerReply(221, "Refer to RFC 5321"); + + /// + /// Gets the bad syntax command empty reply. + /// + public static SmtpServerReply BadSyntaxCommandEmpty => new SmtpServerReply(400, "Error: bad syntax"); + + /// + /// Gets the OK reply. + /// + public static SmtpServerReply Ok => new SmtpServerReply(250, "OK"); + + /// + /// Gets the authorization required reply. + /// + public static SmtpServerReply AuthorizationRequired => new SmtpServerReply(530, "Authorization Required"); + + #endregion + + #region Properties + + /// + /// Gets the response severity. + /// + public SmtpReplyCodeSeverities ReplyCodeSeverity { + get; + } + + /// + /// Gets the response category. + /// + public SmtpReplyCodeCategories ReplyCodeCategory { + get; + } + + /// + /// Gets the numeric response code. + /// + public Int32 ReplyCode { + get; + } + + /// + /// Gets the enhanced status code. + /// + public String EnhancedStatusCode { + get; + } + + /// + /// Gets the content. + /// + public List Content { + get; + } + + /// + /// Returns true if the response code is between 200 and 599. + /// + public Boolean IsValid { + get; + } + + /// + /// Gets a value indicating whether this instance is positive. + /// + public Boolean IsPositive => this.ReplyCode >= 200 && this.ReplyCode <= 399; + + #endregion + + #region Methods + + /// + /// Parses the specified text into a Server Reply for thorough analysis. + /// + /// The text. + /// A new instance of SMTP server response object. + public static SmtpServerReply Parse(String text) { + String[] lines = text.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries); + if(lines.Length == 0) { + return new SmtpServerReply(); + } + + String[] lastLineParts = lines.Last().Split(new[] { " " }, StringSplitOptions.RemoveEmptyEntries); + String enhancedStatusCode = String.Empty; + _ = Int32.TryParse(lastLineParts[0], out Int32 responseCode); + if(lastLineParts.Length > 1) { + if(lastLineParts[1].Split('.').Length == 3) { + enhancedStatusCode = lastLineParts[1]; + } + } + + List content = new List(); + + for(Int32 i = 0; i < lines.Length; i++) { + String splitChar = i == lines.Length - 1 ? " " : "-"; + + String[] lineParts = lines[i].Split(new[] { splitChar }, 2, StringSplitOptions.None); + String lineContent = lineParts.Last(); + if(String.IsNullOrWhiteSpace(enhancedStatusCode) == false) { + lineContent = lineContent.Replace(enhancedStatusCode, String.Empty).Trim(); + } + + content.Add(lineContent); + } + + return new SmtpServerReply(responseCode, enhancedStatusCode, content.ToArray()); + } + + /// + /// Returns a that represents this instance. + /// + /// + /// A that represents this instance. + /// + public override String ToString() { + String responseCodeText = this.ReplyCode.ToString(CultureInfo.InvariantCulture); + String statusCodeText = String.IsNullOrWhiteSpace(this.EnhancedStatusCode) ? String.Empty : $" {this.EnhancedStatusCode.Trim()}"; + if(this.Content.Count == 0) { + return $"{responseCodeText}{statusCodeText}"; + } + + StringBuilder builder = new StringBuilder(); + + for(Int32 i = 0; i < this.Content.Count; i++) { + Boolean isLastLine = i == this.Content.Count - 1; + + _ = builder.Append(isLastLine ? $"{responseCodeText}{statusCodeText} {this.Content[i]}" : $"{responseCodeText}-{this.Content[i]}\r\n"); + } + + return builder.ToString(); + } + + #endregion + } } \ No newline at end of file diff --git a/Swan/Net/Smtp/SmtpSessionState.cs b/Swan/Net/Smtp/SmtpSessionState.cs index 9788362..6db1efe 100644 --- a/Swan/Net/Smtp/SmtpSessionState.cs +++ b/Swan/Net/Smtp/SmtpSessionState.cs @@ -1,158 +1,179 @@ -namespace Swan.Net.Smtp -{ - using System.Collections.Generic; - +using System.Collections.Generic; +using System; + +namespace Swan.Net.Smtp { + /// + /// Represents the state of an SMTP session associated with a client. + /// + public class SmtpSessionState { /// - /// Represents the state of an SMTP session associated with a client. + /// Initializes a new instance of the class. /// - public class SmtpSessionState - { - /// - /// Initializes a new instance of the class. - /// - public SmtpSessionState() - { - DataBuffer = new List(); - Reset(true); - ResetAuthentication(); - } - - #region Properties - - /// - /// Gets the contents of the data buffer. - /// - public List DataBuffer { get; protected set; } - - /// - /// Gets or sets a value indicating whether this instance has initiated. - /// - public bool HasInitiated { get; set; } - - /// - /// Gets or sets a value indicating whether the current session supports extensions. - /// - public bool SupportsExtensions { get; set; } - - /// - /// Gets or sets the client hostname. - /// - public string ClientHostname { get; set; } - - /// - /// Gets or sets a value indicating whether the session is currently receiving DATA. - /// - public bool IsInDataMode { get; set; } - - /// - /// Gets or sets the sender address. - /// - public string SenderAddress { get; set; } - - /// - /// Gets the recipients. - /// - public List Recipients { get; } = new List(); - - /// - /// Gets or sets the extended data supporting any additional field for storage by a responder implementation. - /// - public object ExtendedData { get; set; } - - #endregion - - #region AUTH State - - /// - /// Gets or sets a value indicating whether this instance is in authentication mode. - /// - public bool IsInAuthMode { get; set; } - - /// - /// Gets or sets the username. - /// - public string Username { get; set; } - - /// - /// Gets or sets the password. - /// - public string Password { get; set; } - - /// - /// Gets a value indicating whether this instance has provided username. - /// - public bool HasProvidedUsername => string.IsNullOrWhiteSpace(Username) == false; - - /// - /// Gets or sets a value indicating whether this instance is authenticated. - /// - public bool IsAuthenticated { get; set; } - - /// - /// Gets or sets the authentication mode. - /// - public string AuthMode { get; set; } - - /// - /// Gets or sets a value indicating whether this instance is channel secure. - /// - public bool IsChannelSecure { get; set; } - - /// - /// Resets the authentication state. - /// - public void ResetAuthentication() - { - Username = string.Empty; - Password = string.Empty; - AuthMode = string.Empty; - IsInAuthMode = false; - IsAuthenticated = false; - } - - #endregion - - #region Methods - - /// - /// Resets the data mode to false, clears the recipients, the sender address and the data buffer. - /// - public void ResetEmail() - { - IsInDataMode = false; - Recipients.Clear(); - SenderAddress = string.Empty; - DataBuffer.Clear(); - } - - /// - /// Resets the state table entirely. - /// - /// if set to true [clear extension data]. - public void Reset(bool clearExtensionData) - { - HasInitiated = false; - SupportsExtensions = false; - ClientHostname = string.Empty; - ResetEmail(); - - if (clearExtensionData) - ExtendedData = null; - } - - /// - /// Creates a new object that is a copy of the current instance. - /// - /// A clone. - public virtual SmtpSessionState Clone() - { - var clonedState = this.CopyPropertiesToNew(new[] {nameof(DataBuffer)}); - clonedState.DataBuffer.AddRange(DataBuffer); - clonedState.Recipients.AddRange(Recipients); - - return clonedState; - } - - #endregion - } + public SmtpSessionState() { + this.DataBuffer = new List(); + this.Reset(true); + this.ResetAuthentication(); + } + + #region Properties + + /// + /// Gets the contents of the data buffer. + /// + public List DataBuffer { + get; protected set; + } + + /// + /// Gets or sets a value indicating whether this instance has initiated. + /// + public Boolean HasInitiated { + get; set; + } + + /// + /// Gets or sets a value indicating whether the current session supports extensions. + /// + public Boolean SupportsExtensions { + get; set; + } + + /// + /// Gets or sets the client hostname. + /// + public String ClientHostname { + get; set; + } + + /// + /// Gets or sets a value indicating whether the session is currently receiving DATA. + /// + public Boolean IsInDataMode { + get; set; + } + + /// + /// Gets or sets the sender address. + /// + public String SenderAddress { + get; set; + } + + /// + /// Gets the recipients. + /// + public List Recipients { get; } = new List(); + + /// + /// Gets or sets the extended data supporting any additional field for storage by a responder implementation. + /// + public Object ExtendedData { + get; set; + } + + #endregion + + #region AUTH State + + /// + /// Gets or sets a value indicating whether this instance is in authentication mode. + /// + public Boolean IsInAuthMode { + get; set; + } + + /// + /// Gets or sets the username. + /// + public String Username { + get; set; + } + + /// + /// Gets or sets the password. + /// + public String Password { + get; set; + } + + /// + /// Gets a value indicating whether this instance has provided username. + /// + public Boolean HasProvidedUsername => String.IsNullOrWhiteSpace(this.Username) == false; + + /// + /// Gets or sets a value indicating whether this instance is authenticated. + /// + public Boolean IsAuthenticated { + get; set; + } + + /// + /// Gets or sets the authentication mode. + /// + public String AuthMode { + get; set; + } + + /// + /// Gets or sets a value indicating whether this instance is channel secure. + /// + public Boolean IsChannelSecure { + get; set; + } + + /// + /// Resets the authentication state. + /// + public void ResetAuthentication() { + this.Username = String.Empty; + this.Password = String.Empty; + this.AuthMode = String.Empty; + this.IsInAuthMode = false; + this.IsAuthenticated = false; + } + + #endregion + + #region Methods + + /// + /// Resets the data mode to false, clears the recipients, the sender address and the data buffer. + /// + public void ResetEmail() { + this.IsInDataMode = false; + this.Recipients.Clear(); + this.SenderAddress = String.Empty; + this.DataBuffer.Clear(); + } + + /// + /// Resets the state table entirely. + /// + /// if set to true [clear extension data]. + public void Reset(Boolean clearExtensionData) { + this.HasInitiated = false; + this.SupportsExtensions = false; + this.ClientHostname = String.Empty; + this.ResetEmail(); + + if(clearExtensionData) { + this.ExtendedData = null; + } + } + + /// + /// Creates a new object that is a copy of the current instance. + /// + /// A clone. + public virtual SmtpSessionState Clone() { + SmtpSessionState clonedState = this.CopyPropertiesToNew(new[] { nameof(this.DataBuffer) }); + clonedState.DataBuffer.AddRange(this.DataBuffer); + clonedState.Recipients.AddRange(this.Recipients); + + return clonedState; + } + + #endregion + } } \ No newline at end of file diff --git a/Swan/ProcessResult.cs b/Swan/ProcessResult.cs index c21e68f..9703050 100644 --- a/Swan/ProcessResult.cs +++ b/Swan/ProcessResult.cs @@ -1,46 +1,51 @@ -namespace Swan -{ +using System; + +namespace Swan { + /// + /// Represents the text of the standard output and standard error + /// of a process, including its exit code. + /// + public class ProcessResult { /// - /// Represents the text of the standard output and standard error - /// of a process, including its exit code. + /// Initializes a new instance of the class. /// - public class ProcessResult - { - /// - /// Initializes a new instance of the class. - /// - /// The exit code. - /// The standard output. - /// The standard error. - public ProcessResult(int exitCode, string standardOutput, string standardError) - { - ExitCode = exitCode; - StandardOutput = standardOutput; - StandardError = standardError; - } - - /// - /// Gets the exit code. - /// - /// - /// The exit code. - /// - public int ExitCode { get; } - - /// - /// Gets the text of the standard output. - /// - /// - /// The standard output. - /// - public string StandardOutput { get; } - - /// - /// Gets the text of the standard error. - /// - /// - /// The standard error. - /// - public string StandardError { get; } - } + /// The exit code. + /// The standard output. + /// The standard error. + public ProcessResult(Int32 exitCode, String standardOutput, String standardError) { + this.ExitCode = exitCode; + this.StandardOutput = standardOutput; + this.StandardError = standardError; + } + + /// + /// Gets the exit code. + /// + /// + /// The exit code. + /// + public Int32 ExitCode { + get; + } + + /// + /// Gets the text of the standard output. + /// + /// + /// The standard output. + /// + public String StandardOutput { + get; + } + + /// + /// Gets the text of the standard error. + /// + /// + /// The standard error. + /// + public String StandardError { + get; + } + } } \ No newline at end of file diff --git a/Swan/ProcessRunner.cs b/Swan/ProcessRunner.cs index db00888..a429c34 100644 --- a/Swan/ProcessRunner.cs +++ b/Swan/ProcessRunner.cs @@ -1,443 +1,353 @@ -namespace Swan -{ - using System; - using System.Diagnostics; - using System.IO; - using System.Linq; - using System.Text; - using System.Threading; - using System.Threading.Tasks; - +#nullable enable +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan { + /// + /// Provides methods to help create external processes, and efficiently capture the + /// standard error and standard output streams. + /// + public static class ProcessRunner { /// - /// Provides methods to help create external processes, and efficiently capture the - /// standard error and standard output streams. + /// Defines a delegate to handle binary data reception from the standard + /// output or standard error streams from a process. /// - public static class ProcessRunner - { - /// - /// Defines a delegate to handle binary data reception from the standard - /// output or standard error streams from a process. - /// - /// The process data. - /// The process. - public delegate void ProcessDataReceivedCallback(byte[] processData, Process process); - - /// - /// Runs the process asynchronously and if the exit code is 0, - /// returns all of the standard output text. If the exit code is something other than 0 - /// it returns the contents of standard error. - /// This method is meant to be used for programs that output a relatively small amount of text. - /// - /// The filename. - /// The arguments. - /// The working directory. - /// The cancellation token. - /// The type of the result produced by this Task. - /// - /// The following code explains how to run an external process using the - /// method. - /// - /// class Example - /// { - /// using System.Threading.Tasks; - /// using Swan; - /// - /// static async Task Main() - /// { - /// // execute a process and save its output - /// var data = await ProcessRunner. - /// GetProcessOutputAsync("dotnet", "--help"); - /// - /// // print the output - /// data.WriteLine(); - /// } - /// } - /// - /// - public static async Task GetProcessOutputAsync( - string filename, - string arguments = "", - string? workingDirectory = null, - CancellationToken cancellationToken = default) - { - var result = await GetProcessResultAsync(filename, arguments, workingDirectory, cancellationToken: cancellationToken).ConfigureAwait(false); - return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; - } - - /// - /// Runs the process asynchronously and if the exit code is 0, - /// returns all of the standard output text. If the exit code is something other than 0 - /// it returns the contents of standard error. - /// This method is meant to be used for programs that output a relatively small amount - /// of text using a different encoder. - /// - /// The filename. - /// The arguments. - /// The encoding. - /// The cancellation token. - /// - /// The type of the result produced by this Task. - /// - public static async Task GetProcessEncodedOutputAsync( - string filename, - string arguments = "", - Encoding? encoding = null, - CancellationToken cancellationToken = default) - { - var result = await GetProcessResultAsync(filename, arguments, null, encoding, cancellationToken).ConfigureAwait(false); - return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; - } - - /// - /// Executes a process asynchronously and returns the text of the standard output and standard error streams - /// along with the exit code. This method is meant to be used for programs that output a relatively small - /// amount of text. - /// - /// The filename. - /// The arguments. - /// The cancellation token. - /// - /// Text of the standard output and standard error streams along with the exit code as a instance. - /// - /// filename. - public static Task GetProcessResultAsync( - string filename, - string arguments = "", - CancellationToken cancellationToken = default) => - GetProcessResultAsync(filename, arguments, null, Definitions.CurrentAnsiEncoding, cancellationToken); - - /// - /// Executes a process asynchronously and returns the text of the standard output and standard error streams - /// along with the exit code. This method is meant to be used for programs that output a relatively small - /// amount of text. - /// - /// The filename. - /// The arguments. - /// The working directory. - /// The encoding. - /// The cancellation token. - /// - /// Text of the standard output and standard error streams along with the exit code as a instance. - /// - /// filename. - /// - /// The following code describes how to run an external process using the method. - /// - /// class Example - /// { - /// using System.Threading.Tasks; - /// using Swan; - /// - /// static async Task Main() - /// { - /// // Execute a process asynchronously - /// var data = await ProcessRunner.GetProcessResultAsync("dotnet", "--help"); - /// - /// // print out the exit code - /// $"{data.ExitCode}".WriteLine(); - /// - /// // print out the output - /// data.StandardOutput.WriteLine(); - /// // and the error if exists - /// data.StandardError.Error(); - /// } - /// } - /// - public static async Task GetProcessResultAsync( - string filename, - string arguments, - string? workingDirectory, - Encoding? encoding = null, - CancellationToken cancellationToken = default) - { - if (filename == null) - throw new ArgumentNullException(nameof(filename)); - - if (encoding == null) - encoding = Definitions.CurrentAnsiEncoding; - - var standardOutputBuilder = new StringBuilder(); - var standardErrorBuilder = new StringBuilder(); - - var processReturn = await RunProcessAsync( - filename, - arguments, - workingDirectory, - (data, proc) => standardOutputBuilder.Append(encoding.GetString(data)), - (data, proc) => standardErrorBuilder.Append(encoding.GetString(data)), - encoding, - true, - cancellationToken) - .ConfigureAwait(false); - - return new ProcessResult(processReturn, standardOutputBuilder.ToString(), standardErrorBuilder.ToString()); - } - - /// - /// Runs an external process asynchronously, providing callbacks to - /// capture binary data from the standard error and standard output streams. - /// The callbacks contain a reference to the process so you can respond to output or - /// error streams by writing to the process' input stream. - /// The exit code (return value) will be -1 for forceful termination of the process. - /// - /// The filename. - /// The arguments. - /// The working directory. - /// The on output data. - /// The on error data. - /// The encoding. - /// if set to true the next data callback will wait until the current one completes. - /// The cancellation token. - /// - /// Value type will be -1 for forceful termination of the process. - /// - public static Task RunProcessAsync( - string filename, - string arguments, - string? workingDirectory, - ProcessDataReceivedCallback onOutputData, - ProcessDataReceivedCallback onErrorData, - Encoding encoding, - bool syncEvents = true, - CancellationToken cancellationToken = default) - { - if (filename == null) - throw new ArgumentNullException(nameof(filename)); - - return Task.Run(() => - { - // Setup the process and its corresponding start info - var process = new Process - { - EnableRaisingEvents = false, - StartInfo = new ProcessStartInfo - { - Arguments = arguments, - CreateNoWindow = true, - FileName = filename, - RedirectStandardError = true, - StandardErrorEncoding = encoding, - RedirectStandardOutput = true, - StandardOutputEncoding = encoding, - UseShellExecute = false, -#if NET461 - WindowStyle = ProcessWindowStyle.Hidden, -#endif - }, - }; - - if (!string.IsNullOrWhiteSpace(workingDirectory)) - process.StartInfo.WorkingDirectory = workingDirectory; - - // Launch the process and discard any buffered data for standard error and standard output - process.Start(); - process.StandardError.DiscardBufferedData(); - process.StandardOutput.DiscardBufferedData(); - - // Launch the asynchronous stream reading tasks - var readTasks = new Task[2]; - readTasks[0] = CopyStreamAsync( - process, - process.StandardOutput.BaseStream, - onOutputData, - syncEvents, - cancellationToken); - readTasks[1] = CopyStreamAsync( - process, - process.StandardError.BaseStream, - onErrorData, - syncEvents, - cancellationToken); - - try - { - // Wait for all tasks to complete - Task.WaitAll(readTasks, cancellationToken); - } - catch (TaskCanceledException) - { - // ignore - } - finally - { - // Wait for the process to exit - while (cancellationToken.IsCancellationRequested == false) - { - if (process.HasExited || process.WaitForExit(5)) - break; - } - - // Forcefully kill the process if it do not exit - try - { - if (process.HasExited == false) - process.Kill(); - } - catch - { - // swallow - } - } - - try - { - // Retrieve and return the exit code. - // -1 signals error - return process.HasExited ? process.ExitCode : -1; - } - catch - { - return -1; - } - }, cancellationToken); - } - - /// - /// Runs an external process asynchronously, providing callbacks to - /// capture binary data from the standard error and standard output streams. - /// The callbacks contain a reference to the process so you can respond to output or - /// error streams by writing to the process' input stream. - /// The exit code (return value) will be -1 for forceful termination of the process. - /// - /// The filename. - /// The arguments. - /// The on output data. - /// The on error data. - /// if set to true the next data callback will wait until the current one completes. - /// The cancellation token. - /// Value type will be -1 for forceful termination of the process. - /// - /// The following example illustrates how to run an external process using the - /// - /// method. - /// - /// class Example - /// { - /// using System.Diagnostics; - /// using System.Text; - /// using System.Threading.Tasks; - /// using Swan; - /// - /// static async Task Main() - /// { - /// // Execute a process asynchronously - /// var data = await ProcessRunner - /// .RunProcessAsync("dotnet", "--help", Print, Print); - /// - /// // flush all messages - /// Terminal.Flush(); - /// } - /// - /// // a callback to print both output or errors - /// static void Print(byte[] data, Process proc) => - /// Encoding.GetEncoding(0).GetString(data).WriteLine(); - /// } - /// - /// - public static Task RunProcessAsync( - string filename, - string arguments, - ProcessDataReceivedCallback onOutputData, - ProcessDataReceivedCallback onErrorData, - bool syncEvents = true, - CancellationToken cancellationToken = default) - => RunProcessAsync( - filename, - arguments, - null, - onOutputData, - onErrorData, - Definitions.CurrentAnsiEncoding, - syncEvents, - cancellationToken); - - /// - /// Copies the stream asynchronously. - /// - /// The process. - /// The source stream. - /// The on data callback. - /// if set to true [synchronize events]. - /// The cancellation token. - /// Total copies stream. - private static Task CopyStreamAsync( - Process process, - Stream baseStream, - ProcessDataReceivedCallback onDataCallback, - bool syncEvents, - CancellationToken ct) => - Task.Run(async () => - { - // define some state variables - var swapBuffer = new byte[2048]; // the buffer to copy data from one stream to the next - ulong totalCount = 0; // the total amount of bytes read - var hasExited = false; - - while (ct.IsCancellationRequested == false) - { - try - { - // Check if process is no longer valid - // if this condition holds, simply read the last bits of data available. - int readCount; // the bytes read in any given event - if (process.HasExited || process.WaitForExit(1)) - { - while (true) - { - try - { - readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct); - - if (readCount > 0) - { - totalCount += (ulong) readCount; - onDataCallback?.Invoke(swapBuffer.Skip(0).Take(readCount).ToArray(), process); - } - else - { - hasExited = true; - break; - } - } - catch - { - hasExited = true; - break; - } - } - } - - if (hasExited) break; - - // Try reading from the stream. < 0 means no read occurred. - readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct).ConfigureAwait(false); - - // When no read is done, we need to let is rest for a bit - if (readCount <= 0) - { - await Task.Delay(1, ct).ConfigureAwait(false); // do not hog CPU cycles doing nothing. - continue; - } - - totalCount += (ulong) readCount; - if (onDataCallback == null) continue; - - // Create the buffer to pass to the callback - var eventBuffer = swapBuffer.Skip(0).Take(readCount).ToArray(); - - // Create the data processing callback invocation - var eventTask = Task.Run(() => onDataCallback.Invoke(eventBuffer, process), ct); - - // wait for the event to process before the next read occurs - if (syncEvents) eventTask.Wait(ct); - } - catch - { - break; - } - } - - return totalCount; - }, ct); - } + /// The process data. + /// The process. + public delegate void ProcessDataReceivedCallback(Byte[] processData, Process process); + + /// + /// Runs the process asynchronously and if the exit code is 0, + /// returns all of the standard output text. If the exit code is something other than 0 + /// it returns the contents of standard error. + /// This method is meant to be used for programs that output a relatively small amount of text. + /// + /// The filename. + /// The arguments. + /// The working directory. + /// The cancellation token. + /// The type of the result produced by this Task. + /// + /// The following code explains how to run an external process using the + /// method. + /// + /// class Example + /// { + /// using System.Threading.Tasks; + /// using Swan; + /// + /// static async Task Main() + /// { + /// // execute a process and save its output + /// var data = await ProcessRunner. + /// GetProcessOutputAsync("dotnet", "--help"); + /// + /// // print the output + /// data.WriteLine(); + /// } + /// } + /// + /// + public static async Task GetProcessOutputAsync(String filename, String arguments = "", String? workingDirectory = null, CancellationToken cancellationToken = default) { + ProcessResult result = await GetProcessResultAsync(filename, arguments, workingDirectory, cancellationToken: cancellationToken).ConfigureAwait(false); + return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; + } + + /// + /// Runs the process asynchronously and if the exit code is 0, + /// returns all of the standard output text. If the exit code is something other than 0 + /// it returns the contents of standard error. + /// This method is meant to be used for programs that output a relatively small amount + /// of text using a different encoder. + /// + /// The filename. + /// The arguments. + /// The encoding. + /// The cancellation token. + /// + /// The type of the result produced by this Task. + /// + public static async Task GetProcessEncodedOutputAsync(String filename, String arguments = "", Encoding? encoding = null, CancellationToken cancellationToken = default) { + ProcessResult result = await GetProcessResultAsync(filename, arguments, null, encoding, cancellationToken).ConfigureAwait(false); + return result.ExitCode == 0 ? result.StandardOutput : result.StandardError; + } + + /// + /// Executes a process asynchronously and returns the text of the standard output and standard error streams + /// along with the exit code. This method is meant to be used for programs that output a relatively small + /// amount of text. + /// + /// The filename. + /// The arguments. + /// The cancellation token. + /// + /// Text of the standard output and standard error streams along with the exit code as a instance. + /// + /// filename. + public static Task GetProcessResultAsync(String filename, String arguments = "", CancellationToken cancellationToken = default) => GetProcessResultAsync(filename, arguments, null, Definitions.CurrentAnsiEncoding, cancellationToken); + + /// + /// Executes a process asynchronously and returns the text of the standard output and standard error streams + /// along with the exit code. This method is meant to be used for programs that output a relatively small + /// amount of text. + /// + /// The filename. + /// The arguments. + /// The working directory. + /// The encoding. + /// The cancellation token. + /// + /// Text of the standard output and standard error streams along with the exit code as a instance. + /// + /// filename. + /// + /// The following code describes how to run an external process using the method. + /// + /// class Example + /// { + /// using System.Threading.Tasks; + /// using Swan; + /// + /// static async Task Main() + /// { + /// // Execute a process asynchronously + /// var data = await ProcessRunner.GetProcessResultAsync("dotnet", "--help"); + /// + /// // print out the exit code + /// $"{data.ExitCode}".WriteLine(); + /// + /// // print out the output + /// data.StandardOutput.WriteLine(); + /// // and the error if exists + /// data.StandardError.Error(); + /// } + /// } + /// + public static async Task GetProcessResultAsync(String filename, String arguments, String? workingDirectory, Encoding? encoding = null, CancellationToken cancellationToken = default) { + if(filename == null) { + throw new ArgumentNullException(nameof(filename)); + } + + if(encoding == null) { + encoding = Definitions.CurrentAnsiEncoding; + } + + StringBuilder standardOutputBuilder = new StringBuilder(); + StringBuilder standardErrorBuilder = new StringBuilder(); + + Int32 processReturn = await RunProcessAsync(filename, arguments, workingDirectory, (data, proc) => standardOutputBuilder.Append(encoding.GetString(data)), (data, proc) => standardErrorBuilder.Append(encoding.GetString(data)), encoding, true, cancellationToken).ConfigureAwait(false); + + return new ProcessResult(processReturn, standardOutputBuilder.ToString(), standardErrorBuilder.ToString()); + } + + /// + /// Runs an external process asynchronously, providing callbacks to + /// capture binary data from the standard error and standard output streams. + /// The callbacks contain a reference to the process so you can respond to output or + /// error streams by writing to the process' input stream. + /// The exit code (return value) will be -1 for forceful termination of the process. + /// + /// The filename. + /// The arguments. + /// The working directory. + /// The on output data. + /// The on error data. + /// The encoding. + /// if set to true the next data callback will wait until the current one completes. + /// The cancellation token. + /// + /// Value type will be -1 for forceful termination of the process. + /// + public static Task RunProcessAsync(String filename, String arguments, String? workingDirectory, ProcessDataReceivedCallback onOutputData, ProcessDataReceivedCallback? onErrorData, Encoding encoding, Boolean syncEvents = true, CancellationToken cancellationToken = default) { + if(filename == null) { + throw new ArgumentNullException(nameof(filename)); + } + + return Task.Run(() => { + // Setup the process and its corresponding start info + Process process = new Process { + EnableRaisingEvents = false, + StartInfo = new ProcessStartInfo { + Arguments = arguments, + CreateNoWindow = true, + FileName = filename, + RedirectStandardError = true, + StandardErrorEncoding = encoding, + RedirectStandardOutput = true, + StandardOutputEncoding = encoding, + UseShellExecute = false, + }, + }; + + if(!String.IsNullOrWhiteSpace(workingDirectory)) { + process.StartInfo.WorkingDirectory = workingDirectory; + } + + // Launch the process and discard any buffered data for standard error and standard output + _ = process.Start(); + process.StandardError.DiscardBufferedData(); + process.StandardOutput.DiscardBufferedData(); + + // Launch the asynchronous stream reading tasks + Task[] readTasks = new Task[2]; + readTasks[0] = CopyStreamAsync(process, process.StandardOutput.BaseStream, onOutputData, syncEvents, cancellationToken); + readTasks[1] = CopyStreamAsync(process, process.StandardError.BaseStream, onErrorData, syncEvents, cancellationToken); + + try { + // Wait for all tasks to complete + Task.WaitAll(readTasks, cancellationToken); + } catch(TaskCanceledException) { + // ignore + } finally { + // Wait for the process to exit + while(cancellationToken.IsCancellationRequested == false) { + if(process.HasExited || process.WaitForExit(5)) { + break; + } + } + + // Forcefully kill the process if it do not exit + try { + if(process.HasExited == false) { + process.Kill(); + } + } catch { + // swallow + } + } + + try { + // Retrieve and return the exit code. + // -1 signals error + return process.HasExited ? process.ExitCode : -1; + } catch { + return -1; + } + }, cancellationToken); + } + + /// + /// Runs an external process asynchronously, providing callbacks to + /// capture binary data from the standard error and standard output streams. + /// The callbacks contain a reference to the process so you can respond to output or + /// error streams by writing to the process' input stream. + /// The exit code (return value) will be -1 for forceful termination of the process. + /// + /// The filename. + /// The arguments. + /// The on output data. + /// The on error data. + /// if set to true the next data callback will wait until the current one completes. + /// The cancellation token. + /// Value type will be -1 for forceful termination of the process. + /// + /// The following example illustrates how to run an external process using the + /// + /// method. + /// + /// class Example + /// { + /// using System.Diagnostics; + /// using System.Text; + /// using System.Threading.Tasks; + /// using Swan; + /// + /// static async Task Main() + /// { + /// // Execute a process asynchronously + /// var data = await ProcessRunner + /// .RunProcessAsync("dotnet", "--help", Print, Print); + /// + /// // flush all messages + /// Terminal.Flush(); + /// } + /// + /// // a callback to print both output or errors + /// static void Print(byte[] data, Process proc) => + /// Encoding.GetEncoding(0).GetString(data).WriteLine(); + /// } + /// + /// + public static Task RunProcessAsync(String filename, String arguments, ProcessDataReceivedCallback onOutputData, ProcessDataReceivedCallback? onErrorData, Boolean syncEvents = true, CancellationToken cancellationToken = default) => RunProcessAsync(filename, arguments, null, onOutputData, onErrorData, Definitions.CurrentAnsiEncoding, syncEvents, cancellationToken); + + /// + /// Copies the stream asynchronously. + /// + /// The process. + /// The source stream. + /// The on data callback. + /// if set to true [synchronize events]. + /// The cancellation token. + /// Total copies stream. + private static Task CopyStreamAsync(Process process, Stream baseStream, ProcessDataReceivedCallback? onDataCallback, Boolean syncEvents, CancellationToken ct) => Task.Run(async () => { + // define some state variables + Byte[] swapBuffer = new Byte[2048]; // the buffer to copy data from one stream to the next + UInt64 totalCount = 0; // the total amount of bytes read + Boolean hasExited = false; + + while(ct.IsCancellationRequested == false) { + try { + // Check if process is no longer valid + // if this condition holds, simply read the last bits of data available. + Int32 readCount; // the bytes read in any given event + if(process.HasExited || process.WaitForExit(1)) { + while(true) { + try { + readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct); + + if(readCount > 0) { + totalCount += (UInt64)readCount; + onDataCallback?.Invoke(swapBuffer.Skip(0).Take(readCount).ToArray(), process); + } else { + hasExited = true; + break; + } + } catch { + hasExited = true; + break; + } + } + } + + if(hasExited) { + break; + } + + // Try reading from the stream. < 0 means no read occurred. + readCount = await baseStream.ReadAsync(swapBuffer, 0, swapBuffer.Length, ct).ConfigureAwait(false); + + // When no read is done, we need to let is rest for a bit + if(readCount <= 0) { + await Task.Delay(1, ct).ConfigureAwait(false); // do not hog CPU cycles doing nothing. + continue; + } + + totalCount += (UInt64)readCount; + if(onDataCallback == null) { + continue; + } + + // Create the buffer to pass to the callback + Byte[] eventBuffer = swapBuffer.Skip(0).Take(readCount).ToArray(); + + // Create the data processing callback invocation + Task eventTask = Task.Run(() => onDataCallback.Invoke(eventBuffer, process), ct); + + // wait for the event to process before the next read occurs + if(syncEvents) { + eventTask.Wait(ct); + } + } catch { + break; + } + } + + return totalCount; + }, ct); + } } diff --git a/Swan/Services/ServiceBase.cs b/Swan/Services/ServiceBase.cs index e790c14..aa44e65 100644 --- a/Swan/Services/ServiceBase.cs +++ b/Swan/Services/ServiceBase.cs @@ -1,92 +1,98 @@ using System; -#if !NET461 -namespace Swan.Services -{ + +namespace Swan.Services { + /// + /// Mimic a Windows ServiceBase class. Useful to keep compatibility with applications + /// running as services in OS different to Windows. + /// + [Obsolete("This abstract class will be removed in version 3.0")] + public abstract class ServiceBase { /// - /// Mimic a Windows ServiceBase class. Useful to keep compatibility with applications - /// running as services in OS different to Windows. + /// Gets or sets a value indicating whether the service can be stopped once it has started. /// - [Obsolete("This abstract class will be removed in version 3.0")] - public abstract class ServiceBase - { - /// - /// Gets or sets a value indicating whether the service can be stopped once it has started. - /// - /// - /// true if this instance can stop; otherwise, false. - /// - public bool CanStop { get; set; } = true; - - /// - /// Gets or sets a value indicating whether the service should be notified when the system is shutting down. - /// - /// - /// true if this instance can shutdown; otherwise, false. - /// - public bool CanShutdown { get; set; } - - /// - /// Gets or sets a value indicating whether the service can be paused and resumed. - /// - /// - /// true if this instance can pause and continue; otherwise, false. - /// - public bool CanPauseAndContinue { get; set; } - - /// - /// Gets or sets the exit code. - /// - /// - /// The exit code. - /// - public int ExitCode { get; set; } - - /// - /// Indicates whether to report Start, Stop, Pause, and Continue commands in the event log. - /// - /// - /// true if [automatic log]; otherwise, false. - /// - public bool AutoLog { get; set; } - - /// - /// Gets or sets the name of the service. - /// - /// - /// The name of the service. - /// - public string ServiceName { get; set; } - - /// - /// Stops the executing service. - /// - public void Stop() - { - if (!CanStop) return; - - CanStop = false; - OnStop(); - } - - /// - /// When implemented in a derived class, executes when a Start command is sent to the service by the Service Control Manager (SCM) - /// or when the operating system starts (for a service that starts automatically). Specifies actions to take when the service starts. - /// - /// The arguments. - protected virtual void OnStart(string[] args) - { - // do nothing - } - - /// - /// When implemented in a derived class, executes when a Stop command is sent to the service by the Service Control Manager (SCM). - /// Specifies actions to take when a service stops running. - /// - protected virtual void OnStop() - { - // do nothing - } - } + /// + /// true if this instance can stop; otherwise, false. + /// + public Boolean CanStop { get; set; } = true; + + /// + /// Gets or sets a value indicating whether the service should be notified when the system is shutting down. + /// + /// + /// true if this instance can shutdown; otherwise, false. + /// + public Boolean CanShutdown { + get; set; + } + + /// + /// Gets or sets a value indicating whether the service can be paused and resumed. + /// + /// + /// true if this instance can pause and continue; otherwise, false. + /// + public Boolean CanPauseAndContinue { + get; set; + } + + /// + /// Gets or sets the exit code. + /// + /// + /// The exit code. + /// + public Int32 ExitCode { + get; set; + } + + /// + /// Indicates whether to report Start, Stop, Pause, and Continue commands in the event log. + /// + /// + /// true if [automatic log]; otherwise, false. + /// + public Boolean AutoLog { + get; set; + } + + /// + /// Gets or sets the name of the service. + /// + /// + /// The name of the service. + /// + public String ServiceName { + get; set; + } + + /// + /// Stops the executing service. + /// + public void Stop() { + if(!this.CanStop) { + return; + } + + this.CanStop = false; + this.OnStop(); + } + + /// + /// When implemented in a derived class, executes when a Start command is sent to the service by the Service Control Manager (SCM) + /// or when the operating system starts (for a service that starts automatically). Specifies actions to take when the service starts. + /// + /// The arguments. + protected virtual void OnStart(String[] args) { + // do nothing + } + + /// + /// When implemented in a derived class, executes when a Stop command is sent to the service by the Service Control Manager (SCM). + /// Specifies actions to take when a service stops running. + /// + protected virtual void OnStop() { + // do nothing + } + } } -#endif diff --git a/Swan/Threading/DelayProvider.cs b/Swan/Threading/DelayProvider.cs index f1d1c23..17ec77c 100644 --- a/Swan/Threading/DelayProvider.cs +++ b/Swan/Threading/DelayProvider.cs @@ -1,141 +1,136 @@ -namespace Swan.Threading -{ - using System; - using System.Diagnostics; - using System.Threading; - using System.Threading.Tasks; - +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Threading { + /// + /// Represents logic providing several delay mechanisms. + /// + /// + /// The following example shows how to implement delay mechanisms. + /// + /// using Swan.Threading; + /// + /// public class Example + /// { + /// public static void Main() + /// { + /// // using the ThreadSleep strategy + /// using (var delay = new DelayProvider(DelayProvider.DelayStrategy.ThreadSleep)) + /// { + /// // retrieve how much time was delayed + /// var time = delay.WaitOne(); + /// } + /// } + /// } + /// + /// + public sealed class DelayProvider : IDisposable { + private readonly Object _syncRoot = new Object(); + private readonly Stopwatch _delayStopwatch = new Stopwatch(); + + private Boolean _isDisposed; + private IWaitEvent _delayEvent; + /// - /// Represents logic providing several delay mechanisms. + /// Initializes a new instance of the class. /// - /// - /// The following example shows how to implement delay mechanisms. - /// - /// using Swan.Threading; - /// - /// public class Example - /// { - /// public static void Main() - /// { - /// // using the ThreadSleep strategy - /// using (var delay = new DelayProvider(DelayProvider.DelayStrategy.ThreadSleep)) - /// { - /// // retrieve how much time was delayed - /// var time = delay.WaitOne(); - /// } - /// } - /// } - /// - /// - public sealed class DelayProvider : IDisposable - { - private readonly object _syncRoot = new object(); - private readonly Stopwatch _delayStopwatch = new Stopwatch(); - - private bool _isDisposed; - private IWaitEvent _delayEvent; - - /// - /// Initializes a new instance of the class. - /// - /// The strategy. - public DelayProvider(DelayStrategy strategy = DelayStrategy.TaskDelay) - { - Strategy = strategy; - } - - /// - /// Enumerates the different ways of providing delays. - /// - public enum DelayStrategy - { - /// - /// Using the Thread.Sleep(15) mechanism. - /// - ThreadSleep, - - /// - /// Using the Task.Delay(1).Wait mechanism. - /// - TaskDelay, - - /// - /// Using a wait event that completes in a background ThreadPool thread. - /// - ThreadPool, - } - - /// - /// Gets the selected delay strategy. - /// - public DelayStrategy Strategy { get; } - - /// - /// Creates the smallest possible, synchronous delay based on the selected strategy. - /// - /// The elapsed time of the delay. - public TimeSpan WaitOne() - { - lock (_syncRoot) - { - if (_isDisposed) return TimeSpan.Zero; - - _delayStopwatch.Restart(); - - switch (Strategy) - { - case DelayStrategy.ThreadSleep: - DelaySleep(); - break; - case DelayStrategy.TaskDelay: - DelayTask(); - break; - case DelayStrategy.ThreadPool: - DelayThreadPool(); - break; - } - - return _delayStopwatch.Elapsed; - } - } - - #region Dispose Pattern - - /// - public void Dispose() - { - lock (_syncRoot) - { - if (_isDisposed) return; - _isDisposed = true; - - _delayEvent?.Dispose(); - } - } - - #endregion - - #region Private Delay Mechanisms - - private static void DelaySleep() => Thread.Sleep(15); - - private static void DelayTask() => Task.Delay(1).Wait(); - - private void DelayThreadPool() - { - if (_delayEvent == null) - _delayEvent = WaitEventFactory.Create(isCompleted: true, useSlim: true); - - _delayEvent.Begin(); - ThreadPool.QueueUserWorkItem(s => - { - DelaySleep(); - _delayEvent.Complete(); - }); - - _delayEvent.Wait(); - } - - #endregion - } + /// The strategy. + public DelayProvider(DelayStrategy strategy = DelayStrategy.TaskDelay) => this.Strategy = strategy; + + /// + /// Enumerates the different ways of providing delays. + /// + public enum DelayStrategy { + /// + /// Using the Thread.Sleep(15) mechanism. + /// + ThreadSleep, + + /// + /// Using the Task.Delay(1).Wait mechanism. + /// + TaskDelay, + + /// + /// Using a wait event that completes in a background ThreadPool thread. + /// + ThreadPool, + } + + /// + /// Gets the selected delay strategy. + /// + public DelayStrategy Strategy { + get; + } + + /// + /// Creates the smallest possible, synchronous delay based on the selected strategy. + /// + /// The elapsed time of the delay. + public TimeSpan WaitOne() { + lock(this._syncRoot) { + if(this._isDisposed) { + return TimeSpan.Zero; + } + + this._delayStopwatch.Restart(); + + switch(this.Strategy) { + case DelayStrategy.ThreadSleep: + DelaySleep(); + break; + case DelayStrategy.TaskDelay: + DelayTask(); + break; + case DelayStrategy.ThreadPool: + this.DelayThreadPool(); + break; + } + + return this._delayStopwatch.Elapsed; + } + } + + #region Dispose Pattern + + /// + public void Dispose() { + lock(this._syncRoot) { + if(this._isDisposed) { + return; + } + + this._isDisposed = true; + + this._delayEvent?.Dispose(); + } + } + + #endregion + + #region Private Delay Mechanisms + + private static void DelaySleep() => Thread.Sleep(15); + + private static void DelayTask() => Task.Delay(1).Wait(); + + private void DelayThreadPool() { + if(this._delayEvent == null) { + this._delayEvent = WaitEventFactory.Create(isCompleted: true, useSlim: true); + } + + this._delayEvent.Begin(); + _ = ThreadPool.QueueUserWorkItem(s => { + DelaySleep(); + this._delayEvent.Complete(); + }); + + this._delayEvent.Wait(); + } + + #endregion + } } \ No newline at end of file diff --git a/Swan/Threading/ThreadWorkerBase.cs b/Swan/Threading/ThreadWorkerBase.cs index 527a8f9..743dd0d 100644 --- a/Swan/Threading/ThreadWorkerBase.cs +++ b/Swan/Threading/ThreadWorkerBase.cs @@ -1,292 +1,251 @@ -namespace Swan.Threading -{ - using System; - using System.Threading; - using System.Threading.Tasks; - +namespace Swan.Threading { + using System; + using System.Threading; + using System.Threading.Tasks; + + /// + /// Provides a base implementation for application workers + /// that perform continuous, long-running tasks. This class + /// provides the ability to perform fine-grained control on these tasks. + /// + /// + public abstract class ThreadWorkerBase : WorkerBase { + private readonly Object _syncLock = new Object(); + private readonly Thread _thread; + /// - /// Provides a base implementation for application workers - /// that perform continuous, long-running tasks. This class - /// provides the ability to perform fine-grained control on these tasks. + /// Initializes a new instance of the class. /// - /// - public abstract class ThreadWorkerBase : WorkerBase - { - private readonly object _syncLock = new object(); - private readonly Thread _thread; - - /// - /// Initializes a new instance of the class. - /// - /// The name. - /// The thread priority. - /// The interval of cycle execution. - /// The cycle delay provide implementation. - protected ThreadWorkerBase(string name, ThreadPriority priority, TimeSpan period, IWorkerDelayProvider delayProvider) - : base(name, period) - { - DelayProvider = delayProvider; - _thread = new Thread(RunWorkerLoop) - { - IsBackground = true, - Priority = priority, - Name = name, - }; - } - - /// - /// Initializes a new instance of the class. - /// - /// The name. - /// The execution interval. - protected ThreadWorkerBase(string name, TimeSpan period) - : this(name, ThreadPriority.Normal, period, WorkerDelayProvider.Default) - { - // placeholder - } - - /// - /// Provides an implementation on a cycle delay provider. - /// - protected IWorkerDelayProvider DelayProvider { get; } - - /// - public override Task StartAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Paused || WorkerState == WorkerState.Waiting) - return ResumeAsync(); - - if (WorkerState != WorkerState.Created) - return Task.FromResult(WorkerState); - - if (IsStopRequested) - return Task.FromResult(WorkerState); - - var task = QueueStateChange(StateChangeRequest.Start); - _thread.Start(); - return task; - } - } - - /// - public override Task PauseAsync() - { - lock (_syncLock) - { - if (WorkerState != WorkerState.Running && WorkerState != WorkerState.Waiting) - return Task.FromResult(WorkerState); - - return IsStopRequested ? Task.FromResult(WorkerState) : QueueStateChange(StateChangeRequest.Pause); - } - } - - /// - public override Task ResumeAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Created) - return StartAsync(); - - if (WorkerState != WorkerState.Paused && WorkerState != WorkerState.Waiting) - return Task.FromResult(WorkerState); - - return IsStopRequested ? Task.FromResult(WorkerState) : QueueStateChange(StateChangeRequest.Resume); - } - } - - /// - public override Task StopAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Stopped || WorkerState == WorkerState.Created) - { - WorkerState = WorkerState.Stopped; - return Task.FromResult(WorkerState); - } - - return QueueStateChange(StateChangeRequest.Stop); - } - } - - /// - /// Suspends execution queues a new new cycle for execution. The delay is given in - /// milliseconds. When overridden in a derived class the wait handle will be set - /// whenever an interrupt is received. - /// - /// The remaining delay to wait for in the cycle. - /// Contains a reference to a task with the scheduled period delay. - /// The cancellation token to cancel waiting. - protected virtual void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) => - DelayProvider?.ExecuteCycleDelay(wantedDelay, delayTask, token); - - /// - protected override void OnDisposing() - { - lock (_syncLock) - { - if ((_thread.ThreadState & ThreadState.Unstarted) != ThreadState.Unstarted) - _thread.Join(); - } - } - - /// - /// Implements worker control, execution and delay logic in a loop. - /// - private void RunWorkerLoop() - { - while (WorkerState != WorkerState.Stopped && !IsDisposing && !IsDisposed) - { - CycleStopwatch.Restart(); - var interruptToken = CycleCancellation.Token; - var period = Period.TotalMilliseconds >= int.MaxValue ? -1 : Convert.ToInt32(Math.Floor(Period.TotalMilliseconds)); - var delayTask = Task.Delay(period, interruptToken); - var initialWorkerState = WorkerState; - - // Lock the cycle and capture relevant state valid for this cycle - CycleCompletedEvent.Reset(); - - // Process the tasks that are awaiting - if (ProcessStateChangeRequests()) - continue; - - try - { - if (initialWorkerState == WorkerState.Waiting && - !interruptToken.IsCancellationRequested) - { - // Mark the state as Running - WorkerState = WorkerState.Running; - - // Call the execution logic - ExecuteCycleLogic(interruptToken); - } - } - catch (Exception ex) - { - OnCycleException(ex); - } - finally - { - // Update the state - WorkerState = initialWorkerState == WorkerState.Paused + /// The name. + /// The thread priority. + /// The interval of cycle execution. + /// The cycle delay provide implementation. + protected ThreadWorkerBase(String name, ThreadPriority priority, TimeSpan period, IWorkerDelayProvider delayProvider) : base(name, period) { + this.DelayProvider = delayProvider; + this._thread = new Thread(this.RunWorkerLoop) { + IsBackground = true, + Priority = priority, + Name = name, + }; + } + + /// + /// Initializes a new instance of the class. + /// + /// The name. + /// The execution interval. + protected ThreadWorkerBase(String name, TimeSpan period) : this(name, ThreadPriority.Normal, period, WorkerDelayProvider.Default) { + // placeholder + } + + /// + /// Provides an implementation on a cycle delay provider. + /// + protected IWorkerDelayProvider DelayProvider { + get; + } + + /// + public override Task StartAsync() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Paused || this.WorkerState == WorkerState.Waiting) { + return this.ResumeAsync(); + } + + if(this.WorkerState != WorkerState.Created) { + return Task.FromResult(this.WorkerState); + } + + if(this.IsStopRequested) { + return Task.FromResult(this.WorkerState); + } + + Task task = this.QueueStateChange(StateChangeRequest.Start); + this._thread.Start(); + return task; + } + } + + /// + public override Task PauseAsync() { + lock(this._syncLock) { + return this.WorkerState != WorkerState.Running && this.WorkerState != WorkerState.Waiting ? Task.FromResult(this.WorkerState) : this.IsStopRequested ? Task.FromResult(this.WorkerState) : this.QueueStateChange(StateChangeRequest.Pause); + } + } + + /// + public override Task ResumeAsync() { + lock(this._syncLock) { + return this.WorkerState == WorkerState.Created ? this.StartAsync() : this.WorkerState != WorkerState.Paused && this.WorkerState != WorkerState.Waiting ? Task.FromResult(this.WorkerState) : this.IsStopRequested ? Task.FromResult(this.WorkerState) : this.QueueStateChange(StateChangeRequest.Resume); + } + } + + /// + public override Task StopAsync() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Stopped || this.WorkerState == WorkerState.Created) { + this.WorkerState = WorkerState.Stopped; + return Task.FromResult(this.WorkerState); + } + + return this.QueueStateChange(StateChangeRequest.Stop); + } + } + + /// + /// Suspends execution queues a new new cycle for execution. The delay is given in + /// milliseconds. When overridden in a derived class the wait handle will be set + /// whenever an interrupt is received. + /// + /// The remaining delay to wait for in the cycle. + /// Contains a reference to a task with the scheduled period delay. + /// The cancellation token to cancel waiting. + protected virtual void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) => + this.DelayProvider?.ExecuteCycleDelay(wantedDelay, delayTask, token); + + /// + protected override void OnDisposing() { + lock(this._syncLock) { + if((this._thread.ThreadState & ThreadState.Unstarted) != ThreadState.Unstarted) { + this._thread.Join(); + } + } + } + + /// + /// Implements worker control, execution and delay logic in a loop. + /// + private void RunWorkerLoop() { + while(this.WorkerState != WorkerState.Stopped && !this.IsDisposing && !this.IsDisposed) { + this.CycleStopwatch.Restart(); + CancellationToken interruptToken = this.CycleCancellation.Token; + Int32 period = this.Period.TotalMilliseconds >= Int32.MaxValue ? -1 : Convert.ToInt32(Math.Floor(this.Period.TotalMilliseconds)); + Task delayTask = Task.Delay(period, interruptToken); + WorkerState initialWorkerState = this.WorkerState; + + // Lock the cycle and capture relevant state valid for this cycle + this.CycleCompletedEvent.Reset(); + + // Process the tasks that are awaiting + if(this.ProcessStateChangeRequests()) { + continue; + } + + try { + if(initialWorkerState == WorkerState.Waiting && + !interruptToken.IsCancellationRequested) { + // Mark the state as Running + this.WorkerState = WorkerState.Running; + + // Call the execution logic + this.ExecuteCycleLogic(interruptToken); + } + } catch(Exception ex) { + this.OnCycleException(ex); + } finally { + // Update the state + this.WorkerState = initialWorkerState == WorkerState.Paused ? WorkerState.Paused - : WorkerState.Waiting; - - // Signal the cycle has been completed so new cycles can be executed - CycleCompletedEvent.Set(); - - if (!interruptToken.IsCancellationRequested) - { - var cycleDelay = ComputeCycleDelay(initialWorkerState); - if (cycleDelay == Timeout.Infinite) - delayTask = Task.Delay(Timeout.Infinite, interruptToken); - - ExecuteCycleDelay( + : WorkerState.Waiting; + + // Signal the cycle has been completed so new cycles can be executed + this.CycleCompletedEvent.Set(); + + if(!interruptToken.IsCancellationRequested) { + Int32 cycleDelay = this.ComputeCycleDelay(initialWorkerState); + if(cycleDelay == Timeout.Infinite) { + delayTask = Task.Delay(Timeout.Infinite, interruptToken); + } + + this.ExecuteCycleDelay( cycleDelay, delayTask, - CycleCancellation.Token); - } - } - } - - ClearStateChangeRequests(); - WorkerState = WorkerState.Stopped; - } - - /// - /// Queues a transition in worker state for processing. Returns a task that can be awaited - /// when the operation completes. - /// - /// The request. - /// The awaitable task. - private Task QueueStateChange(StateChangeRequest request) - { - lock (_syncLock) - { - if (StateChangeTask != null) - return StateChangeTask; - - var waitingTask = new Task(() => - { - StateChangedEvent.Wait(); - lock (_syncLock) - { - StateChangeTask = null; - return WorkerState; - } - }); - - StateChangeTask = waitingTask; - StateChangedEvent.Reset(); - StateChangeRequests[request] = true; - waitingTask.Start(); - CycleCancellation.Cancel(); - - return waitingTask; - } - } - - /// - /// Processes the state change request by checking pending events and scheduling - /// cycle execution accordingly. The is also updated. - /// - /// Returns true if the execution should be terminated. false otherwise. - private bool ProcessStateChangeRequests() - { - lock (_syncLock) - { - var hasRequest = false; - var currentState = WorkerState; - - // Update the state in the given priority - if (StateChangeRequests[StateChangeRequest.Stop] || IsDisposing || IsDisposed) - { - hasRequest = true; - WorkerState = WorkerState.Stopped; - } - else if (StateChangeRequests[StateChangeRequest.Pause]) - { - hasRequest = true; - WorkerState = WorkerState.Paused; - } - else if (StateChangeRequests[StateChangeRequest.Start] || StateChangeRequests[StateChangeRequest.Resume]) - { - hasRequest = true; - WorkerState = WorkerState.Waiting; - } - - // Signals all state changes to continue - // as a command has been handled. - if (hasRequest) - { - ClearStateChangeRequests(); - OnStateChangeProcessed(currentState, WorkerState); - } - - return hasRequest; - } - } - - /// - /// Signals all state change requests to set. - /// - private void ClearStateChangeRequests() - { - lock (_syncLock) - { - // Mark all events as completed - StateChangeRequests[StateChangeRequest.Start] = false; - StateChangeRequests[StateChangeRequest.Pause] = false; - StateChangeRequests[StateChangeRequest.Resume] = false; - StateChangeRequests[StateChangeRequest.Stop] = false; - - StateChangedEvent.Set(); - CycleCompletedEvent.Set(); - } - } - } + this.CycleCancellation.Token); + } + } + } + + this.ClearStateChangeRequests(); + this.WorkerState = WorkerState.Stopped; + } + + /// + /// Queues a transition in worker state for processing. Returns a task that can be awaited + /// when the operation completes. + /// + /// The request. + /// The awaitable task. + private Task QueueStateChange(StateChangeRequest request) { + lock(this._syncLock) { + if(this.StateChangeTask != null) { + return this.StateChangeTask; + } + + Task waitingTask = new Task(() => { + this.StateChangedEvent.Wait(); + lock(this._syncLock) { + this.StateChangeTask = null; + return this.WorkerState; + } + }); + + this.StateChangeTask = waitingTask; + this.StateChangedEvent.Reset(); + this.StateChangeRequests[request] = true; + waitingTask.Start(); + this.CycleCancellation.Cancel(); + + return waitingTask; + } + } + + /// + /// Processes the state change request by checking pending events and scheduling + /// cycle execution accordingly. The is also updated. + /// + /// Returns true if the execution should be terminated. false otherwise. + private Boolean ProcessStateChangeRequests() { + lock(this._syncLock) { + Boolean hasRequest = false; + WorkerState currentState = this.WorkerState; + + // Update the state in the given priority + if(this.StateChangeRequests[StateChangeRequest.Stop] || this.IsDisposing || this.IsDisposed) { + hasRequest = true; + this.WorkerState = WorkerState.Stopped; + } else if(this.StateChangeRequests[StateChangeRequest.Pause]) { + hasRequest = true; + this.WorkerState = WorkerState.Paused; + } else if(this.StateChangeRequests[StateChangeRequest.Start] || this.StateChangeRequests[StateChangeRequest.Resume]) { + hasRequest = true; + this.WorkerState = WorkerState.Waiting; + } + + // Signals all state changes to continue + // as a command has been handled. + if(hasRequest) { + this.ClearStateChangeRequests(); + this.OnStateChangeProcessed(currentState, this.WorkerState); + } + + return hasRequest; + } + } + + /// + /// Signals all state change requests to set. + /// + private void ClearStateChangeRequests() { + lock(this._syncLock) { + // Mark all events as completed + this.StateChangeRequests[StateChangeRequest.Start] = false; + this.StateChangeRequests[StateChangeRequest.Pause] = false; + this.StateChangeRequests[StateChangeRequest.Resume] = false; + this.StateChangeRequests[StateChangeRequest.Stop] = false; + + this.StateChangedEvent.Set(); + this.CycleCompletedEvent.Set(); + } + } + } } diff --git a/Swan/Threading/TimerWorkerBase.cs b/Swan/Threading/TimerWorkerBase.cs index 2175d7c..2719262 100644 --- a/Swan/Threading/TimerWorkerBase.cs +++ b/Swan/Threading/TimerWorkerBase.cs @@ -1,328 +1,300 @@ -namespace Swan.Threading -{ - using System; - using System.Threading; - using System.Threading.Tasks; - +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Threading { + /// + /// Provides a base implementation for application workers. + /// + /// + public abstract class TimerWorkerBase : WorkerBase { + private readonly Object _syncLock = new Object(); + private readonly Timer _timer; + private Boolean _isTimerAlive = true; + /// - /// Provides a base implementation for application workers. + /// Initializes a new instance of the class. /// - /// - public abstract class TimerWorkerBase : WorkerBase - { - private readonly object _syncLock = new object(); - private readonly Timer _timer; - private bool _isTimerAlive = true; - - /// - /// Initializes a new instance of the class. - /// - /// The name. - /// The execution interval. - protected TimerWorkerBase(string name, TimeSpan period) - : base(name, period) - { - // Instantiate the timer that will be used to schedule cycles - _timer = new Timer( - ExecuteTimerCallback, - this, - Timeout.Infinite, - Timeout.Infinite); - } - - /// - public override Task StartAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Paused || WorkerState == WorkerState.Waiting) - return ResumeAsync(); - - if (WorkerState != WorkerState.Created) - return Task.FromResult(WorkerState); - - if (IsStopRequested) - return Task.FromResult(WorkerState); - - var task = QueueStateChange(StateChangeRequest.Start); - Interrupt(); - return task; - } - } - - /// - public override Task PauseAsync() - { - lock (_syncLock) - { - if (WorkerState != WorkerState.Running && WorkerState != WorkerState.Waiting) - return Task.FromResult(WorkerState); - - if (IsStopRequested) - return Task.FromResult(WorkerState); - - var task = QueueStateChange(StateChangeRequest.Pause); - Interrupt(); - return task; - } - } - - /// - public override Task ResumeAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Created) - return StartAsync(); - - if (WorkerState != WorkerState.Paused && WorkerState != WorkerState.Waiting) - return Task.FromResult(WorkerState); - - if (IsStopRequested) - return Task.FromResult(WorkerState); - - var task = QueueStateChange(StateChangeRequest.Resume); - Interrupt(); - return task; - } - } - - /// - public override Task StopAsync() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Stopped || WorkerState == WorkerState.Created) - { - WorkerState = WorkerState.Stopped; - return Task.FromResult(WorkerState); - } - - var task = QueueStateChange(StateChangeRequest.Stop); - Interrupt(); - return task; - } - } - - /// - /// Schedules a new cycle for execution. The delay is given in - /// milliseconds. Passing a delay of 0 means a new cycle should be executed - /// immediately. - /// - /// The delay. - protected void ScheduleCycle(int delay) - { - lock (_syncLock) - { - if (!_isTimerAlive) return; - _timer.Change(delay, Timeout.Infinite); - } - } - - /// - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - - lock (_syncLock) - { - if (!_isTimerAlive) return; - _isTimerAlive = false; - _timer.Dispose(); - } - } - - /// - /// Cancels the current token and schedules a new cycle immediately. - /// - private void Interrupt() - { - lock (_syncLock) - { - if (WorkerState == WorkerState.Stopped) - return; - - CycleCancellation.Cancel(); - ScheduleCycle(0); - } - } - - /// - /// Executes the worker cycle control logic. - /// This includes processing state change requests, - /// the execution of use cycle code, - /// and the scheduling of new cycles. - /// - private void ExecuteWorkerCycle() - { - CycleStopwatch.Restart(); - - lock (_syncLock) - { - if (IsDisposing || IsDisposed) - { - WorkerState = WorkerState.Stopped; - - // Cancel any awaiters - try { StateChangedEvent.Set(); } - catch { /* Ignore */ } - - return; - } - - // Prevent running another instance of the cycle - if (CycleCompletedEvent.IsSet == false) return; - - // Lock the cycle and capture relevant state valid for this cycle - CycleCompletedEvent.Reset(); - } - - var interruptToken = CycleCancellation.Token; - var initialWorkerState = WorkerState; - - // Process the tasks that are awaiting - if (ProcessStateChangeRequests()) - return; - - try - { - if (initialWorkerState == WorkerState.Waiting && - !interruptToken.IsCancellationRequested) - { - // Mark the state as Running - WorkerState = WorkerState.Running; - - // Call the execution logic - ExecuteCycleLogic(interruptToken); - } - } - catch (Exception ex) - { - OnCycleException(ex); - } - finally - { - // Update the state - WorkerState = initialWorkerState == WorkerState.Paused + /// The name. + /// The execution interval. + protected TimerWorkerBase(String name, TimeSpan period) : base(name, period) => + // Instantiate the timer that will be used to schedule cycles + this._timer = new Timer(this.ExecuteTimerCallback, this, Timeout.Infinite, Timeout.Infinite); + + /// + public override Task StartAsync() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Paused || this.WorkerState == WorkerState.Waiting) { + return this.ResumeAsync(); + } + + if(this.WorkerState != WorkerState.Created) { + return Task.FromResult(this.WorkerState); + } + + if(this.IsStopRequested) { + return Task.FromResult(this.WorkerState); + } + + Task task = this.QueueStateChange(StateChangeRequest.Start); + this.Interrupt(); + return task; + } + } + + /// + public override Task PauseAsync() { + lock(this._syncLock) { + if(this.WorkerState != WorkerState.Running && this.WorkerState != WorkerState.Waiting) { + return Task.FromResult(this.WorkerState); + } + + if(this.IsStopRequested) { + return Task.FromResult(this.WorkerState); + } + + Task task = this.QueueStateChange(StateChangeRequest.Pause); + this.Interrupt(); + return task; + } + } + + /// + public override Task ResumeAsync() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Created) { + return this.StartAsync(); + } + + if(this.WorkerState != WorkerState.Paused && this.WorkerState != WorkerState.Waiting) { + return Task.FromResult(this.WorkerState); + } + + if(this.IsStopRequested) { + return Task.FromResult(this.WorkerState); + } + + Task task = this.QueueStateChange(StateChangeRequest.Resume); + this.Interrupt(); + return task; + } + } + + /// + public override Task StopAsync() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Stopped || this.WorkerState == WorkerState.Created) { + this.WorkerState = WorkerState.Stopped; + return Task.FromResult(this.WorkerState); + } + + Task task = this.QueueStateChange(StateChangeRequest.Stop); + this.Interrupt(); + return task; + } + } + + /// + /// Schedules a new cycle for execution. The delay is given in + /// milliseconds. Passing a delay of 0 means a new cycle should be executed + /// immediately. + /// + /// The delay. + protected void ScheduleCycle(Int32 delay) { + lock(this._syncLock) { + if(!this._isTimerAlive) { + return; + } + + _ = this._timer.Change(delay, Timeout.Infinite); + } + } + + /// + protected override void Dispose(Boolean disposing) { + base.Dispose(disposing); + + lock(this._syncLock) { + if(!this._isTimerAlive) { + return; + } + + this._isTimerAlive = false; + this._timer.Dispose(); + } + } + + /// + /// Cancels the current token and schedules a new cycle immediately. + /// + private void Interrupt() { + lock(this._syncLock) { + if(this.WorkerState == WorkerState.Stopped) { + return; + } + + this.CycleCancellation.Cancel(); + this.ScheduleCycle(0); + } + } + + /// + /// Executes the worker cycle control logic. + /// This includes processing state change requests, + /// the execution of use cycle code, + /// and the scheduling of new cycles. + /// + private void ExecuteWorkerCycle() { + this.CycleStopwatch.Restart(); + + lock(this._syncLock) { + if(this.IsDisposing || this.IsDisposed) { + this.WorkerState = WorkerState.Stopped; + + // Cancel any awaiters + try { + this.StateChangedEvent.Set(); + } catch { /* Ignore */ } + + return; + } + + // Prevent running another instance of the cycle + if(this.CycleCompletedEvent.IsSet == false) { + return; + } + + // Lock the cycle and capture relevant state valid for this cycle + this.CycleCompletedEvent.Reset(); + } + + CancellationToken interruptToken = this.CycleCancellation.Token; + WorkerState initialWorkerState = this.WorkerState; + + // Process the tasks that are awaiting + if(this.ProcessStateChangeRequests()) { + return; + } + + try { + if(initialWorkerState == WorkerState.Waiting && + !interruptToken.IsCancellationRequested) { + // Mark the state as Running + this.WorkerState = WorkerState.Running; + + // Call the execution logic + this.ExecuteCycleLogic(interruptToken); + } + } catch(Exception ex) { + this.OnCycleException(ex); + } finally { + // Update the state + this.WorkerState = initialWorkerState == WorkerState.Paused ? WorkerState.Paused - : WorkerState.Waiting; - - lock (_syncLock) - { - // Signal the cycle has been completed so new cycles can be executed - CycleCompletedEvent.Set(); - - // Schedule a new cycle - ScheduleCycle(!interruptToken.IsCancellationRequested - ? ComputeCycleDelay(initialWorkerState) - : 0); - } - } - } - - /// - /// Represents the callback that is executed when the ticks. - /// - /// The state -- this contains the worker. - private void ExecuteTimerCallback(object state) => ExecuteWorkerCycle(); - - /// - /// Queues a transition in worker state for processing. Returns a task that can be awaited - /// when the operation completes. - /// - /// The request. - /// The awaitable task. - private Task QueueStateChange(StateChangeRequest request) - { - lock (_syncLock) - { - if (StateChangeTask != null) - return StateChangeTask; - - var waitingTask = new Task(() => - { - StateChangedEvent.Wait(); - lock (_syncLock) - { - StateChangeTask = null; - return WorkerState; - } - }); - - StateChangeTask = waitingTask; - StateChangedEvent.Reset(); - StateChangeRequests[request] = true; - waitingTask.Start(); - CycleCancellation.Cancel(); - - return waitingTask; - } - } - - /// - /// Processes the state change queue by checking pending events and scheduling - /// cycle execution accordingly. The is also updated. - /// - /// Returns true if the execution should be terminated. false otherwise. - private bool ProcessStateChangeRequests() - { - lock (_syncLock) - { - var currentState = WorkerState; - var hasRequest = false; - var schedule = 0; - - // Update the state according to request priority - if (StateChangeRequests[StateChangeRequest.Stop] || IsDisposing || IsDisposed) - { - hasRequest = true; - WorkerState = WorkerState.Stopped; - schedule = StateChangeRequests[StateChangeRequest.Stop] ? Timeout.Infinite : 0; - } - else if (StateChangeRequests[StateChangeRequest.Pause]) - { - hasRequest = true; - WorkerState = WorkerState.Paused; - schedule = Timeout.Infinite; - } - else if (StateChangeRequests[StateChangeRequest.Start] || StateChangeRequests[StateChangeRequest.Resume]) - { - hasRequest = true; - WorkerState = WorkerState.Waiting; - } - - // Signals all state changes to continue - // as a command has been handled. - if (hasRequest) - { - ClearStateChangeRequests(schedule, currentState, WorkerState); - } - - return hasRequest; - } - } - - /// - /// Signals all state change requests to set. - /// - /// The cycle schedule. - /// The previous worker state. - /// The new worker state. - private void ClearStateChangeRequests(int schedule, WorkerState oldState, WorkerState newState) - { - lock (_syncLock) - { - // Mark all events as completed - StateChangeRequests[StateChangeRequest.Start] = false; - StateChangeRequests[StateChangeRequest.Pause] = false; - StateChangeRequests[StateChangeRequest.Resume] = false; - StateChangeRequests[StateChangeRequest.Stop] = false; - - StateChangedEvent.Set(); - CycleCompletedEvent.Set(); - OnStateChangeProcessed(oldState, newState); - ScheduleCycle(schedule); - } - } - } + : WorkerState.Waiting; + + lock(this._syncLock) { + // Signal the cycle has been completed so new cycles can be executed + this.CycleCompletedEvent.Set(); + + // Schedule a new cycle + this.ScheduleCycle(!interruptToken.IsCancellationRequested + ? this.ComputeCycleDelay(initialWorkerState) + : 0); + } + } + } + + /// + /// Represents the callback that is executed when the ticks. + /// + /// The state -- this contains the worker. + private void ExecuteTimerCallback(Object state) => this.ExecuteWorkerCycle(); + + /// + /// Queues a transition in worker state for processing. Returns a task that can be awaited + /// when the operation completes. + /// + /// The request. + /// The awaitable task. + private Task QueueStateChange(StateChangeRequest request) { + lock(this._syncLock) { + if(this.StateChangeTask != null) { + return this.StateChangeTask; + } + + Task waitingTask = new Task(() => { + this.StateChangedEvent.Wait(); + lock(this._syncLock) { + this.StateChangeTask = null; + return this.WorkerState; + } + }); + + this.StateChangeTask = waitingTask; + this.StateChangedEvent.Reset(); + this.StateChangeRequests[request] = true; + waitingTask.Start(); + this.CycleCancellation.Cancel(); + + return waitingTask; + } + } + + /// + /// Processes the state change queue by checking pending events and scheduling + /// cycle execution accordingly. The is also updated. + /// + /// Returns true if the execution should be terminated. false otherwise. + private Boolean ProcessStateChangeRequests() { + lock(this._syncLock) { + WorkerState currentState = this.WorkerState; + Boolean hasRequest = false; + Int32 schedule = 0; + + // Update the state according to request priority + if(this.StateChangeRequests[StateChangeRequest.Stop] || this.IsDisposing || this.IsDisposed) { + hasRequest = true; + this.WorkerState = WorkerState.Stopped; + schedule = this.StateChangeRequests[StateChangeRequest.Stop] ? Timeout.Infinite : 0; + } else if(this.StateChangeRequests[StateChangeRequest.Pause]) { + hasRequest = true; + this.WorkerState = WorkerState.Paused; + schedule = Timeout.Infinite; + } else if(this.StateChangeRequests[StateChangeRequest.Start] || this.StateChangeRequests[StateChangeRequest.Resume]) { + hasRequest = true; + this.WorkerState = WorkerState.Waiting; + } + + // Signals all state changes to continue + // as a command has been handled. + if(hasRequest) { + this.ClearStateChangeRequests(schedule, currentState, this.WorkerState); + } + + return hasRequest; + } + } + + /// + /// Signals all state change requests to set. + /// + /// The cycle schedule. + /// The previous worker state. + /// The new worker state. + private void ClearStateChangeRequests(Int32 schedule, WorkerState oldState, WorkerState newState) { + lock(this._syncLock) { + // Mark all events as completed + this.StateChangeRequests[StateChangeRequest.Start] = false; + this.StateChangeRequests[StateChangeRequest.Pause] = false; + this.StateChangeRequests[StateChangeRequest.Resume] = false; + this.StateChangeRequests[StateChangeRequest.Stop] = false; + + this.StateChangedEvent.Set(); + this.CycleCompletedEvent.Set(); + this.OnStateChangeProcessed(oldState, newState); + this.ScheduleCycle(schedule); + } + } + } } diff --git a/Swan/Threading/WorkerBase.cs b/Swan/Threading/WorkerBase.cs index ac3e681..0ce0766 100644 --- a/Swan/Threading/WorkerBase.cs +++ b/Swan/Threading/WorkerBase.cs @@ -1,240 +1,233 @@ -namespace Swan.Threading -{ - using System; - using System.Collections.Generic; - using System.Diagnostics; - using System.Threading; - using System.Threading.Tasks; - +#nullable enable +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Threading { + /// + /// Provides base infrastructure for Timer and Thread workers. + /// + /// + public abstract class WorkerBase : IWorker, IDisposable { + // Since these are API property backers, we use interlocked to read from them + // to avoid deadlocked reads + private readonly Object _syncLock = new Object(); + + private readonly AtomicBoolean _isDisposed = new AtomicBoolean(); + private readonly AtomicBoolean _isDisposing = new AtomicBoolean(); + private readonly AtomicEnum _workerState = new AtomicEnum(WorkerState.Created); + private readonly AtomicTimeSpan _timeSpan; + /// - /// Provides base infrastructure for Timer and Thread workers. + /// Initializes a new instance of the class. /// - /// - public abstract class WorkerBase : IWorker, IDisposable - { - // Since these are API property backers, we use interlocked to read from them - // to avoid deadlocked reads - private readonly object _syncLock = new object(); - - private readonly AtomicBoolean _isDisposed = new AtomicBoolean(); - private readonly AtomicBoolean _isDisposing = new AtomicBoolean(); - private readonly AtomicEnum _workerState = new AtomicEnum(WorkerState.Created); - private readonly AtomicTimeSpan _timeSpan; - - /// - /// Initializes a new instance of the class. - /// - /// The name. - /// The execution interval. - protected WorkerBase(string name, TimeSpan period) - { - Name = name; - _timeSpan = new AtomicTimeSpan(period); - - StateChangeRequests = new Dictionary(5) - { - [StateChangeRequest.Start] = false, - [StateChangeRequest.Pause] = false, - [StateChangeRequest.Resume] = false, - [StateChangeRequest.Stop] = false, - }; - } - - /// - /// Enumerates all the different state change requests. - /// - protected enum StateChangeRequest - { - /// - /// No state change request. - /// - None, - - /// - /// Start state change request - /// - Start, - - /// - /// Pause state change request - /// - Pause, - - /// - /// Resume state change request - /// - Resume, - - /// - /// Stop state change request - /// - Stop, - } - - /// - public string Name { get; } - - /// - public TimeSpan Period - { - get => _timeSpan.Value; - set => _timeSpan.Value = value; - } - - /// - public WorkerState WorkerState - { - get => _workerState.Value; - protected set => _workerState.Value = value; - } - - /// - public bool IsDisposed - { - get => _isDisposed.Value; - protected set => _isDisposed.Value = value; - } - - /// - public bool IsDisposing - { - get => _isDisposing.Value; - protected set => _isDisposing.Value = value; - } - - /// - /// Gets the default period of 15 milliseconds which is the default precision for timers. - /// - protected static TimeSpan DefaultPeriod { get; } = TimeSpan.FromMilliseconds(15); - - /// - /// Gets a value indicating whether stop has been requested. - /// This is useful to prevent more requests from being issued. - /// - protected bool IsStopRequested => StateChangeRequests[StateChangeRequest.Stop]; - - /// - /// Gets the cycle stopwatch. - /// - protected Stopwatch CycleStopwatch { get; } = new Stopwatch(); - - /// - /// Gets the state change requests. - /// - protected Dictionary StateChangeRequests { get; } - - /// - /// Gets the cycle completed event. - /// - protected ManualResetEventSlim CycleCompletedEvent { get; } = new ManualResetEventSlim(true); - - /// - /// Gets the state changed event. - /// - protected ManualResetEventSlim StateChangedEvent { get; } = new ManualResetEventSlim(true); - - /// - /// Gets the cycle logic cancellation owner. - /// - protected CancellationTokenOwner CycleCancellation { get; } = new CancellationTokenOwner(); - - /// - /// Gets or sets the state change task. - /// - protected Task? StateChangeTask { get; set; } - - /// - public abstract Task StartAsync(); - - /// - public abstract Task PauseAsync(); - - /// - public abstract Task ResumeAsync(); - - /// - public abstract Task StopAsync(); - - /// - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - /// - /// Releases unmanaged and - optionally - managed resources. - /// - /// true to release both managed and unmanaged resources; false to release only unmanaged resources. - protected virtual void Dispose(bool disposing) - { - lock (_syncLock) - { - if (IsDisposed || IsDisposing) return; - IsDisposing = true; - } - - // This also ensures the state change queue gets cleared - StopAsync().Wait(); - StateChangedEvent.Set(); - CycleCompletedEvent.Set(); - - OnDisposing(); - - CycleStopwatch.Stop(); - StateChangedEvent.Dispose(); - CycleCompletedEvent.Dispose(); - CycleCancellation.Dispose(); - - IsDisposed = true; - IsDisposing = false; - } - - /// - /// Handles the cycle logic exceptions. - /// - /// The exception that was thrown. - protected abstract void OnCycleException(Exception ex); - - /// - /// Represents the user defined logic to be executed on a single worker cycle. - /// Check the cancellation token continuously if you need responsive interrupts. - /// - /// The cancellation token. - protected abstract void ExecuteCycleLogic(CancellationToken cancellationToken); - - /// - /// This method is called automatically when is called. - /// Makes sure you release all resources within this call. - /// - protected abstract void OnDisposing(); - - /// - /// Called when a state change request is processed. - /// - /// The state before the change. - /// The new state. - protected virtual void OnStateChangeProcessed(WorkerState previousState, WorkerState newState) - { - // placeholder - } - - /// - /// Computes the cycle delay. - /// - /// Initial state of the worker. - /// The number of milliseconds to delay for. - protected int ComputeCycleDelay(WorkerState initialWorkerState) - { - var elapsedMillis = CycleStopwatch.ElapsedMilliseconds; - var period = Period; - var periodMillis = period.TotalMilliseconds; - var delayMillis = periodMillis - elapsedMillis; - - if (initialWorkerState == WorkerState.Paused || period == TimeSpan.MaxValue || delayMillis >= int.MaxValue) - return Timeout.Infinite; - - return elapsedMillis >= periodMillis ? 0 : Convert.ToInt32(Math.Floor(delayMillis)); - } - } + /// The name. + /// The execution interval. + protected WorkerBase(String name, TimeSpan period) { + this.Name = name; + this._timeSpan = new AtomicTimeSpan(period); + + this.StateChangeRequests = new Dictionary(5) { + [StateChangeRequest.Start] = false, + [StateChangeRequest.Pause] = false, + [StateChangeRequest.Resume] = false, + [StateChangeRequest.Stop] = false, + }; + } + + /// + /// Enumerates all the different state change requests. + /// + protected enum StateChangeRequest { + /// + /// No state change request. + /// + None, + + /// + /// Start state change request + /// + Start, + + /// + /// Pause state change request + /// + Pause, + + /// + /// Resume state change request + /// + Resume, + + /// + /// Stop state change request + /// + Stop, + } + + /// + public String Name { + get; + } + + /// + public TimeSpan Period { + get => this._timeSpan.Value; + set => this._timeSpan.Value = value; + } + + /// + public WorkerState WorkerState { + get => this._workerState.Value; + protected set => this._workerState.Value = value; + } + + /// + public Boolean IsDisposed { + get => this._isDisposed.Value; + protected set => this._isDisposed.Value = value; + } + + /// + public Boolean IsDisposing { + get => this._isDisposing.Value; + protected set => this._isDisposing.Value = value; + } + + /// + /// Gets the default period of 15 milliseconds which is the default precision for timers. + /// + protected static TimeSpan DefaultPeriod { get; } = TimeSpan.FromMilliseconds(15); + + /// + /// Gets a value indicating whether stop has been requested. + /// This is useful to prevent more requests from being issued. + /// + protected Boolean IsStopRequested => this.StateChangeRequests[StateChangeRequest.Stop]; + + /// + /// Gets the cycle stopwatch. + /// + protected Stopwatch CycleStopwatch { get; } = new Stopwatch(); + + /// + /// Gets the state change requests. + /// + protected Dictionary StateChangeRequests { + get; + } + + /// + /// Gets the cycle completed event. + /// + protected ManualResetEventSlim CycleCompletedEvent { get; } = new ManualResetEventSlim(true); + + /// + /// Gets the state changed event. + /// + protected ManualResetEventSlim StateChangedEvent { get; } = new ManualResetEventSlim(true); + + /// + /// Gets the cycle logic cancellation owner. + /// + protected CancellationTokenOwner CycleCancellation { get; } = new CancellationTokenOwner(); + + /// + /// Gets or sets the state change task. + /// + protected Task? StateChangeTask { + get; set; + } + + /// + public abstract Task StartAsync(); + + /// + public abstract Task PauseAsync(); + + /// + public abstract Task ResumeAsync(); + + /// + public abstract Task StopAsync(); + + /// + public void Dispose() { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Releases unmanaged and - optionally - managed resources. + /// + /// true to release both managed and unmanaged resources; false to release only unmanaged resources. + protected virtual void Dispose(Boolean disposing) { + lock(this._syncLock) { + if(this.IsDisposed || this.IsDisposing) { + return; + } + + this.IsDisposing = true; + } + + // This also ensures the state change queue gets cleared + this.StopAsync().Wait(); + this.StateChangedEvent.Set(); + this.CycleCompletedEvent.Set(); + + this.OnDisposing(); + + this.CycleStopwatch.Stop(); + this.StateChangedEvent.Dispose(); + this.CycleCompletedEvent.Dispose(); + this.CycleCancellation.Dispose(); + + this.IsDisposed = true; + this.IsDisposing = false; + } + + /// + /// Handles the cycle logic exceptions. + /// + /// The exception that was thrown. + protected abstract void OnCycleException(Exception ex); + + /// + /// Represents the user defined logic to be executed on a single worker cycle. + /// Check the cancellation token continuously if you need responsive interrupts. + /// + /// The cancellation token. + protected abstract void ExecuteCycleLogic(CancellationToken cancellationToken); + + /// + /// This method is called automatically when is called. + /// Makes sure you release all resources within this call. + /// + protected abstract void OnDisposing(); + + /// + /// Called when a state change request is processed. + /// + /// The state before the change. + /// The new state. + protected virtual void OnStateChangeProcessed(WorkerState previousState, WorkerState newState) { + // placeholder + } + + /// + /// Computes the cycle delay. + /// + /// Initial state of the worker. + /// The number of milliseconds to delay for. + protected Int32 ComputeCycleDelay(WorkerState initialWorkerState) { + Int64 elapsedMillis = this.CycleStopwatch.ElapsedMilliseconds; + TimeSpan period = this.Period; + Double periodMillis = period.TotalMilliseconds; + Double delayMillis = periodMillis - elapsedMillis; + + return initialWorkerState == WorkerState.Paused || period == TimeSpan.MaxValue || delayMillis >= Int32.MaxValue ? Timeout.Infinite : elapsedMillis >= periodMillis ? 0 : Convert.ToInt32(Math.Floor(delayMillis)); + } + } } diff --git a/Swan/Threading/WorkerDelayProvider.cs b/Swan/Threading/WorkerDelayProvider.cs index 32e99d2..764e4d4 100644 --- a/Swan/Threading/WorkerDelayProvider.cs +++ b/Swan/Threading/WorkerDelayProvider.cs @@ -1,151 +1,146 @@ -namespace Swan.Threading -{ - using System; - using System.Diagnostics; - using System.Threading; - using System.Threading.Tasks; - +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Swan.Threading { + /// + /// Represents a class that implements delay logic for thread workers. + /// + public static class WorkerDelayProvider { /// - /// Represents a class that implements delay logic for thread workers. + /// Gets the default delay provider. /// - public static class WorkerDelayProvider - { - /// - /// Gets the default delay provider. - /// - public static IWorkerDelayProvider Default => TokenTimeout; - - /// - /// Provides a delay implementation which simply waits on the task and cancels on - /// the cancellation token. - /// - public static IWorkerDelayProvider Token => new TokenCancellableDelay(); - - /// - /// Provides a delay implementation which waits on the task and cancels on both, - /// the cancellation token and a wanted delay timeout. - /// - public static IWorkerDelayProvider TokenTimeout => new TokenTimeoutCancellableDelay(); - - /// - /// Provides a delay implementation which uses short sleep intervals of 5ms. - /// - public static IWorkerDelayProvider TokenSleep => new TokenSleepDelay(); - - /// - /// Provides a delay implementation which uses short delay intervals of 5ms and - /// a wait on the delay task in the final loop. - /// - public static IWorkerDelayProvider SteppedToken => new SteppedTokenDelay(); - - private class TokenCancellableDelay : IWorkerDelayProvider - { - public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) - { - if (wantedDelay == 0 || wantedDelay < -1) - return; - - // for wanted delays of less than 30ms it is not worth - // passing a timeout or a token as it only adds unnecessary - // overhead. - if (wantedDelay <= 30) - { - try { delayTask.Wait(token); } - catch { /* ignore */ } - return; - } - - // only wait on the cancellation token - // or until the task completes normally - try { delayTask.Wait(token); } - catch { /* ignore */ } - } - } - - private class TokenTimeoutCancellableDelay : IWorkerDelayProvider - { - public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) - { - if (wantedDelay == 0 || wantedDelay < -1) - return; - - // for wanted delays of less than 30ms it is not worth - // passing a timeout or a token as it only adds unnecessary - // overhead. - if (wantedDelay <= 30) - { - try { delayTask.Wait(token); } - catch { /* ignore */ } - return; - } - - try { delayTask.Wait(wantedDelay, token); } - catch { /* ignore */ } - } - } - - private class TokenSleepDelay : IWorkerDelayProvider - { - private readonly Stopwatch _elapsedWait = new Stopwatch(); - - public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) - { - _elapsedWait.Restart(); - - if (wantedDelay == 0 || wantedDelay < -1) - return; - - while (!token.IsCancellationRequested) - { - Thread.Sleep(5); - - if (wantedDelay != Timeout.Infinite && _elapsedWait.ElapsedMilliseconds >= wantedDelay) - break; - } - } - } - - private class SteppedTokenDelay : IWorkerDelayProvider - { - private const int StepMilliseconds = 15; - private readonly Stopwatch _elapsedWait = new Stopwatch(); - - public void ExecuteCycleDelay(int wantedDelay, Task delayTask, CancellationToken token) - { - _elapsedWait.Restart(); - - if (wantedDelay == 0 || wantedDelay < -1) - return; - - if (wantedDelay == Timeout.Infinite) - { - try { delayTask.Wait(wantedDelay, token); } - catch { /* Ignore cancelled tasks */ } - return; - } - - while (!token.IsCancellationRequested) - { - var remainingWaitTime = wantedDelay - Convert.ToInt32(_elapsedWait.ElapsedMilliseconds); - - // Exit for no remaining wait time - if (remainingWaitTime <= 0) - break; - - if (remainingWaitTime >= StepMilliseconds) - { - Task.Delay(StepMilliseconds, token).Wait(token); - } - else - { - try { delayTask.Wait(remainingWaitTime); } - catch { /* ignore cancellation of task exception */ } - } - - if (_elapsedWait.ElapsedMilliseconds >= wantedDelay) - break; - } - } - } - } + public static IWorkerDelayProvider Default => TokenTimeout; + + /// + /// Provides a delay implementation which simply waits on the task and cancels on + /// the cancellation token. + /// + public static IWorkerDelayProvider Token => new TokenCancellableDelay(); + + /// + /// Provides a delay implementation which waits on the task and cancels on both, + /// the cancellation token and a wanted delay timeout. + /// + public static IWorkerDelayProvider TokenTimeout => new TokenTimeoutCancellableDelay(); + + /// + /// Provides a delay implementation which uses short sleep intervals of 5ms. + /// + public static IWorkerDelayProvider TokenSleep => new TokenSleepDelay(); + + /// + /// Provides a delay implementation which uses short delay intervals of 5ms and + /// a wait on the delay task in the final loop. + /// + public static IWorkerDelayProvider SteppedToken => new SteppedTokenDelay(); + + private class TokenCancellableDelay : IWorkerDelayProvider { + public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) { + if(wantedDelay == 0 || wantedDelay < -1) { + return; + } + + // for wanted delays of less than 30ms it is not worth + // passing a timeout or a token as it only adds unnecessary + // overhead. + if(wantedDelay <= 30) { + try { + delayTask.Wait(token); + } catch { /* ignore */ } + return; + } + + // only wait on the cancellation token + // or until the task completes normally + try { + delayTask.Wait(token); + } catch { /* ignore */ } + } + } + + private class TokenTimeoutCancellableDelay : IWorkerDelayProvider { + public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) { + if(wantedDelay == 0 || wantedDelay < -1) { + return; + } + + // for wanted delays of less than 30ms it is not worth + // passing a timeout or a token as it only adds unnecessary + // overhead. + if(wantedDelay <= 30) { + try { + delayTask.Wait(token); + } catch { /* ignore */ } + return; + } + + try { + _ = delayTask.Wait(wantedDelay, token); + } catch { /* ignore */ } + } + } + + private class TokenSleepDelay : IWorkerDelayProvider { + private readonly Stopwatch _elapsedWait = new Stopwatch(); + + public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) { + this._elapsedWait.Restart(); + + if(wantedDelay == 0 || wantedDelay < -1) { + return; + } + + while(!token.IsCancellationRequested) { + Thread.Sleep(5); + + if(wantedDelay != Timeout.Infinite && this._elapsedWait.ElapsedMilliseconds >= wantedDelay) { + break; + } + } + } + } + + private class SteppedTokenDelay : IWorkerDelayProvider { + private const Int32 StepMilliseconds = 15; + private readonly Stopwatch _elapsedWait = new Stopwatch(); + + public void ExecuteCycleDelay(Int32 wantedDelay, Task delayTask, CancellationToken token) { + this._elapsedWait.Restart(); + + if(wantedDelay == 0 || wantedDelay < -1) { + return; + } + + if(wantedDelay == Timeout.Infinite) { + try { + _ = delayTask.Wait(wantedDelay, token); + } catch { /* Ignore cancelled tasks */ } + return; + } + + while(!token.IsCancellationRequested) { + Int32 remainingWaitTime = wantedDelay - Convert.ToInt32(this._elapsedWait.ElapsedMilliseconds); + + // Exit for no remaining wait time + if(remainingWaitTime <= 0) { + break; + } + + if(remainingWaitTime >= StepMilliseconds) { + Task.Delay(StepMilliseconds, token).Wait(token); + } else { + try { + _ = delayTask.Wait(remainingWaitTime); + } catch { /* ignore cancellation of task exception */ } + } + + if(this._elapsedWait.ElapsedMilliseconds >= wantedDelay) { + break; + } + } + } + } + } } diff --git a/Swan/ViewModelBase.cs b/Swan/ViewModelBase.cs index 7601871..15c0b2e 100644 --- a/Swan/ViewModelBase.cs +++ b/Swan/ViewModelBase.cs @@ -1,124 +1,118 @@ -using System.Collections.Concurrent; +using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.ComponentModel; using System.Linq; using System.Runtime.CompilerServices; using System.Threading.Tasks; -namespace Swan -{ +namespace Swan { + /// + /// A base class for implementing models that fire notifications when their properties change. + /// This class is ideal for implementing MVVM driven UIs. + /// + /// + public abstract class ViewModelBase : INotifyPropertyChanged { + private readonly ConcurrentDictionary _queuedNotifications = new ConcurrentDictionary(); + private readonly Boolean _useDeferredNotifications; + /// - /// A base class for implementing models that fire notifications when their properties change. - /// This class is ideal for implementing MVVM driven UIs. + /// Initializes a new instance of the class. /// - /// - public abstract class ViewModelBase : INotifyPropertyChanged - { - private readonly ConcurrentDictionary _queuedNotifications = new ConcurrentDictionary(); - private readonly bool _useDeferredNotifications; - - /// - /// Initializes a new instance of the class. - /// - /// Set to true to use deferred notifications in the background. - protected ViewModelBase(bool useDeferredNotifications) - { - _useDeferredNotifications = useDeferredNotifications; - } - - /// - /// Initializes a new instance of the class. - /// - protected ViewModelBase() - : this(false) - { - // placeholder - } - - /// - public event PropertyChangedEventHandler PropertyChanged; - - /// Checks if a property already matches a desired value. Sets the property and - /// notifies listeners only when necessary. - /// Type of the property. - /// Reference to a property with both getter and setter. - /// Desired value for the property. - /// Name of the property used to notify listeners. This - /// value is optional and can be provided automatically when invoked from compilers that - /// support CallerMemberName. - /// An array of property names to notify in addition to notifying the changes on the current property name. - /// True if the value was changed, false if the existing value matched the - /// desired value. - protected bool SetProperty(ref T storage, T value, [CallerMemberName] string propertyName = "", string[] notifyAlso = null) - { - if (EqualityComparer.Default.Equals(storage, value)) - return false; - - storage = value; - NotifyPropertyChanged(propertyName, notifyAlso); - return true; - } - - /// - /// Notifies one or more properties changed. - /// - /// The property names. - protected void NotifyPropertyChanged(params string[] propertyNames) => NotifyPropertyChanged(null, propertyNames); - - /// - /// Notifies one or more properties changed. - /// - /// The main property. - /// The auxiliary properties. - private void NotifyPropertyChanged(string mainProperty, string[] auxiliaryProperties) - { - // Queue property notification - if (string.IsNullOrWhiteSpace(mainProperty) == false) - _queuedNotifications[mainProperty] = true; - - // Set the state for notification properties - if (auxiliaryProperties != null) - { - foreach (var property in auxiliaryProperties) - { - if (string.IsNullOrWhiteSpace(property) == false) - _queuedNotifications[property] = true; - } - } - - // Depending on operation mode, either fire the notifications in the background - // or fire them immediately - if (_useDeferredNotifications) - Task.Run(NotifyQueuedProperties); - else - NotifyQueuedProperties(); - } - - /// - /// Notifies the queued properties and resets the property name to a non-queued stated. - /// - private void NotifyQueuedProperties() - { - // get a snapshot of property names. - var propertyNames = _queuedNotifications.Keys.ToArray(); - - // Iterate through the properties - foreach (var property in propertyNames) - { - // don't notify if we don't have a change - if (!_queuedNotifications[property]) continue; - - // notify and reset queued state to false - try { OnPropertyChanged(property); } - finally { _queuedNotifications[property] = false; } - } - } - - /// - /// Called when a property changes its backing value. - /// - /// Name of the property. - private void OnPropertyChanged(string propertyName) => - PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName ?? string.Empty)); - } + /// Set to true to use deferred notifications in the background. + protected ViewModelBase(Boolean useDeferredNotifications) => this._useDeferredNotifications = useDeferredNotifications; + + /// + /// Initializes a new instance of the class. + /// + protected ViewModelBase() : this(false) { + // placeholder + } + + /// + public event PropertyChangedEventHandler PropertyChanged; + + /// Checks if a property already matches a desired value. Sets the property and + /// notifies listeners only when necessary. + /// Type of the property. + /// Reference to a property with both getter and setter. + /// Desired value for the property. + /// Name of the property used to notify listeners. This + /// value is optional and can be provided automatically when invoked from compilers that + /// support CallerMemberName. + /// An array of property names to notify in addition to notifying the changes on the current property name. + /// True if the value was changed, false if the existing value matched the + /// desired value. + protected Boolean SetProperty(ref T storage, T value, [CallerMemberName] String propertyName = "", String[] notifyAlso = null) { + if(EqualityComparer.Default.Equals(storage, value)) { + return false; + } + + storage = value; + this.NotifyPropertyChanged(propertyName, notifyAlso); + return true; + } + + /// + /// Notifies one or more properties changed. + /// + /// The property names. + protected void NotifyPropertyChanged(params String[] propertyNames) => this.NotifyPropertyChanged(null, propertyNames); + + /// + /// Notifies one or more properties changed. + /// + /// The main property. + /// The auxiliary properties. + private void NotifyPropertyChanged(String mainProperty, String[] auxiliaryProperties) { + // Queue property notification + if(String.IsNullOrWhiteSpace(mainProperty) == false) { + this._queuedNotifications[mainProperty] = true; + } + + // Set the state for notification properties + if(auxiliaryProperties != null) { + foreach(String property in auxiliaryProperties) { + if(String.IsNullOrWhiteSpace(property) == false) { + this._queuedNotifications[property] = true; + } + } + } + + // Depending on operation mode, either fire the notifications in the background + // or fire them immediately + if(this._useDeferredNotifications) { + _ = Task.Run(this.NotifyQueuedProperties); + } else { + this.NotifyQueuedProperties(); + } + } + + /// + /// Notifies the queued properties and resets the property name to a non-queued stated. + /// + private void NotifyQueuedProperties() { + // get a snapshot of property names. + String[] propertyNames = this._queuedNotifications.Keys.ToArray(); + + // Iterate through the properties + foreach(String property in propertyNames) { + // don't notify if we don't have a change + if(!this._queuedNotifications[property]) { + continue; + } + + // notify and reset queued state to false + try { + this.OnPropertyChanged(property); + } finally { this._queuedNotifications[property] = false; } + } + } + + /// + /// Called when a property changes its backing value. + /// + /// Name of the property. + private void OnPropertyChanged(String propertyName) => PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName ?? String.Empty)); + } }