利用AspectInjector实现AOP篡改方法返回值

发布时间 2023-08-06 19:30:11作者: 一克猫

AspectInjector

一个开源的轻量级AOP框架,满足大多数场景。但由于该框架注入异步方法不是很方便,故记录下解决方案。

封装通用基类

public abstract class BaseUniversalWrapperAspect
    {
        private delegate object Method(object[] args);
        private delegate object Wrapper(Func<object[], object> target, object[] args);
        private delegate object Handler(Func<object[], object> next, object[] args, AspectEventArgs eventArgs);

        private static readonly ConcurrentDictionary<(MethodBase, Type), Lazy<Handler>> _delegateCache = new ConcurrentDictionary<(MethodBase, Type), Lazy<Handler>>();

        private static readonly MethodInfo _asyncGenericHandler =
            typeof(BaseUniversalWrapperAttribute).GetMethod(nameof(BaseUniversalWrapperAttribute.WrapAsync), BindingFlags.NonPublic | BindingFlags.Instance);

        private static readonly MethodInfo _syncGenericHandler =
            typeof(BaseUniversalWrapperAttribute).GetMethod(nameof(BaseUniversalWrapperAttribute.WrapSync), BindingFlags.NonPublic | BindingFlags.Instance);

        private static readonly Type _voidTaskResult = Type.GetType("System.Threading.Tasks.VoidTaskResult");

        protected object BaseHandle(
            object instance,
            Type type,
            MethodBase method,
            Func<object[], object> target,
            string name,
            object[] args,
            Type returnType,
            Attribute[] triggers)
        {
            var eventArgs = new AspectEventArgs
            {
                Instance = instance,
                Type = type,
                Method = method,
                Name = name,
                Args = args,
                ReturnType = returnType,
                Triggers = triggers
            };

            var wrappers = triggers.OfType<BaseUniversalWrapperAttribute>().ToArray();

            var handler = GetMethodHandler(method, returnType, wrappers);
            return handler(target, args, eventArgs);
        }

        private Handler CreateMethodHandler(Type returnType, IReadOnlyList<BaseUniversalWrapperAttribute> wrappers)
        {
            var targetParam = Expression.Parameter(typeof(Func<object[], object>), "orig");
            var eventArgsParam = Expression.Parameter(typeof(AspectEventArgs), "event");

            MethodInfo wrapperMethod;

            if (typeof(Task).IsAssignableFrom(returnType))
            {
                var taskType = returnType.IsConstructedGenericType ? returnType.GenericTypeArguments[0] : _voidTaskResult;
                returnType = typeof(Task<>).MakeGenericType(new[] { taskType });

                wrapperMethod = _asyncGenericHandler.MakeGenericMethod(new[] { taskType });
            }
            else
            {
                if (returnType == typeof(void))
                    returnType = typeof(object);

                wrapperMethod = _syncGenericHandler.MakeGenericMethod(new[] { returnType });
            }

            var converArgs = Expression.Parameter(typeof(object[]), "args");
            var next = Expression.Lambda(Expression.Convert(Expression.Invoke(targetParam, converArgs), returnType), converArgs);

            foreach (var wrapper in wrappers)
            {
                var argsParam = Expression.Parameter(typeof(object[]), "args");
                next = Expression.Lambda(Expression.Call(Expression.Constant(wrapper), wrapperMethod, next, argsParam, eventArgsParam), argsParam);
            }

            var orig_args = Expression.Parameter(typeof(object[]), "orig_args");
            var handler = Expression.Lambda<Handler>(Expression.Convert(Expression.Invoke(next, orig_args), typeof(object)), targetParam, orig_args, eventArgsParam);

            var handlerCompiled = handler.Compile();

            return handlerCompiled;
        }

        private Handler GetMethodHandler(MethodBase method, Type returnType, IReadOnlyList<BaseUniversalWrapperAttribute> wrappers)
        {
            var lazyHandler = _delegateCache.GetOrAdd((method, returnType), _ => new Lazy<Handler>(() => CreateMethodHandler(returnType, wrappers)));
            return lazyHandler.Value;
        }
    }
