【工具类】可重用的CountDownLatch

发布时间 2023-05-09 16:04:01作者: 无所事事O_o
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;

/**
 * 可重用的CountDownLatch
 * 增加reset方法:count值减少到0后,可以通过reset方法重置,可重复使用
 * 增加版本号:可以通过自主控制版本号来实现带有固定周期数的等待和唤醒
 */
public class ReusableCountDownLatch {
    private final Sync sync;
    /**
     * 等待线程的版本号
     */
    private ThreadLocal<Long> threadVersion = new ThreadLocal<>();
    /**
     * 当前对象的最新版本号
     */
    private AtomicLong latchVersion = new AtomicLong(0);

    private final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        /**
         * 记录count值,用于重置时使用
         */
        private int count;
        /**
         * 是否自动重置
         */
        private boolean autoReset;

        Sync(int count) {
            this.count = count;
            this.autoReset = false;
            setState(count);
        }
        Sync(int count,boolean autoReset) {
            this.count = count;
            this.autoReset = autoReset;
            setState(count);
        }

        protected void reset() {
            latchVersion.getAndIncrement();
            setState(count);
        }
        protected void reset(long version) {
            latchVersion.set(version);
            setState(count);
        }

        int getCount() {
            return getState();
        }

        /**
         * 尝试获取共享锁,AQS框架保证了获取锁和释放锁的过程不会出现并发问题
         * @param acquires the acquire argument. This value is always the one
         *        passed to an acquire method, or is the value saved on entry
         *        to a condition wait.  The value is otherwise uninterpreted
         *        and can represent anything you like.
         * @return
         */
        protected int tryAcquireShared(int acquires) {
            Long tVersion = threadVersion.get();
            long lVersion = latchVersion.get();
            if(tVersion != null && lVersion > tVersion) {
                threadVersion.set(null);
                return 1;
            } else if(tVersion != null && lVersion < tVersion) {
                return -1;
            }
            boolean res = getState() == 0;
            if(!res) {
                threadVersion.set(lVersion);
                return -1;
            }
            return 1;
        }

        /**
         * 尝试释放共享锁
         * @param releases the release argument. This value is always the one
         *        passed to a release method, or the current state value upon
         *        entry to a condition wait.  The value is otherwise
         *        uninterpreted and can represent anything you like.
         * @return
         */
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (; ; ) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc)) {
                    boolean res = nextc == 0;
                    if(res && autoReset) {
                        // 自动reset之后才会唤醒等待线程
                        reset();
//                        System.out.println("rest");
                    }
                    return res;
                }

            }
        }
    }

    public ReusableCountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    public ReusableCountDownLatch(int count, boolean autoReset) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count,autoReset);
    }

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public void await(long version) throws InterruptedException {
        threadVersion.set(version);
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
            throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public boolean await(long timeout, TimeUnit unit,long version)
            throws InterruptedException {
        threadVersion.set(version);
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public void countDown() {
        sync.releaseShared(1);
    }

    public long getCount() {
        return sync.getCount();
    }

    public void reset() {
        sync.reset();
    }

    public void reset(long version) {
        sync.reset(version);
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }

    // chatgpt帮忙写的测试用例
    public static void main(String[] args) throws InterruptedException {
        System.out.println("start");
        //带自动重置
        ReusableCountDownLatch latch = new ReusableCountDownLatch(3, true);

        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                latch.countDown();
                System.out.println("Thread finished");
            }).start();
        }
        System.out.println("All threads await");
        latch.await();
        System.out.println("All threads finished");

        // 如果是不自动重置的需要手动重置
        //latch.reset();

        for (int i = 0; i < 3; i++) {
            new Thread(() -> {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                latch.countDown();
                System.out.println("Thread finished");
            }).start();
        }

        latch.await();
        System.out.println("All threads finished again");
    }
}