Java ThreadLocal

发布时间 2023-03-25 23:28:55作者: Dazzling!

ThreadLocal的功能在Java多线程并发环境中非常实用,其作用是提供线程本地变量,例如用户ID、会话ID等与当前线程密切关联的信息。那么它在实际业务场景中可以怎么使用呢?让我们一起来看看下边这个案例。

有一台 Web 服务器,需要设计一个组件,用于记录每次请求完整执行的耗时时长,整体流程如下所示:

img

在这种场景中,假设我们通过构建一张哈希表,然后表的 key 是线程的 id,valud 是对应要存储的内容,那么这种问题就会很好解决了。其实 JDK 里面就正好有一个组件是这么实现的,下边我们来看看如何通过线程本地变量技术来实现这一效果。

下边这个案例中,我写了一个基于 SpringBoot 的 Web 应用,并且在一个自定义的拦截器中定义了一个 ThreadLocal 变量,每次请求的开始阶段,就会将当前时间戳放入到这个 ThreadLocal 变量中。当执行结束之后,就会从 ThreadLocal 变量中提取之前存入的时间数据,然后根据当前时间戳来统计实际请求的耗时。核心的拦截器部分的代码如下:

package 并发编程08.线程本地变量.config;

import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
*  @Author idea
*  @Date  created in 3:22 下午 2022/6/26
*/
@Configuration
public class TimeCountInterceptor implements HandlerInterceptor {

    static class CommonThreadLocal<Long> extends ThreadLocal{
        @Override
        protected Object initialValue() {
            return -1;
        }
    }
    private static ThreadLocal<Long> timeCount = new CommonThreadLocal<>();

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
        System.out.println("提前赋值的获取:"+timeCount.get());
        //中间写逻辑代码,比如判断是否登录成功,失败则返回false
        timeCount.set(System.currentTimeMillis());
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) {
        long currentTime = System.currentTimeMillis();
        long startTime = timeCount.get();
        long timeUse = currentTime - startTime;
        System.out.println(Thread.currentThread().getName() + "耗时:" + timeUse + "ms");
        timeCount.remove();
    }


}

这部分的测试,我们可以通过新建一个SpringBoot应用,然后创建一个接口,并且通过请求借口来查看控制台的打印结果进行验证测试。 例如加入一个Controller和对应的接口:

@RestController
@RequestMapping(value = "/test")
public class TestController {

    @GetMapping(value = "/do-test")
    public String doTest() throws InterruptedException {
        System.out.println("start do test");
        Thread.sleep(2000);
        System.out.println("end do test");
        return "success";
    }
}

然后通过curl请求这个接口之后,查看控制台的内容输出,请求指令如下:

curl 'localhost:8080/test/do-test'

通过控制台的输出,我们可以发现,每个不同的请求的耗时都被记录了下来,具体如下图所示:

image.png

相信好奇的读者会发现,为什么每次请求看起来都是在往同一个 ThreadLocal 中存储数据,但是每次取出计算的结果都是各自独立且不影响的呢?要想知道这一原因,就需要我们深入去了解 ThreadLocal 背后的技术原理了。

ThreadLocal 的理解

ThreadLocal 对象中提供了线程局部变量,它使得每个线程之间互不干扰,一般可以通过重写它的 initialValue 方法机械能赋值。当线程第一次访问它的时候,就会直接触发 initialValue 方法。

在 ThreadLocal 类中,我们最常用的函数应该是 set 和 get,以及 remove ,如果希望深入了解 ThreadLocal 的底层原理的话,建议大家从这几个函数作为切入点去分析。那么下边让我们一起来深入探讨这几个函数的实现原理。

set函数分析

public void set(T value) {
    //获取当前请求的线程    
    Thread t = Thread.currentThread();
    //取出Thread类内部的threadLocals变量,这个变量是一个哈希表结构
    ThreadLocalMap map = getMap(t);
    if (map != null)
        //将需要存储的值放入到这个哈希表中
        map.set(this, value);
    else
        createMap(t, value);
}

这个案例的 createMap,是一个将本地线程变量放入到一个 map 集合中的函数,具体实现部分如下:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

而 ThreadLocalMap 是一个自定义的哈希表,在每个 Thread 类的内部,都会存有一份叫做 ThreadLocalMap 的变量实例,而 key 就是直接对应的 threadLocal 对象,value 是各个线程往这个 threadLocal 中存放的数值,其内部结构如下图所示:

img

threadLocalMap 正是存在于线程的 tcb 当中,而各个 tcb 之间的内存是独立的模块,因此说 threadLocal 在多线程环境使用的情况下才不会有线程安全的问题。

这里我们说到的 threadLocalMap 其实本质是一个哈希表,但是它在设计原理上和我们熟知的 HashMap 其实是有一定出入的。下边我通过一张图来解释下它们两者在存储结构上的差别。

img

上图中展示了两个不同的 map 在存储上的差异,可以很明显地看出,ThreadLocalMap 存储的时候,每个数组下标至多只会存储一个 Entry 元素,每个 Entry 在插入的时候如果出现了 Hash 冲突,则不会采用链地址法去进行规避,而是会采用扩容的思路来解决冲突问题,这一点和 HashMap 是有所出入的。

