线程间数据传递之ThreadLocal、InheritableThreadLocal、TransmittableThreadLocal

发布时间 2023-04-03 22:23:24作者: loveletters

前言

在JAVA中线程之间传输数据的方式有多种,而本文旨在探讨ThreadLocal及其衍生类的使用场景。

使用场景

  • 业务系统的参数传递:在我们的业务系统中可能会用到许多公共参数,可能是用户的token信息,在我们链路中可能某一个方法需要用到它,那么我们又不想一层层的传递它。
  • 分布式系统要打通各个系统之间的调用链路,他们就需要在每一次系统全链路调用中进行染色,将每一次全链路调用的数据关联在展示平台展示,便于业务工程师后续分析系统调用,链路调用之间的染色、标识都离不开线程间数据传递。
  • spring中基于ThreadLocal来实现事务

ThreadLocal介绍

ThreadLocal又叫做线程本地变量是为每一个Thread创建的一个变量的副本,每个线程都可以在内部访问到这个副本。通过这种方式我们在一个线程的生命周期以内安全的访问这个变量,不用担心被其他线程所污染。这是它相对于全局变量所带来的优势。

ThreadLocal为我们提供了三个public方法

public void set(T value) { ... }
public T get() { ... }
public void remove() { ... }

使用方式如下:

public class ThreadLocalDemo {
    private ThreadLocal<Long> threadLocal = new ThreadLocal<>();

    private void fun1() {
        threadLocal.set(System.nanoTime());
        System.out.println("fun1:" + threadLocal.get());
        fun2();
    }

    private void fun2() {
        System.out.println("fun2:" + threadLocal.get());
        fun3();
    }

    private void fun3() {
        System.out.println("fun3:" + threadLocal.get());
        threadLocal.remove();
    }

    public static void main(String[] args) {
        ThreadLocalDemo demo = new ThreadLocalDemo();
        demo.fun1();
    }
}

结果如下:

我们在先创建了一个本地变量threadLocal 在fun1中set值,在其余的方法中我们并没有通过方法传递的方式显示的将值传递给其他方法,仅仅是通过threadLocal变量get的方式就可以获取到我们所需要的变量,从而实现了变量的跨方法的传递。可能你觉得这样写没什么好处,我不用threadLocal直接用个全局变量照样可以实现数据的传递。好玩的还在后面,我们可以改造一下fun1方法。

 private void fun1() throws InterruptedException {
        threadLocal.set(System.nanoTime());
        System.out.println("fun1:" + threadLocal.get());
        final Thread t1 = new Thread(() -> {
            System.out.println("t1:" + threadLocal.get());
        }, "t1");
        t1.start();
        t1.join();
        fun2();
    }

按照常理来想那么我们t1线程内也应该能获取到数据,但是结果大相径庭。

为什么会这样呢?

查看ThreadLocal的源码

首先它的set方法会首先去获取当前线程,然后再调用getMap方法去获取当前线程的ThradLocalMap,这个线程中的一个本地变量默认为null。

而这个ThreadLocalMap又是ThreadLocal中的一个静态内部类。回到上面的set方法,因为第一次此时的线程中的这个变量还未赋值,所以为null,于是调用createMap

查看构造方法,实际上跟hashmap内部结构类似,也是一个Entry对象真正持有我们存入的value

而这个Entry又是ThreadLocalMap中的一个静态内部类

我们再查看一下get方法

我们可以看到同样是先获取当前线程对象,然后再获取它所持有的ThreadLocalMap,然后根据threadLocal对象为key找到实际上持有数据的Entry

所以说我们可以看到实际上threadLocal对象只是作为了一个key而真正存储数据的是每个线程自身的thread内持有的一个ThreadLocalMap的对象,而我们的t1线程自然就不能获取到数据。

自此我们简单的介绍了ThreadLocal的用法及其get set的原理,但是还没完,ThreadLocal对象提供的一个remove方法是做什么的。为什么我们在结束的时候需要手动去调用remove方法呢?

结论先行:如果使用后不remove可能会有内存泄漏的风险!
我们再回头看看一下ThreadLocal在内存中整体的结构

我们可以看到在线程的栈中ThreadLocal对象引用着堆中的ThreadLocal对象,当前线程对象引用着堆内的线程对象,而我们从上面的源码得知在Thread对象里面持有着一个ThreadLocalMap对象而在这个Map对象里面又持有Entry对象,Entry的里面的key为ThreadLocal对象一个弱引用,而他的value就是我们需要存放的对象的强引用。因为key为弱引用所以当我们的threadlocal对象为null的时候,在下一次gc的时候那么堆内的threadlocal对象就会被回收。如果此时线程还在运行那么就会导致这个value还是被Entry强引用着无法被回收掉导致内存泄露。
Thread Ref -> Thread -> ThreaLocalMap -> Entry -> value
在直接new Thread 这种情况下随着线程的销毁那么这个value对象也还是会被回收,所以存在的风险不大。但是我们平时一般使用线程池,那么就会导致线程不会被销毁,那么就可能会存在内存泄漏的风险。所以我们在平时使用的时候用完了记得及时remove掉以防内存泄漏。

