Java-并发编程-04常见工具类及使用

发布时间 2023-12-18 23:58:42作者: VayneBeSelf

四、常用并发工具类

在JDK的并发包里提供了几个非常有用的并发容器和并发工具类。供我们在多线程开发中进行使用。

ConcurrentHashMap

在集合类中HashMap是比较常用的集合对象,但是HashMap是线程不安全的(多线程环境下可能会存在问题)。为了保证数据的安全性我们可以使用Hashtable,但是Hashtable的效率低下。

基于以上两个原因我们可以使用JDK1.5以后所提供的ConcurrentHashMap。

点击查看代码
package com.vayne.concurrent;

import java.util.Date;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author vayne
 * @date 2023-11-03
 */


public class ConcurrentHM {
    public static void main(String[] args) throws InterruptedException {
//        HashMap<String, String> map = new HashMap<>();
//        Hashtable<String, String> map = new Hashtable<>();
        ConcurrentHashMap<String, String> map = new ConcurrentHashMap<>();
        new Thread(() -> {
            
            for (int i = 0; i < 500; i++) {
                map.put(String.valueOf(i),String.valueOf(i));
            }

        }).start();
        new Thread(() -> {
            for (int i = 500; i < 1000; i++) {
                map.put(String.valueOf(i),String.valueOf(i));
            }
        }).start();

        Thread.sleep(2000);

        for(int x = 1 ; x < 1000 ; x++) {
            // HashMap中的键就是当前循环变量的x这个数据的字符串表现形式 , 根据键找到值,然后在进行判断
            if( !String.valueOf(x).equals( map.get(String.valueOf(x)) ) ) {
                System.out.println(x + ":" + map.get(String.valueOf(x)));
            }
        }

    }
}

通过控制台的输出结果,我们可以看到在多线程操作HashMap的时候,可能会出现线程安全问题。

Hashtable保证线程安全的原理:

部分源码:

点击查看代码
public class Hashtable<K,V> extends Dictionary<K,V> implements Map<K,V>, Cloneable, java.io.Serializable {
    
    // Entry数组,一个Entry就相当于一个元素
    private transient Entry<?,?>[] table;
    
    // Entry类的定义
    private static class Entry<K,V> implements Map.Entry<K,V> {
        final int hash;		// 当前key的hash码值
        final K key;		// 键
        V value;			// 值
        Entry<K,V> next;	// 下一个节点
    }
    
    // 存储数据
      public synchronized V put(K key, V value){...}
    
    // 获取数据
    public synchronized V get(Object key){...}
    
    // 获取长度
    public synchronized int size(){...}
    
    ...
    
}

Hashtable保证线程安全性的是使用方法全局锁进行实现的。在线程竞争激烈的情况下HashTable的效率非常低下。因为当一个线程访问HashTable的同步方法,其他线程也访问HashTable的同步方法时,会进入阻塞状态。如线程1使用put进行元素添加,线程2不但不能使用put方法添加元素,也不能使用get方法来获取元素,所以竞争越激烈效率越低。

ConcurrentHash保证线程安全的原理:

jdk1.7版本:

点击查看ConcurrentHash的重要成员变量
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> implements ConcurrentMap<K, V>, Serializable {
    
    /**
     * Segment翻译中文为"段" , 段数组对象
     */
    final Segment<K,V>[] segments;
    
    // Segment是一种可重入锁(ReentrantLock),在ConcurrentHashMap里扮演锁的角色,将一个大的table分割成多个小的table进行加锁。
    static final class Segment<K,V> extends ReentrantLock implements Serializable {
        
        transient volatile int count;    			// Segment中元素的数量,由volatile修饰,支持内存可见性;
        transient int modCount;			 			// 对table的大小造成影响的操作的数量(比如put或者remove操作);
        transient int threshold;		 			// 扩容阈值;
        transient volatile HashEntry<K,V>[] table;  // 链表数组,数组中的每一个元素代表了一个链表的头部;
        final float loadFactor;			 			// 负载因子 
        
    }
    
    // Segment中的元素是以HashEntry的形式存放在数组中的,其结构与普通HashMap的HashEntry基本一致,不同的是Segment的HashEntry,其value由		     // volatile修饰,以支持内存可见性,即写操作对其他读线程即时可见。
    static final class HashEntry<K,V> {
        final int hash;					// 当前节点key对应的哈希码值
        final K key;					// 存储键
        volatile V value;				// 存储值
        volatile HashEntry<K,V> next;	// 下一个节点
    }
    
}