public abstract class BaseUniversalWrapperAttribute : Attribute
    {
        protected internal virtual T WrapSync<T>(Func<object[], T> target, object[] args, AspectEventArgs eventArgs)
        {
            return target(args);
        }
        protected internal virtual Task<T> WrapAsync<T>(Func<object[], Task<T>> target, object[] args, AspectEventArgs eventArgs)
        {
            return target(args);
        }
    }
public abstract class BaseMethodPointsAspectAttribute : BaseUniversalWrapperAttribute
    {
        protected internal sealed override T WrapSync<T>(Func<object[], T> target, object[] args, AspectEventArgs eventArgs)
        {
            OnBefore(eventArgs);

            try
            {
                var result = base.WrapSync(target, args, eventArgs);
                OnAfter(eventArgs);

                return result;
            }
            catch (Exception exception)
            {
                return OnException<T>(eventArgs, exception);
            }
        }

        protected internal sealed override async Task<T> WrapAsync<T>(Func<object[], Task<T>> target, object[] args, AspectEventArgs eventArgs)
        {
            OnBefore(eventArgs);

            try
            {
                var result = await target(args);
                OnAfter(eventArgs);

                return result;
            }
            catch (Exception exception)
            {
                return OnException<T>(eventArgs, exception);
            }
        }

        protected virtual void OnBefore(AspectEventArgs eventArgs)
        {
        }

        protected virtual void OnAfter(AspectEventArgs eventArgs)
        {
        }

        protected virtual T OnException<T>(AspectEventArgs eventArgs, Exception exception)
        {
            throw exception;
        }
    }
public class AspectEventArgs : EventArgs
    {
        public object Instance { get; internal set; }
        public Type Type { get; internal set; }
        public MethodBase Method { get; internal set; }
        public string Name { get; internal set; }
        public IReadOnlyList<object> Args { get; internal set; }
        public Type ReturnType { get; internal set; }
        public Attribute[] Triggers { get; internal set; }
    }

自定义切面类

用于限制参数长度,不符合条件的传参将修改返回值为错误信息。

 [Aspect(Scope.Global)]
    public class CommandLengthLimitAspect : BaseUniversalWrapperAspect
    {
        [Advice(Kind.Around, Targets = Target.Method)]
        public object Handle(
            [Argument(Source.Instance)] object instance,
            [Argument(Source.Type)] Type type,
            [Argument(Source.Metadata)] MethodBase method,
            [Argument(Source.Target)] Func<object[], object> target,
            [Argument(Source.Name)] string name,
            [Argument(Source.Arguments)] object[] args,
            [Argument(Source.ReturnType)] Type returnType,
            [Argument(Source.Triggers)] Attribute[] triggers)
        {
            var trigger = triggers.OfType<CommandLengthLimitAttribute>().First();
            var parameters = args[0] as List<string>;

            if (parameters is not null && !trigger.LimitLengths.Contains(parameters.Count))
            {
                Func<object[], object> func = GetError;
                return BaseHandle(instance, type, method, func, nameof(GetError), null, returnType, triggers);
            }

            return BaseHandle(instance, type, method, target, name, args, returnType, triggers);
        }

        private async Task<string> GetError(object obj)
        {
            await Task.CompletedTask;
            return "错啦~";
        }
    }

自定义特性

[Injection(typeof(CommandLengthLimitAspect), Inherited = true)]
    [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
    public class CommandLengthLimitAttribute : BaseMethodPointsAspectAttribute
    {
        public int[] LimitLengths { get; }
        public CommandLengthLimitAttribute(params int[] limitLengths) => LimitLengths = limitLengths;

    }

使用示例

commands长度不为1,Todo 方法返回的string结果将被篡改为错误信息。

[CommandLengthLimit(1)]
public async Task<string> Todo(List<string> commands)
{
    // todo..
}