谈谈对TransmittableThreadLocal的理解

发布时间 2023-08-14 09:36:31作者: 甜菜波波

前言

最近遇到一个问题,公司内部有一个公共的SSO包,用来获取HTTP请求中的登录态,代码中会直接用这个包的方法获取用户登录信息,在代码任意位置直接用SSOUtil.getUser()获取用户信息,在我们一个下载的业务代码中,用到了线程池开启子任务处理请求,结果发现子任务中拿到的用户信息和HTTP请求主线程中的不一致,导致了一些业务问题。

出问题的交互流程如下:

最后在SSO包的源码中排查到了问题所在,SSO包用到了一个TransmittableThreadLocal(本文统一简称TTL)来存储用户信息到当前线程。本文的目的是探究一下TTL的原理,在这之前会先回顾一下ThreadLocal和InheritableThreadLocal的实现原理。

ThreadLocal

ThreadLocal是Java的一个类,顾名思义,“线程本地”变量,用于保存线程私有的变量。ThreadLocal有一个内部类叫ThreadLocalMap,ThreadLocalMap底层是数组,数组中存放多个Entry对象,这个对象的Key是一个指向ThreadLocal变量的WeakReference,Value是当前线程该ThreadLocal变量的值。对于每一个Java线程,在JVM中对应一个Thread对象,每个Thread对象里面持有一个ThreadLocalMap,它们之间的关系用下图表示更加清楚。

ThreadLocal的核心类ThreadLocalMap部分源码如下

static class ThreadLocalMap {

/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

private static final int INITIAL_CAPACITY = 16;

/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;

/**
* The number of entries in the table.
*/
private int size = 0;

/**
* The next size value at which to resize.
*/
private int threshold; // Default to 0


}

InheritableThreadLocal

InheritableThreadLocal也是Java的一个类,它是ThreadLocal的子类,它的出现是为了解决ThreadLocal在线程间传递过程中丢失的问题,上面我们知道Thread中维护了一个Map,用来存放当前线程本地变量的值,但是开启新线程,这个变量就失效了。Thread类还有另一个ThreadLocalMap即inheritableThreadLocals,下面是Thread.init()方法中的一段代码,当线程创建时,会从主线程复制inheritableThreadLocals到子线程,完成父子线程本地变量的传递。

if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
/* Stash the specified stack size in case the VM cares */
this.stackSize = stackSize;

ITL的实现方法也非常巧妙,它继承了ThreadLocal,重写了getMap方法,在ThreadLocal设置值的时候,会通过getMap来拿到ThreadLocalMap,通过重写拿到了Thread类中的inheritableThreadLocals,从而实现了把ITL的值都通过另一个Map来存放。

//InheritableThreadLocal代码
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
protected T childValue(T parentValue) {
return parentValue;
}
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}

//ThreadLocal的set方法
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

TransmittableThreadLocal

ITL类可以完成父线程到子线程的值传递。但对于使用线程池等会池化复用线程的执行组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal值传递已经没有意义,应用需要的实际上是把任务提交给线程池时的ThreadLocal值传递到任务执行时。

TransmittableThreadLocal(TTL)的出现是解决这个问题的,TTL是Alibaba开源的一个类,TTL继承了ITL,使用方式也类似。相比InheritableThreadLocal,添加了copy方法用于定制任务提交给线程池时的ThreadLocal值传递到任务执行时的拷贝行为,缺省传递的是引用。注意:如果跨线程传递了对象引用因为不再有线程封闭,与InheritableThreadLocal.childValue一样,使用者/业务逻辑要注意传递对象的线程安全。

它使用TtlRunnable和TtlCallable来修饰传入线程池的Runnable和Callable。

示例代码:

TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();

// =====================================================

// 在父线程中设置
context.set("value-set-in-parent");

Runnable task = new RunnableTask();
// 额外的处理,生成修饰了的对象ttlRunnable
Runnable ttlRunnable = TtlRunnable.get(task);
executorService.submit(ttlRunnable);

// =====================================================

// Task中可以读取,值是"value-set-in-parent"
String value = context.get();

注意:即使是同一个Runnable任务多次提交到线程池时,每次提交时都需要通过修饰操作(即TtlRunnable.get(task))以抓取这次提交时的TransmittableThreadLocal上下文的值;即如果同一个任务下一次提交时不执行修饰而仍然使用上一次的TtlRunnable,则提交的任务运行时会是之前修饰操作所抓取的上下文。示例代码如下:

// 第一次提交
Runnable task = new RunnableTask();
executorService.submit(TtlRunnable.get(task));

// ...业务逻辑代码,
// 并且修改了 TransmittableThreadLocal上下文 ...
// context.set("value-modified-in-parent");

// 再次提交
// 重新执行修饰,以传递修改了的 TransmittableThreadLocal上下文
executorService.submit(TtlRunnable.get(task));

上述用法对Callable也是类似的。

下面我们通过源码看下TTL是怎么实现池化线程间传递的。看之前可以思考一下,在ITL中,其实是做到了新起子线程时,复制ITL。池化的线程做不到,是因为复用线程场景没有这个触发的时机了,那么TTL一样需要这样的一个触发时机,只不过不是ITL中的Thread.init(),通过上面的用法示例,我们知道这个触发时机实际上就是TtlRunnable.get(),我们可以直接看下get()做了哪些事

//Step1:一层简单包装->调用构造方法创建TtlRunnable
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
if (null == runnable) return null;

if (runnable instanceof TtlEnhanced) {
// avoid redundant decoration, and ensure idempotency
if (idempotent) return (TtlRunnable) runnable;
else throw new IllegalStateException("Already TtlRunnable!");
}
return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

//Step2:创建对象时capture()方法复制了父线程TTL的值,这个值通过holder来维护(set时会调用addThisToHolder()将TTL值设置进holder)
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

//Step3:执行Run方法,先对当前线程中的TTL进行备份,然后通过replay方法将父线程的TTL添加进当前线程,最后在finnaly代码中对之前的TTL进行恢复
@Override
public void run() {
final Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}

final Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}

通过上面的源码可以清楚的了解TTL实现的原理。首先是获得父线程的TTL,然后将子线程的TTLMap进行备份,接着将父线程的TTL循环复制进子线程,最后在子线程执行完runnable.run()以后,将子线程的TTLMap还原。官方推荐的用法中,可以直接包装线程池,原理是类似的,在新线程池的run之前,执行TTLRunnable.get(),这样的用法对业务代码侵入更小,比较推荐。

ExecutorService executorService = ...
// 额外的处理,生成修饰了的对象executorService
executorService = TtlExecutors.getTtlExecutorService(executorService);

回到前言中的问题,SSO的包里面已经引入了TTL,但是我们的只用到了get()、set()方法,其实作用和ITL还是一样的。因此在池化的线程中间,之前设置过用户信息存放在TTL,线程退出后也无法清除,后续其他用户用到了这个线程,获取的还是之前用户的信息,导致了业务异常。

我们可以用TTLRunnable来包装一下业务的Runnable,问题就解决了。

Reference

https://github.com/alibaba/transmittable-thread-local