image

简单来讲,就是ConcurrentHashMap比HashMap多了一次hash过程,第1次hash定位到Segment,第2次hash定位到HashEntry,然后链表搜索找到指定节点。

在进行写操作时,只需锁住写元素所在的Segment即可(这种锁被称为分段锁),其他Segment无需加锁,从而产生锁竞争的概率大大减小,提高了并发读写的效率。

该种实现方式的缺点是hash过程比普通的HashMap要长(因为需要进行两次hash操作)。

点击put方法实现代码
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> implements ConcurrentMap<K, V>, Serializable { 
    
    public V put(K key, V value) {
        
        // 定义一个Segment对象
        Segment<K,V> s;
        
        // 如果value的值为空,那么抛出异常
        if (value == null) throw new NullPointerException();
        
        // hash函数获取key的hashCode,然后做了一些处理
        int hash = hash(key);
        
        // 通过key的hashCode定位segment
        int j = (hash >>> segmentShift) & segmentMask;
        
        // 对定位的Segment进行判断,如果Segment为空,调用ensureSegment进行初始化操作(第一次hash定位)
        if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null) 
            s = ensureSegment(j);
        
        // 调用Segment对象的put方法添加元素
        return s.put(key, hash, value, false);
    }
    
    // Segment是一种可ReentrantLock,在ConcurrentHashMap里扮演锁的角色,将一个大的table分割成多个小的table进行加锁。
    static final class Segment<K,V> extends ReentrantLock implements Serializable {
        
        // 添加元素
        final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            
            // 尝试对该段进行加锁,如果加锁失败,则调用scanAndLockForPut方法;在该方法中就要进行再次尝试或者进行自旋等待
            HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                
                // 获取HashEntry数组对象
                HashEntry<K,V>[] tab = table;
                
                // 根据key的hashCode值计算索引(第二次hash定位)
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) 
                    
                    // 若不为null
                    if (e != null) {
                        K k;
                        
                        // 判读当前节点的key是否和链表头节点的key相同(依赖于hashCode方法和equals方法) 
                        // 如果相同,值进行更新
                        if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        
                        e = e.next;
                    } else {  // 若头结点为null
                        
                        // 将新节点添加到链表中
                        if (node != null) 
                            node.setNext(first);
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        
                        // 如果超过阈值,则进行rehash操作
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
        
            return oldValue;
        } 	
        
    }
    
}

** jdk1.8版本:**

在JDK1.8中为了进一步优化ConcurrentHashMap的性能,去掉了Segment分段锁的设计。在数据结构方面,则是跟HashMap一样,使用一个哈希表table数组。(数组 + 链表 + 红黑树) 。

而线程安全方面是结合CAS机制 + 局部锁实现的,减低锁的粒度,提高性能。同时在HashMap的基础上,对哈希表table数组和链表节点的value,next指针等使用volatile来修饰,从而实现线程可见性。

ConcurrentHashMap中的重要成员变量:

点击重要成员变量代码

public class ConcurrentHashMap<K,V> extends AbstractMap<K,V> implements ConcurrentMap<K,V>, Serializable {

    // Node数组
    transient volatile Node<K,V>[] table;

    // Node类的定义
    static class Node<K,V> implements Map.Entry<K,V> { 

        final int hash;				// 当前key的hashCode值
        final K key;				// 键
        volatile V val;				// 值
        volatile Node<K,V> next;	// 下一个节点

    }

    // TreeNode类的定义
    static final class TreeNode<K,V> extends Node<K,V> {
        TreeNode<K,V> parent;  // 父节点
        TreeNode<K,V> left;	   // 左子节点
        TreeNode<K,V> right;   // 右子节点
        TreeNode<K,V> prev;    // needed to unlink next upon deletion
        boolean red;		   // 节点的颜色状态
    }

}

image

点击put方法代码
public class ConcurrentHashMap<K,V> extends AbstractMap<K,V> implements ConcurrentMap<K,V>, Serializable {
    
    // 添加元素
    public V put(K key, V value) {
    	return putVal(key, value, false);
	}
    