因此我们可以结合jdk7中的try with resource 在自动关闭中remove掉,防止我们忘记remove。

例如:这样就在UserContext中完全封装了ThreadLocal,外部代码在try (resource) {...}内部可以随时调用UserContext.currentUser()获取当前线程绑定的用户名

public class UserContext implements AutoCloseable {

    static final ThreadLocal<String> ctx = new ThreadLocal<>();

    public UserContext(String user) {
        ctx.set(user);
    }

    public static String currentUser() {
        return ctx.get();
    }

    @Override
    public void close() {
        ctx.remove();
    }
}


try (var ctx = new UserContext("Bob")) {
    // 可任意调用UserContext.currentUser():
    String currentUser = UserContext.currentUser();
} // 在此自动调用UserContext.close()方法释放ThreadLocal关联对象

InheritableThreadLocal介绍

在上文中我们介绍了ThreadLocal,它可以在同一个线程中传递数据,但是却无法通过父线程向子线程传递数据,所以我们在最开始的demo中t1线程获取不到数据。但是如果我们有这种诉求,希望父线程能够向子线程传递数据呢,那我们便可以用到InheritableThreadLocal。


public class ThreadLocalDemo {
   
    private ThreadLocal<Long> inheritablethreadlocal = new InheritableThreadLocal<>();

    private void fun1() throws InterruptedException {
        inheritablethreadlocal.set(System.nanoTime());
        System.out.println("fun1:" + inheritablethreadlocal.get());
        final Thread t1 = new Thread(() -> {
            System.out.println("t1:" + inheritablethreadlocal.get());
        }, "t1");
        t1.start();
        t1.join();
        fun2();
    }

    private void fun2() {
        System.out.println("fun2:" + inheritablethreadlocal.get());
        fun3();
    }

    private void fun3() {
        System.out.println("fun3:" + inheritablethreadlocal.get());
        inheritablethreadlocal.remove();
    }

    public static void main(String[] args) throws InterruptedException {
        ThreadLocalDemo demo = new ThreadLocalDemo();
        demo.fun1();

    }
}

查看运行结果:

我们可以看到在t1线程中能够正确的获取到结果。

查看源码,我们可以看到它继承于ThreadLocal,重写了其中的getMap createMap,并且提供了一个childValue方法。

回到demo,它的用法就如图所示,可以为我们在子父线程中传递数据。那么它是如何实现的呢!

我们同样从它的set方法开始

实际上它是基础于ThreadLocal所以它的set方法跟跟ThreadLocal的一样,但是它重写了getMap方法
所以这里getMap就跟之前的有所不同,这里返回给我们的是另一个变量

这个变量的默认值为null

所以会调用createMap方法

然后再调用get方法便可以获取到值

所以对于父线程来说它的存取值跟ThreadLocal一样的只是Thread内部持有的变量不同,那么对于子线程来说又是如何获取到正确的值的呢。

我们来对获取值的这段代码打上断点

通过debug可以知道在我们子线程getMap的时候就返回了threadLocalMap这个对象

那么这个对象是在什么时候被赋值的呢,那只可能在是new Thread的时候通过构造方法赋值的了。我们继续去看Thread的构造方法,我们发现在构造方法中调用了init方法,继续查看init方法可以找到

继续往下追

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        if (name == null) {
            throw new NullPointerException("name cannot be null");
        }

        this.name = name;

        Thread parent = currentThread();
        SecurityManager security = System.getSecurityManager();
        if (g == null) {
            /* Determine if it's an applet or not */

            /* If there is a security manager, ask the security manager
               what to do. */
            if (security != null) {
                g = security.getThreadGroup();
            }

            /* If the security doesn't have a strong opinion of the matter
               use the parent thread group. */
            if (g == null) {
                g = parent.getThreadGroup();
            }
        }

        /* checkAccess regardless of whether or not threadgroup is
           explicitly passed in. */
        g.checkAccess();

        /*
         * Do we have the required permissions?
         */
        if (security != null) {
            if (isCCLOverridden(getClass())) {
                security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
            }
        }

        g.addUnstarted();

        this.group = g;
        this.daemon = parent.isDaemon();
        this.priority = parent.getPriority();
        if (security == null || isCCLOverridden(parent.getClass()))
            this.contextClassLoader = parent.getContextClassLoader();
        else
            this.contextClassLoader = parent.contextClassLoader;
        this.inheritedAccessControlContext =
                acc != null ? acc : AccessController.getContext();
        this.target = target;
        setPriority(priority);
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        /* Stash the specified stack size in case the VM cares */
        this.stackSize = stackSize;

        /* Set thread ID */
        tid = nextThreadID();
    }

