TransmittableThreadLocal原理分析

发布时间 2023-06-08 15:46:22作者: JaxYoun

TransmittableThreadLocal原理分析

原文:https://www.cnblogs.com/sglx/p/16018266.html

一、简介

  TransmittableThreadLocal是由阿里开发的线程间变量传递工具包,解决了JDK中InheritableThreadLocal只能在【new Thread】这种有显式父子关系的线程间传递“线程本地变量”,而无法应用到向线程池提交任务的场景。可以应用来做进程内的调用链路追踪、异步线程间传递变量等用途,我们来了解一下原理。

二、InheritableThreadLocal原理

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    // set的时候调用的
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

InheritableThreadLocal重写了父类的三个方法,其中createMap最为关键,他是我们调用ThreadLocal的set方法时调用的。这里是new了一个ThreadLocalMap赋值给了Thread的inheritableThreadLocals变量,那么我就来看一下Thread的属性及方法。

class Thread implements Runnable {

    // 父线程使用
    ThreadLocal.ThreadLocalMap threadLocals = null;

    // 子线程使用
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
    
    // 构造器 创建线程
    public Thread(Runnable target) {
        init(null, target, "Thread-" + nextThreadNum(), 0);
    }

}

当我们使用ThreadLocal的时候,是赋值给threadLocals属性,使用InheritableThreadLocal就是把值又赋给了线程的inheritableThreadLocals属性,那么,可以猜测就是在我们new Thread()的时候触发了线程的值传递,下面通过源码验证猜想

private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) {
  ···

    // 获取父线程
    Thread parent = currentThread();

  ···

    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
      // 当inheritThreadLocals为true,并且父线程的inheritableThreadLocals不为null
      // 将父线程的inheritableThreadLocals赋值给子线程的inheritableThreadLocals
      this.inheritableThreadLocals =
      ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
  ···
}

这样的方式解决了新建线程时的 ThreadLocal 传值问题,但实践中不可能一直新建线程,通用做法是线程复用,比如线程池。但是,递交异步任务时相应的ThreadLocal的值就无法传递过去了。我们希望的是,异步线程执行任务所使用的 ThreadLocal 值,是从提交者线程持有的那份,即从任务提交者传递到任务执行者。

想想,

三、RunnableWrapper & CallableWrapper

如果我们在创建异步任务时,在任务执行代码之外,获取当前线程的变量值并临时保存,再传递给执行线程,在真正的任务执行前保存到当前线程即可。对,确实可以,但是麻烦,每个创建异步任务的地方都要写。那就把这块逻辑封装到递交任务的方法中。

假设按照服务上下文的场景举例,目前项目中的执行异步操作的方案是定义一个 AsyncExecutor ,并声明执行 Supplier 返回 CompletableFuture 的方法。既然这样就可以对方法做一些改造,保证上下文的传递。

private static ThreadLocal<String> contextHolder = new ThreadLocal<>();

public static <T> CompletableFuture<T> invokeToCompletableFuture(Supplier<T> supplier, String errorMessage) {
    // 第一步
    String context = contextHolder.get();
    Supplier<T> newSupplier = () -> {
         // 第二步
        String origin = contextHolder.get();
        try {
            contextHolder.set(context);
            // 第三步
            return supplier.get();
        } finally {
            // 第四步
            contextHolder.set(origin);
            log.info(origin);
        }
    };
    return CompletableFuture.supplyAsync(newSupplier).exceptionally(e -> {
        throw new ServerErrorException(errorMessage, e);
    });
}

// test code
public static void main(String[] args) throws ExecutionException, InterruptedException {
    contextHolder.set("main");
    log.info(contextHolder.get());
    CompletableFuture<String> context = invokeToCompletableFuture(() -> test.contextHolder.get(), "error");
    log.info(context.get());
}

总得来说,就是在将异步任务派发给线程池时,对其做一下上下文传递的处理。

第一步:主线程获取上下文,传递给任务暂存。之后的操作都将是异步执行线程操作的。

第二步:异步执行线程将原有上下文取出,暂时保存。并将主线程传递过来的上下文设置。

第三步:执行异步任务。

第四步:将原有上下文设置回去。

可以看到一般并不会在异步线程执行完任务之后直接进行remove 。而是一开始取出原上下文(可能为NULL,也可能是线程创建时 InheritableThreadLocal继承过来的值,当然最终会被清除的),并在任务执行结束重新放回。这样的方式可以说是异步ThreadLocal传递的标准范式。