    // putVal方法定义
    final V putVal(K key, V value, boolean onlyIfAbsent) {
        
        // key为null直接抛出异常
        if (key == null || value == null) throw new NullPointerException();
        
        // 计算key所对应的hashCode值
        int hash = spread(key.hashCode());
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            
            // 哈希表如果不存在,那么此时初始化哈希表
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            
            // 通过hash值计算key在table表中的索引,将其值赋值给变量i,然后根据索引找到对应的Node,如果Node为null,做出处理
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                
                // 新增链表头结点,cas方式添加到哈希表table
                if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null))) break;                   
            }
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                
                // f为链表头结点,使用synchronized加锁
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            binCount = 1;
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                
                                // 节点已经存在,更新value即可
                                if (e.hash == hash && ((ek = e.key) == key || (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                
                                // 该key对应的节点不存在,则新增节点并添加到该链表的末尾
                                Node<K,V> pred = e;
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key, value, null);
                                    break;
                                }
                                
                            }
                            
                        } else if (f instanceof TreeBin) { // 红黑树节点,则往该红黑树更新或添加该节点即可
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key, value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                
                // 判断是否需要将链表转为红黑树
                if (binCount != 0) {
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        addCount(1L, binCount);
        return null;
    }
    
    // CAS算法的核心类
    private static final sun.misc.Unsafe U;
    static {
        try {
            U = sun.misc.Unsafe.getUnsafe();
            ...
        } catch (Exception e) {
            throw new Error(e);
        }
    }
    
    // 原子获取链表节点
    static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i) {
        return (Node<K,V>)U.getObjectVolatile(tab, ((long)i << ASHIFT) + ABASE);
    }
    
    // CAS更新或新增链表节点
    static final <K,V> boolean casTabAt(Node<K,V>[] tab, int i, Node<K,V> c, Node<K,V> v) {
        return U.compareAndSwapObject(tab, ((long)i << ASHIFT) + ABASE, c, v);
    }
    
}
  1. 如果当前需要put的key对应的链表在哈希表table中还不存在,即还没添加过该key的hash值对应的链表,则调用casTabAt方法,基于CAS机制来实现添加该链表头结点到哈希表table中,避免该线程在添加该链表头结的时候,其他线程也在添加的并发问题;如果CAS失败,则进行自旋,通过继续第2步的操作;

  2. 如果需要添加的链表已经存在哈希表table中,则通过tabAt方法,基于volatile机制,获取当前最新的链表头结点f,由于f指向的是ConcurrentHashMap的哈希表table的某条链表的头结点,故虽然f是临时变量,由于是引用共享的该链表头结点,所以可以使用synchronized关键字来同步多个线程对该链表的访问。在synchronized(f)同步块里面则是与HashMap一样遍历该链表,如果该key对应的链表节点已经存在,则更新,否则在链表的末尾新增该key对应的链表节点。

CountDownLatch

CountDownLatch允许一个或多个线程等待其他线程完成操作以后,再执行当前线程;比如我们在主线程需要开启2个其他线程,当其他的线程执行完毕以后我们再去执行主线程,针对这个需求我们就可以使用CountDownLatch来进行实现。CountDownLatch中count down是倒着数数的意思;CountDownLatch是通过一个计数器来实现的,每当一个线程完成了自己的任务后,可以调用countDown()方法让计数器-1,当计数器到达0时,调用CountDownLatch的await()方法的线程阻塞状态解除,继续执行。

CountDownLatch的相关方法

public CountDownLatch(int count)						// 初始化一个指定计数器的CountDownLatch对象
public void await() throws InterruptedException			// 让当前线程等待
public void countDown()									// 计数器进行减1
点击查看代码
package com.vayne.countdownlatch;

import java.util.concurrent.CountDownLatch;

public class CountDownLatchThread01 implements Runnable {

    // CountDownLatch类型成员变量
    private CountDownLatch countDownLatch ;
    public CountDownLatchThread01(CountDownLatch countDownLatch) {      // 构造方法的作用:接收CountDownLatch对象
        this.countDownLatch = countDownLatch ;
    }