这个parent 为当前线程

inheritThreadLocals 这个变量init方法传入的true,并且parent.inheritableThreadLocals在我们之前put的时候通过createMap的方式已经为它赋值过了所以也不可能为null,至此我们可以知道了在我们创建子线程的时候便复制了父线程的ThreadLocalMap,而这个Map我们之前已经说过了就是真正存储存放数据的地方,所以说我们在子线程中便能够拿到父线程所存放的数据。

自此我们已经知道了InheritableThreadLocal是怎么实现父线程数据传递给子线程的了,那么问题来了,它是线程安全的吗?

先看一个例子:

public class ThreadLocalDemo {

    private ThreadLocal<Person> inheritablethreadlocal = new InheritableThreadLocal<>();

    private void fun1() throws InterruptedException {
        final Person person = new Person();
        person.setName("张三");
        inheritablethreadlocal.set(person);
        System.out.println("fun1:" + inheritablethreadlocal.get());
        final Thread t1 = new Thread(() -> {
            Person p = (Person) inheritablethreadlocal.get();
            p.setName("李四");
            System.out.println("t1:" + inheritablethreadlocal.get());
        }, "t1");
        t1.start();
        t1.join();
        fun2();
    }

    private void fun2() {
        System.out.println("fun2:" + inheritablethreadlocal.get());
        fun3();
    }

    private void fun3() {
        System.out.println("fun3:" + inheritablethreadlocal.get());
        inheritablethreadlocal.remove();
    }

    public static void main(String[] args) throws InterruptedException {
        ThreadLocalDemo demo = new ThreadLocalDemo();
        demo.fun1();

    }
}

class Person{
    String name;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    @Override
    public String toString() {
        return "Person{" +
                "name='" + name + '\'' +
                '}';
    }

如果它是线程安全的那么,在t1线程中的修改应该不能够影响到其他线程中的值,我们看结果

显而易见它并不是线程安全的,为什么会这样呢?我们回到源码继续查看new Thread的时候子线程复制父线程的ThreadLocalMap的时候

我们进入这个createInheritedMap方法,里面是直接调用ThreadLocalMap的构造方法,继续查看构造方法。

这一步复制value的时候直接调用childValue方法

在默认方法里面直接返回了父线程的里面持有的对象,如果是值类型不会存在问题,如果是引用类型,因为是同一个内存地址,所以会导致线程不安全的问题。

知道了原因那么也好解决,那就是我们重写childValue方法,在里面做一个深拷贝就可以了。

 private ThreadLocal<Person> inheritablethreadlocal = new InheritableThreadLocal(){
        @Override
        protected Object childValue(Object parentValue) {
            Person p = (Person) parentValue;
            final Person child = new Person();
            child.setName(p.getName());
            return child;
        }
    };

看起来似乎是解决了问题,结果也很美好。那么进一步思考下,我们这是在父线程给子线程传值,如果是两个不相关的线程呢,比如说在线程池中会怎么样呢?

public class InheritableThreadLocalDemo {
    static ExecutorService executorService = new ThreadPoolExecutor(2, 2, 60, TimeUnit.SECONDS,new LinkedBlockingQueue<>(100));
    static ThreadLocal<Integer> inheritableThreadLocal = new InheritableThreadLocal<>();


    public static void main(String[] args) {
        System.out.println("主线程开始");

        for (int i = 0; i < 100; i++) {
            inheritableThreadLocal.set(i);
            System.out.println("主线程获取值:"+inheritableThreadLocal.get());
            executorService.execute(new RunnableDemo());
            inheritableThreadLocal.remove();

        }
    }

    private static class RunnableDemo implements Runnable{