这样子既起到了显式清除主线程带来的上下文,也避免了如果线程池的拒绝策略为CallerRunsPolicy ,后续处理时上下文丢失的问题。

Supplier不算是典型例子,更为典型的应该是Runnable和Callable。

四、TransmittableThreadLocal

  TransmittableThreadLocal继承自InheritableThreadLocal,因此它可以在创建线程的时候将值传递给子线程,那么怎么确保使用线程池的时候也有效呢?我们来看一下源码

  1. 构造方法
// 构造器
public TransmittableThreadLocal() {
  this(false);
}

public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {
  // 是否忽略null值set,默认false
  this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;
}
  1. set方法
public final void set(T value) {
  if (!disableIgnoreNullValueSemantics && null == value) {
    // 不忽略null写入,则移除本地线程变量
    remove();
  } else {
    // 调用父类InheritableThreadLocal的set方法
    super.set(value);
    // 将自己添加到静态线程变量holder中
    addThisToHolder();
  }
}

先看addThisToHolder方法

private void addThisToHolder() {
  // 判断holder是否存在此TransmittableThreadLocal对象
  if (!holder.get().containsKey(this)) {
    // 不存则添加进holder
    holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
  }
}

属性holder又是什么呢?

private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
  new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
      return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    }

    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
      return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
    }
  
};

1.final static修饰的变量,只会存在一份
2.使用了WeakHashMap,弱引用,方便垃圾回收
3.key就是TransmittableThreadLocal对象

remove方法

public final void remove() {
  // 从holder中移除
  removeThisFromHolder();
  // 调用父类的移除方法,移除值
  super.remove();
}
  1. get方法
public final T get() {
  // 调用父类的get
  T value = super.get();
  // 如果允许忽略null,或者value不为null,再次添加到holder
  if (disableIgnoreNullValueSemantics || null != value) {
    addThisToHolder();
  }
  return value;
}
  1. 当我们使用线程池时,需要使用TtlRunnable.get(runnable)对runnable进行包装,或者使用TtlExecutors.getTtlExecutor(executor)对执行器进行包装,才能使线程池的变量传递起效果,那么我们就接着看一下源码的执行流程
    TtlExecutors.getTtlExecutor(executor)
public static Executor getTtlExecutor(@Nullable Executor executor) {
  if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
    return executor;
  }
  // 包装执行器
  return new ExecutorTtlWrapper(executor, true);
}

ExecutorTtlWrapper(@NonNull Executor executor, boolean idempotent) { //idempotent是否需要保证幂等
  this.executor = executor;
  this.idempotent = idempotent;
}

public void execute(@NonNull Runnable command) {
  // 实际上也是通过TtlRunnable对原runnable进行包装
  executor.execute(TtlRunnable.get(command, false, idempotent));
}

可以看到,两种方式原理一样,我们直接看TtlRunnable.get()

public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
  if (null == runnable) {
    return null;
  }

  if (runnable instanceof TtlEnhanced) {
    if (idempotent) {
      return (TtlRunnable) runnable;
    } else {
      throw new IllegalStateException("Already TtlRunnable!"); 
    }
  }
  
  // 返回TtlRunnable
  return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

构建TtlRunnable

private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
  // 原子引用
  this.capturedRef = new AtomicReference<Object>(capture());
  this.runnable = runnable;
  this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

capture捕获父线程的ttl

// 存放父线程的值
public static Object capture() {
  return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
  HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
  // 遍历了所有holder
  for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
    // copyValue实际上调用了TransmittableThreadLocal的get方法获取线程存储的变量值
    ttl2Value.put(threadLocal, threadLocal.copyValue());
  }
  return ttl2Value;
}

private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
  final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
  // 
  for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
    final ThreadLocal<Object> threadLocal = entry.getKey();
    final TtlCopier<Object> copier = entry.getValue();
    //
    threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
  }
  return threadLocal2Value;
}

再看TtlRunnable的run方法

public void run() {
  // 获取Snapshot对象,里面存储了父线程的值
  final Object captured = capturedRef.get();
  if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
    throw new IllegalStateException("TTL value reference is released after run!");
  }
  
  // 传入capture方法捕获的ttl,然后在子线程重放,也就是调用ttl的set方法,
  // 这样就会把值设置到当前的线程中去,最后会把子线程之前存在的ttl返回
  final Object backup = replay(captured);
  try {
    // 调用原runnable的run
    runnable.run();
  } finally {
    // 
    restore(backup);
  }
  
}