    @Override
    public void run() {

        try {
            Thread.sleep(10000);
            System.out.println("10秒以后执行了CountDownLatchThread01......");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // 调用CountDownLatch对象的countDown方法对计数器进行-1操作
        countDownLatch.countDown();

    }

}
----------------------------------------------------------------------
package com.vayne.countdownlatch;

import java.util.concurrent.CountDownLatch;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class CountDownLatchThread02 implements Runnable{

    private CountDownLatch countDownLatch ;
    public CountDownLatchThread02(CountDownLatch countDownLatch){
        this.countDownLatch = countDownLatch;
    }
    @Override
    public void run() {
        try {
            Thread.sleep(3000);
            System.out.println("3秒后执行了CountDownLatchThread02");
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        countDownLatch.countDown();
    }
}
------------------------------------------------------------------
package com.vayne.countdownlatch;

import java.util.concurrent.CountDownLatch;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class Test {
    public static void main(String[] args) throws InterruptedException {
        CountDownLatch countDownLatch = new CountDownLatch(2);
        CountDownLatchThread01 countDownLatchThread01 = new CountDownLatchThread01(countDownLatch);
        CountDownLatchThread02 countDownLatchThread02 = new CountDownLatchThread02(countDownLatch);

        Thread thread = new Thread(countDownLatchThread01);
        Thread thread1 = new Thread(countDownLatchThread02);

        thread1.start();
        thread.start();

        countDownLatch.await();
        System.out.println("main线程执行了");
    }
}

CountDownLatchThread02线程先执行完毕,此时计数器-1;CountDownLatchThread01线程执行完毕,此时计数器-1;当计数器的值为0的时候,主线程阻塞状态接触,主线程向下执行。

CyclicBarrier

CyclicBarrier的字面意思是可循环使用(Cyclic)的屏障(Barrier)。它要做的事情是,让一组线程到达一个屏障(也可以叫同步点)时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程才会继续运行。

例如:公司召集5名员工开会,等5名员工都到了,会议开始。我们创建5个员工线程,1个开会线程,几乎同时启动,使用CyclicBarrier保证5名员工线程全部执行后,再执行开会线程。

CyclicBarrier的相关方法:

public CyclicBarrier(int parties, Runnable barrierAction)   // 用于在线程到达屏障时,优先执行barrierAction,方便处理更复杂的业务场景
public int await()											// 每个线程调用await方法告诉CyclicBarrier我已经到达了屏障,然后当前线程被阻塞

案例演示:模拟员工开会

实现步骤:

  1. 创建一个员工线程类(EmployeeThread),该线程类中需要定义一个CyclicBarrier类型的形式参数
  2. 创建一个开会线程类(MettingThread)
  3. 测试类
    1. 创建CyclicBarrier对象
    2. 创建5个EmployeeThread线程对象,把第一步创建的CyclicBarrier对象作为构造方法参数传递过来
    3. 启动5个员工线程
点击案例代码
package com.vayne.cyclicbarrier;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class EmployeeThread implements Runnable{
    private CyclicBarrier cyclicBarrier;
    private int time;
    public EmployeeThread(CyclicBarrier cyclicBarrier,int time) {
        this.cyclicBarrier = cyclicBarrier;
        this.time = time;
    }
    @Override
    public void run() {
        try {
            Thread.sleep(time*1000);
            System.out.println(Thread.currentThread().getName() + "在第" + time + "秒到达了会议室等待开会");
            cyclicBarrier.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }
}
----------------------------------------------------------------
package com.vayne.cyclicbarrier;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class MettingThread implements Runnable{
    @Override
    public void run() {
        System.out.println("所有人都已经到了,我们开始开会吧!!");
    }
}
--------------------------------------------------------------
package com.vayne.cyclicbarrier;

import java.util.concurrent.CyclicBarrier;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class Test {
    public static void main(String[] args) {

        CyclicBarrier cyclicBarrier = new CyclicBarrier(5,new MettingThread());
        for (int i = 0; i < 5; i++) {
            EmployeeThread employeeThread = new EmployeeThread(cyclicBarrier, i);
            new Thread(employeeThread).start();
        }
    }
}

Semaphore

Semaphore字面意思是信号量的意思,它的作用是控制访问特定资源的线程数目。

举例:现在有一个十字路口,有多辆汽车需要进经过这个十字路口,但是我们规定同时只能有两辆汽车经过。其他汽车处于等待状态,只要某一个汽车经过了这个十字路口,其他的汽车才可以经过,但是同时只能有两个汽车经过。如何限定经过这个十字路口车辆数目呢? 我们就可以使用Semaphore。

Semaphore的常用方法

public Semaphore(int permits)						permits 表示许可线程的数量
public void acquire() throws InterruptedException	表示获取许可
public void release()								表示释放许可

案例演示:模拟汽车通过十字路口

实现步骤:

  1. 创建一个汽车的线程任务类(CarThreadRunnable),在该类中定义一个Semaphore类型的成员变量
  2. 创建测试类
    1. 创建线程任务类对象
    2. 创建5个线程对象,并启动。(5个线程对象,相当于5辆汽车)
点击查看代码
package com.vayne.semaphore;

import java.util.concurrent.Semaphore;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class CarThreadRunnable implements Runnable{
    private Semaphore semaphore;

    public CarThreadRunnable(Semaphore semaphore){
        this.semaphore = semaphore;
    }
    @Override
    public void run() {
        try {
            semaphore.acquire();
            System.out.println(Thread.currentThread().getName()+"进入了十字路口");
            Thread.sleep(3000);
            System.out.println(Thread.currentThread().getName()+"驶离了十字路口");
            semaphore.release();

        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}
-----------------------------------------------------------------
package com.vayne.semaphore;

import java.util.concurrent.Semaphore;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class Test {
    public static void main(String[] args) {
        Semaphore semaphore = new Semaphore(2);
        for (int i = 0; i < 5; i++) {
            CarThreadRunnable carThreadRunnable = new CarThreadRunnable(semaphore);
            new Thread(carThreadRunnable).start();
        }
    }
}

Exchanger

Exchanger(交换者)是一个用于线程间协作的工具类。Exchanger用于进行线程间的数据交换。

举例:比如男女双方结婚的时候,需要进行交换结婚戒指。

Exchanger常用方法

public Exchanger()							// 构造方法
public V exchange(V x)						// 进行交换数据的方法,参数x表示本方数据 ,返回值v表示对方数据

这两个线程通过exchange方法交换数据,如果第一个线程先执行exchange()方法,它会一直等待第二个线程也执行exchange方法,当两个线程都到达同步点时,这两个线程就可以交换数据,

将本线程生产出来的数据传递给对方。

案例演示:模拟交互结婚戒指

实现步骤:

  1. 创建一个男方的线程类(ManThread),定义一个Exchanger类型的成员变量
  2. 创建一个女方的线程类(WomanThread),定义一个Exchanger类型的成员变量
  3. 测试类
    1. 创建一个Exchanger对象
    2. 创建一个ManThread对象,把第一步所创建的Exchanger作为构造方法参数传递过来
    3. 创建一个WomanThread对象,把第一步所创建的Exchanger作为构造方法参数传递过来
    4. 启动两个线程
点击查看代码
package com.vayne.exchange;

import java.util.concurrent.Exchanger;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class ManThread extends Thread {
    private Exchanger exchanger;

    public ManThread(Exchanger exchanger) {
        this.exchanger = exchanger;
    }

    @Override
    public void run() {
        System.out.println("男人掏出了他的戒指");

        try {
            Object exchange = exchanger.exchange("女生戒指");
            System.out.println("男生拿到了"+exchange);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}
---------------------------------------------------------------
package com.vayne.exchange;

import java.util.concurrent.Exchanger;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class WomanThread extends Thread{
    private Exchanger exchanger;

    public WomanThread(Exchanger exchanger) {
        this.exchanger = exchanger;
    }
    @Override
    public void run() {
        System.out.println("女人掏出了她的戒指");

        try {
            Object result = exchanger.exchange("男生戒指");
            System.out.println("女生拿到了"  + result);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}
----------------------------------------------------------
package com.vayne.exchange;

import java.util.concurrent.Exchanger;

/**
 * @author vayne
 * @date 2023-11-05
 */
public class Test{
    public static void main(String[] args) {
        Exchanger<String> exchanger = new Exchanger<>();
        ManThread manThread = new ManThread(exchanger);
        WomanThread womanThread = new WomanThread(exchanger);
        manThread.start();
        womanThread.start();
    }
}

ThreadLocal

线程安全(是指广义上的共享资源访问安全性,因为线程隔离是通过副本保证本线程访问资源安全性,它不保证线程之间还存在共享关系的狭义上的安全性)的解决思路:

  1. 互斥同步: synchronized 和 ReentrantLock
  2. 非阻塞同步: CAS, AtomicXXXX
  3. 无同步方案: 栈封闭,本地存储(Thread Local),可重入代码

这个章节将详细的讲讲 本地存储(Thread Local)。官网的解释是这样的:

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its {@code get} or {@code set} method) has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID) 该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其 get 或 set 方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。

总结而言:ThreadLocal是一个将在多线程中为每一个线程创建单独的变量副本的类; 当使用ThreadLocal来维护变量时, ThreadLocal会为每个线程创建单独的变量副本, 避免因多线程操作共享变量而导致的数据不一致的情况。

提到ThreadLocal被提到应用最多的是session管理和数据库链接管理,这里以数据访问为例帮助你理解ThreadLocal:

点击查看代码
class ConnectionManager {
    private static Connection connect = null;

    public static Connection openConnection() {
        if (connect == null) {
            connect = DriverManager.getConnection();
        }
        return connect;
    }

    public static void closeConnection() {
        if (connect != null)
            connect.close();
    }
}

很显然,在多线程中使用会存在线程安全问题:

第一,这里面的2个方法都没有进行同步,很可能在openConnection方法中会多次创建connect;

第二,由于connect是共享变量,那么必然在调用connect的地方需要使用到同步来保障线程安全,因为很可能一个线程在使用connect进行数据库操作,而另外一个线程调用closeConnection关闭链接。

为了解决上述线程安全的问题,第一考虑:互斥同步你可能会说,将这段代码的两个方法进行同步处理,并且在调用connect的地方需要进行同步处理,比如用Synchronized或者ReentrantLock互斥锁。

这里再抛出一个问题:这地方到底需不需要将connect变量进行共享?事实上,是不需要的。假如每个线程中都有一个connect变量,各个线程之间对connect变量的访问实际上是没有依赖关系的,即一个线程不需要关心其他线程是否对这个connect进行了修改的。即改后的代码可以这样:

点击查看代码
class ConnectionManager {
    private Connection connect = null;

    public Connection openConnection() {
        if (connect == null) {
            connect = DriverManager.getConnection();
        }
        return connect;
    }

    public void closeConnection() {
        if (connect != null)
            connect.close();
    }
}

class Dao {
    public void insert() {
        ConnectionManager connectionManager = new ConnectionManager();
        Connection connection = connectionManager.openConnection();

        // 使用connection进行操作

        connectionManager.closeConnection();
    }
}

这样处理确实也没有任何问题,由于每次都是在方法内部创建的连接,那么线程之间自然不存在线程安全问题。但是这样会有一个致命的影响:导致服务器压力非常大,并且严重影响程序执行性能。

由于在方法中需要频繁地开启和关闭数据库连接,这样不仅严重影响程序执行效率,还可能导致服务器压力巨大。

那么这种情况下使用ThreadLocal是再适合不过的了,因为ThreadLocal在每个线程中对该变量会创建一个副本,即每个线程内部都会有一个该变量,且在线程内部任何地方都可以使用,线程之间互不影响,这样一来就不存在线程安全问题,也不会严重影响程序执行性能。下面就是网上出现最多的例子:

点击查看代码
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;

public class ConnectionManager {

    private static final ThreadLocal<Connection> dbConnectionLocal = new ThreadLocal<Connection>() {
        @Override
        protected Connection initialValue() {
            try {
                return DriverManager.getConnection("", "", "");
            } catch (SQLException e) {
                e.printStackTrace();
            }
            return null;
        }
    };

    public Connection getConnection() {
        return dbConnectionLocal.get();
    }
}

注意下ThreadLocal的修饰符

ThreaLocal的JDK文档中说明:ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread。如果我们希望通过某个类将状态(例如用户ID、事务ID)与线程关联起来,那么通常在这个类中定义private static类型的ThreadLocal 实例。

Threadlocal原理

ThreadLocal如何实现的线程隔离呢?

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap threadLocals = getMap(t);
    if (threadLocals != null) {
        ThreadLocalMap.Entry e = threadLocals.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

我们观察ThreadLocal的核心方法set发现:

  1. 首先获取当前线程对象t, 然后从线程t中获取到ThreadLocalMap的成员属性threadLocals

  2. 如果当前线程的threadLocals已经初始化(即不为null) 并且存在以当前ThreadLocal对象为Key的值, 则直接返回当前线程要获取的对象(本例中为Connection);

  3. 如果当前线程的threadLocals已经初始化(即不为null)但是不存在以当前ThreadLocal对象为Key的的对象, 那么重新创建一个Connection对象, 并且添加到当前线程的threadLocals Map中,并返回

  4. 如果当前线程的threadLocals属性还没有被初始化, 则重新创建一个ThreadLocalMap对象, 并且创建一个Connection对象并添加到ThreadLocalMap对象中并返回。

如果存在则直接返回很好理解, 那么对于如何初始化的代码又是怎样的呢?

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

  1. 首先调用我们上面写的重载过后的initialValue方法, 产生一个Connection对象

  2. 继续查看当前线程的threadLocals是不是空的, 如果ThreadLocalMap已被初始化, 那么直接将产生的对象添加到ThreadLocalMap中, 如果没有初始化, 则创建并添加对象到其中;

同时, ThreadLocal还提供了直接操作Thread对象中的threadLocals的方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

这样我们也可以不实现initialValue, 将初始化工作放到DBConnectionFactory的getConnection方法中:

public Connection getConnection() {
    Connection connection = dbConnectionLocal.get();
    if (connection == null) {
        try {
            connection = DriverManager.getConnection("", "", "");
            dbConnectionLocal.set(connection);
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
    return connection;
}

那么我们看过代码之后就很清晰的知道了为什么ThreadLocal能够实现变量的多线程隔离了; 其实就是用了Map的数据结构给当前线程缓存了, 要使用的时候就从本线程的threadLocals对象中获取就可以了, key就是当前线程;

当然了在当前线程下获取当前线程里面的Map里面的对象并操作肯定没有线程并发问题了, 当然能做到变量的线程间隔离了;
现在我们知道了ThreadLocal到底是什么了, 又知道了如何使用ThreadLocal以及其基本实现原理了是不是就可以结束了呢? 其实还有一个问题就是ThreadLocalMap是个什么对象, 为什么要用这个对象呢?

ThreadLocalMap

本质上来讲, 它就是一个Map, 但是这个ThreadLocalMap与我们平时见到的Map有点不一样

  1. 它没有实现Map接口;
  2. 它没有public的方法, 最多有一个default的构造方法, 因为这个ThreadLocalMap的方法仅仅3. 在ThreadLocal类中调用, 属于静态内部类
  3. ThreadLocalMap的Entry实现继承了WeakReference<ThreadLocal<?>>
  4. 该方法仅仅用了一个Entry数组来存储Key, Value; Entry并不是链表形式, 而是每个bucket里面仅仅放一个Entry;
private void set(ThreadLocal<?> key, Object value) {
    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.
   

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    // 获取线程本地存储表的引用,并根据key的哈希码和表长度计算槽位位置i

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        // 遍历槽位位置i及其后续位置的条目,直到到达表末尾,然后从表头再开始
        ThreadLocal<?> k = e.get();
        // 获取当前条目的key

        if (k == key) {
            // 如果当前条目的key与指定的key相等
            e.value = value;
            // 更新条目的value
            return;
        }

        if (k == null) {
            // 如果当前条目的key为空
            replaceStaleEntry(key, value, i);
            // 替换为新的条目
            return;
        }
    }

    tab[i] = new Entry(key, value);
    // 在槽位位置i创建新的条目
    int sz = ++size;
    // 更新条目数量
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // 检查是否需要清理槽位并重新哈希
        rehash();
}

先进行简单的分析, 对该代码表层意思进行解读:

  1. 看下当前threadLocal的在数组中的索引位置 比如: i = 2, 看 i = 2 位置上面的元素(Entry)的Key是否等于threadLocal 这个 Key, 如果等于就很好说了, 直接将该位置上面的Entry的Value替换成最新的就可以了;

  2. 如果当前位置上面的 Entry 的 Key为空, 说明ThreadLocal对象已经被回收了, 那么就调用replaceStaleEntry

  3. 如果清理完无用条目(ThreadLocal被回收的条目)、并且数组中的数据大小 > 阈值的时候对当前的Table进行重新哈希 所以, 该HashMap是处理冲突检测的机制是向后移位, 清除过期条目 最终找到合适的位置;

ThreadLocal造成内存泄露的问题

网上有一个栗子:

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class ThreadLocalDemo {
    static class LocalVariable {
        private Long[] a = new Long[1024 * 1024];
    }

    // (1)
    final static ThreadPoolExecutor poolExecutor = new ThreadPoolExecutor(5, 5, 1, TimeUnit.MINUTES,
            new LinkedBlockingQueue<>());
    // (2)
    final static ThreadLocal<LocalVariable> localVariable = new ThreadLocal<LocalVariable>();

    public static void main(String[] args) throws InterruptedException {
        // (3)
        Thread.sleep(5000 * 4);
        for (int i = 0; i < 50; ++i) {
            poolExecutor.execute(new Runnable() {
                public void run() {
                    // (4)
                    localVariable.set(new LocalVariable());
                    // (5)
                    System.out.println("use local varaible" + localVariable.get());
                    localVariable.remove();
                }
            });
        }
        // (6)
        System.out.println("pool execute over");
    }
}

如果用线程池来操作ThreadLocal 对象确实会造成内存泄露, 因为对于线程池里面不会销毁的线程, 里面总会存在着<ThreadLocal, LocalVariable>的强引用, 因为final static 修饰的 ThreadLocal 并不会释放, 而ThreadLocalMap 对于 Key 虽然是弱引用, 但是强引用不会释放, 弱引用当然也会一直有值, 同时创建的LocalVariable对象也不会释放, 就造成了内存泄露; 如果LocalVariable对象不是一个大对象的话, 其实泄露的并不严重, 泄露的内存 = 核心线程数 * LocalVariable对象的大小;所以, 为了避免出现内存泄露的情况, ThreadLocal提供了一个清除线程中对象的方法, 即 remove, 其实内部实现就是调用 ThreadLocalMap 的remove方法

那么如何优雅的进行remove方法,从而避免由于忘记remove而导致的OOM内存泄漏,其实我们可以在企业开发中借助对应的AOP切面编程或者Filter或Interceptor,完成优雅的remove。

ThreadLocal应用场景

除了上述的数据库管理类的例子,我们再看看其它一些应用:

每个线程维护了一个“序列号”:

public class SerialNum {
    // The next serial number to be assigned
    private static int nextSerialNum = 0;

    private static ThreadLocal serialNum = new ThreadLocal() {
        protected synchronized Object initialValue() {
            return new Integer(nextSerialNum++);
        }
    };

    public static int get() {
        return ((Integer) (serialNum.get())).intValue();
    }
}

Session的管理:

private static final ThreadLocal threadSession = new ThreadLocal();  
  
public static Session getSession() throws InfrastructureException {  
    Session s = (Session) threadSession.get();  
    try {  
        if (s == null) {  
            s = getSessionFactory().openSession();  
            threadSession.set(s);  
        }  
    } catch (HibernateException ex) {  
        throw new InfrastructureException(ex);  
    }  
    return s;  
}  

使用场景:在线程内部创建ThreadLocal

还有一种用法是在线程类内部创建ThreadLocal,基本步骤如下:
在多线程的类(如ThreadDemo类)中,创建一个ThreadLocal对象threadXxx,用来保存线程间需要隔离处理的对象xxx。

在ThreadDemo类中,创建一个获取要隔离访问的数据的方法getXxx(),在方法中判断,若ThreadLocal对象为null时候,应该new()一个隔离访问类型的对象,并强制转换为要应用的类型。

在ThreadDemo类的run()方法中,通过调用getXxx()方法获取要操作的数据,这样可以保证每个线程对应一个数据对象,在任何时刻都操作的是这个对象。

public class ThreadLocalTest implements Runnable{
    
    ThreadLocal<Student> StudentThreadLocal = new ThreadLocal<Student>();

    @Override
    public void run() {
        String currentThreadName = Thread.currentThread().getName();
        System.out.println(currentThreadName + " is running...");
        Random random = new Random();
        int age = random.nextInt(100);
        System.out.println(currentThreadName + " is set age: "  + age);
        Student Student = getStudentt(); //通过这个方法,为每个线程都独立的new一个Studentt对象,每个线程的的Studentt对象都可以设置不同的值
        Student.setAge(age);
        System.out.println(currentThreadName + " is first get age: " + Student.getAge());
        try {
            Thread.sleep(500);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println( currentThreadName + " is second get age: " + Student.getAge());
        
    }
    
    private Student getStudentt() {
        Student Student = StudentThreadLocal.get();
        if (null == Student) {
            Student = new Student();
            StudentThreadLocal.set(Student);
        }
        return Student;
    }

    public static void main(String[] args) {
        ThreadLocalTest t = new ThreadLocalTest();
        Thread t1 = new Thread(t,"Thread A");
        Thread t2 = new Thread(t,"Thread B");
        t1.start();
        t2.start();
    }
    
}

class Student{
    int age;
    public int getAge() {
        return age;
    }
    public void setAge(int age) {
        this.age = age;
    }
    
}