        @Override
        public void run() {
            System.out.println("子线程获取值:"+inheritableThreadLocal.get());
        }
    }
}

按照我们上文的分析,那也结果也应该是子线程跟主线程都会输出从0~99,但是结果却大相径庭。

主线程的值确实是从0~99,但是子线程的取值却始终为0跟1.从上文的分析可以知道直接new Thread方式会拷贝父线程的ThreadLocalMap,但是我们这里是一个Runnable对象,它又是怎么获取到的ThreadLocalMap呢?
那回答这个问题我们首先得回顾一下线程池的工作原理。

我们新创建了一个Runnable对象放入线程池中通过execute方法执行,会首先判断线程池的核心线程数有没有达到最大,如果还没达到最大那么新启动一个work线程,如果达到最大,那么接着判断我们给的队列是否满了,如果还未满就入队,如果已经满了继续判断最大线程数是否达到最大,如果还未达到最大则继续新启动一个work线程,已经达到最大就执行拒绝策略。

回到代码中来,我们查看execute方法

先判断工作线程是否小于核心线程,如果小于就创建一个work线程

继续往下查看

查看Worker的构造方法

可以看到Worker中持有一个Runable对象以及一个Thread对象,我们传入的Runable对象复制给Worker中的一个成员变量。而下面的那个newThread方法中传入的this值得就是这个worker对象,通过这个方法把Runnable对象变成了一个Thread对象.

因为我们这里没有值得ThreadFactory,所以会使用到DefaultThreadFactory,newThread内会调用到Thread的构造方法,于是我们可以知道原来是new Worker对象的时候会去复制主线程的ThreadLocalMap.

这里会执行t.start()方法

也就是回去执行里面的worker对象的run方法

接着调用runWorker方法

这个task就是我们在worker中的firstTask也就是我们业务中传入的Runnable对象,这里有个while循环如果task不等于null直接执行,并且在执行完成后将task置为null。然后这个getTask方法为从队列中取出一个Runnable,而这个worker一直在这里存活着,所以在它内部持有的Thread对象的值还是最开始new Worker时候创建的,所以它内部的ThreadLocalMap也为最开始复制父线程的值,因为我们创建的线程池的核心线程数为2个,所以会创建2个Worker对象,即我们也只是复制了2次值,那么自然我们取值的时候也只有最开始复制的那2个值。

如果我们想在线程池等复用线程的组建中,使用ThreadLocal值的传递功能,来解决异步执行时上下文传递,那么应该如何处理呢?

TransmittableThreadLocal介绍

JDK的InheritableThreadLocal类可以完成父线程到子线程的值传递。但对于使用线程池等会池化复用线程的执行组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal值传递已经没有意义,应用需要的实际上是把 任务提交给线程池时的ThreadLocal值传递到 任务执行时。TransmittableThreadLocal类继承并加强InheritableThreadLocal类,解决上述的问题。
InheritableThreadLocal为阿里开源的一个组件,所以我们在使用的时候需要添加如下依赖:

<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>transmittable-thread-local</artifactId>
    <version>2.12.6</version>
</dependency>

示例如下:

public class TransmittableThreadLocalDemo {
    static ExecutorService executorService = new ThreadPoolExecutor(2, 2, 60, TimeUnit.SECONDS,new LinkedBlockingQueue<>(100));
    static ThreadLocal<Integer> inheritableThreadLocal = new TransmittableThreadLocal<>();


    public static void main(String[] args) {
        System.out.println("主线程开始");

        for (int i = 0; i < 100; i++) {
            inheritableThreadLocal.set(i);
            System.out.println("主线程获取值:"+inheritableThreadLocal.get());
            executorService.execute(TtlRunnable.get(new RunnableDemo()));
            inheritableThreadLocal.remove();

        }
    }

    private static class RunnableDemo implements Runnable{

        @Override
        public void run() {
            System.out.println("子线程获取值:"+inheritableThreadLocal.get());
        }
    }
}

结果如下:

这样我们就可以在子线程中正确的获取到想要的结果了。那么它究竟是如何实现的呢?我们回到源码。

我们可以看到TransmittableThreadLocal是继承于InheritableThreadLocal的,并且其中有一个全局的静态变量holder,用于存储使用 TransmittableThreadLocal set 的上下文。这个hodler本身是一个InheritableThreadLocal,其中存储的是一个WeakHashMap而且这个map的key为我们所存储的值,它的value始终为null。

    private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };
  • initialValue 方法会在 InheritableThreadLocal 创建时被调用,默认创建一个 WeakHashMap。
  • childValue 方法会在创建子线程时,Thread 调用 init 方法,会调用 - -- ThreadLocal.createInheritedMap(parent.inheritableThreadLocals),createInheritedMap 中会创建 ThreadLocalMap,ThreadLocalMap 的构造方法中会调用 childValue 方法

我们从set方法开始看看ttl是怎么存储数据的:

super.set方法

我们继续查看addThisToHolder方法

我们可以看到我们往这个对象的静态成员变量中存放了一个自己的引用,即在WeakHshMap中put了一个key为当前实例对象的引用,值为null。

而调用get方法则是

super.get

通过当前thread获取获取holder中的keySet,通过遍历keySet获取ttl,通过ttl委托ThreadLocal维护变量的值