通过阅读 Entry 的源代码,我们可以更加深刻的认识到这一点,下边是 ThreadLocalMap 类中定义的静态 Entry 的源代码:

static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

通过阅读 Entry 的源代码,我们看到该对象继承了一个 WeakReference,在 JVM 进行 GC 回收的时候,WeakReference 对象是可以被自动注销的,这也就意味着,ThreadLocalMap 中的某些 Entry 对象的 key 可能会是一个 null 值,这一点在上边的结构对比图中也有画出。所以当我们每次使用完 threadLocal 之后,都需要手动调用下它的 remove 接口,将 Entry 中 key 为 null 的对象给清除掉,否则就容易产生内存泄漏的问题。

那么除了手动调用remove操作之外,threadlocal自己有什么机制去清理掉过期数据嘛?

其实,如果你有仔细阅读过 threadlocal 在放入数据 (set) 和提取数据 (get) 的两个函数的实现源代码,就可以看到,在set里面会有对 ThreadlocalMap 中的空 Entry 进行移除(这里的空 Entry 是指 Entry 的 value 值为空,下边一律简称空 Entry)的逻辑,具体细节在下边的源代码中我给指了出来:

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }
        //如果遇到了entry的key为空的情况 ,会将它清空,然后放入当前value  
        if (k == null) {
        //
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    //这里的cleanSomeSlots函数底层也涉及到了回收空entry的逻辑
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        //在迭代过程中,可能会有些过期数据被清空了,所以需要重新调整下各个slot的位置
        //新写入的数据需要被rehash一下。
        rehash();
}

//清理value为空的Entry对象
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            //释放空间
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

除了 set 操作之外,在 get 函数内部也有涉及到空 Entry 的处理逻辑,具体逻辑我们来看看底层的实现原理:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        //getEntry函数里面涉及到了关于空Entry对象的回收逻辑
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}


private Entry getEntry(ThreadLocal<?> key) {
    //首先是通过threadlocal的hashcode和entry数组的长度进行&计算
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    //如果对应位置有值,则直接返还
    if (e != null && e.get() == key)
        return e;
    else
        //深入查看这个函数的底层实现逻辑
        return getEntryAfterMiss(key, i, e);
}


private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
        
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            //释放无用的Entry空间
            expungeStaleEntry(i);
        else
            //渐进式便利每个元素
            i = nextIndex(i, len);
        //对entry数组的各个元素逐个逐个判断    
        e = tab[i];
    }
    return null;
}

在 get 方法中,通过 threadlocal 的 hash 值去按位做与运算,由于threadlocalmap的key是一个没有next指针的entry对象,所以在查询元素的时候,如果计算得到的key不是目标对象,那么就会继续跟着哈希表去做遍历,挨个查询,这一点方面,在解决hash冲突问题上和hashmap有所出入。

那么,如果在挨个查找的过程中有碰到过期数据该怎么处理?

这里就涉及到了 threadlocal 的一种设计策略了,探测式清理过期数据。这种策略会对着 entry 数组挨个遍历,然后判断 map 的 key 是否为 null,如果为 null,就将 entry 的 value 也置为 null,并且回收,同时 table 的长度也会变化。这一个过程在上边的 getEntryAfterMiss 函数中有所体现。

在清理了过期数据之后,如果还有新数据放入 threadlocalmap 的话,就会触发一个叫做 rehash 的函数,将一些无用的 entry 位置给收拢起来,减少内存占用。

通过对 Threadlocal 的原理分析,我们可以看出,JDK 的作者在设计这个类的时候,其实是花费了很多的心思。在 ThreadLocal 内部专门设计了一个叫做 ThreadLocalMap 的数据结构,这个 map 在实现细节中会自动触发一些空数据的回收机制,并且在 hashcode 的分布定位,以及扩容方面都有做相关的调整。

小结

  • ThreadLocal 底层在设计的时候专门设计了一个 ThreadLocalMap 用于管理各个线程的数据存储。

  • ThreadLocalMap 它在内部的每个 Entry 定义的时候都实现了一个弱引用接口,方便于 GC 回收空间。

  • 在每次访问 ThreadLocalMap 的时候,都会采用一种探测式清理过期数据的方式回收内存。

DoubleAdder、LongAdder 两个类。相比于传统的 AtomicLong 原子类来说,这两个类在进行并发处理的时候,其实本质是通过了采用数组的方式去降低并发修改的程度。将单一 value 的更新压力分担到多个 value 中去,降低单个 value 的 “热度”,分段更新!具体原理如下图所示:

image.png

我们以 LongAdder 为例,在它的底层内部存有一个数组,当 LongAddr 初始化的时候,会将初始值放入到 Cells[0] 的位置上,当同时有两个线程尝试修改 LongAddr 变量的时候,它们会对 Cells 最新 null 值的位置进行 CAS 操作,尝试修改该下标位置的状态,如果修改成功了,那就可以将上一个数组位的值作为基础值进行累加,如果说某个线程希望获取当前的 LongAddr 数值,那么只需要将所有数组的值累加起来即可。