JUC之CountDownLatch

java / 344人浏览 / 0人评论
什么是CountDownLatch?
CountDownLatch是一个在JDK版本1.5时引入的同步工具类,它允许一个或多个线程一直等待,直到其他线程执行完之后再执行,跟它一起被引入的并发工具类还有:CyclicBarrier、Semaphore、ConcurrentHashMap和BlockingQueue,都存在于java.util.concurrent包下。

CountDownLatch原理

CountDownLatch通过一个计数器来实现,计数器的初始化值为线程的数量,每当一个线程完成了自己的任务后,计数器的值就相应减1,当计数器等于0时,表示所有线程都已完成,然后在闭锁上等待的线程恢复继续执行任务。(一次性操作,因为计数无法重置,如果需要一个重置版本计数,可以考虑CyclicBarrier)

CountDownLatch内部结构

输入图片说明
CountDownLatch()
初始化计数器,初始化值如果小于0则抛出异常IllegalArgumentException

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

await()
调用await()方法的线程会被挂起,一直等待到count值为0的时候才继续执行

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
    throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

await(long,TimeUnit)
跟await()方法一样,只不过等待一定的时间后count值还没变为0的时候继续执行

public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
    throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    return tryAcquireShared(arg) >= 0 ||
        doAcquireSharedNanos(arg, nanosTimeout);
}

countDown()
递减锁存器的计数,如果计数等于零,则释放所有等待线程

public void countDown() {
    sync.releaseShared(1);
}
/**
  * 共享模式释放对象
  * 在函数中会调用CountDownLatch的tryReleaseShared函数
  * 当计数器返回0时,会调用AQS的doReleaseShared函数
  */
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}
private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

Sync
Sync是一个内部静态类,继承自AbstractQueuedSynchronizer,

/**
  * CountDownLatch的同步控制器
  * 使用AQS状态表示计数
  */
private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;
    
    // 构造方法CountDownLatch的构造方法最终调用的是Sync的构造
    Sync(int count) {
        setState(count); // 初始化count
    }
    
    // 返回当前count计数
    int getCount() {
        return getState();
    }
    
    // 试图在共享模式下获取对象状态
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }
    // 试图设置状态来反映共享模式下的一个释放
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

简单示例CountDownLatch

假设我们接到一个需求,需要从数据库获取数据,处理数据,之后再保存至数据库。

public static void main(String[] args){
        //获取数据
        int[] data = query();
        System.out.println("获取数据完毕");
        //处理数据
        ExecutorService executorService = Executors.newCachedThreadPool();
        IntStream.range(0,data.length).forEach(i->{
            executorService.execute(()->{
                System.out.println(Thread.currentThread() + "处理第" + (i + 1) + "条数据");
                int value = data[i];
                if (value % 2 == 0) {
                    data[i] = value * 2;
                } else {
                    data[i] = value * 10;
                }
            });
        });
        // 关闭线程池
        executorService.shutdown();
        //保存数据
        save(data);
    }
    private static int[] query() {
        return new int[]{1,2,3,4,5,6,7,8,9,10};
    }
    private static void save(int[] data) {
        System.out.println("保存数据 = " + Arrays.toString(data));
    }

运行结果:

获取数据完毕
Thread[pool-1-thread-1,5,main]处理第1条数据
Thread[pool-1-thread-2,5,main]处理第2条数据
Thread[pool-1-thread-4,5,main]处理第4条数据
Thread[pool-1-thread-3,5,main]处理第3条数据
Thread[pool-1-thread-5,5,main]处理第5条数据
Thread[pool-1-thread-6,5,main]处理第6条数据
保存数据 = [10, 4, 30, 8, 50, 12, 70, 8, 90, 10]
Thread[pool-1-thread-9,5,main]处理第9条数据
Thread[pool-1-thread-7,5,main]处理第7条数据
Thread[pool-1-thread-8,5,main]处理第8条数据
Thread[pool-1-thread-10,5,main]处理第10条数据

可以看出由于CPU时间片的不确定性,数据还没处理完,就已经执行保存了。
使用CountDownLatch来解决这个问题

public class CountDownLatchTest {

    //创建一个定长线程池,超出的线程会在队列中等待
    private static ExecutorService executorService = Executors.newFixedThreadPool(2);
    //创建一个CountDownLatch,计数器为10
    private static CountDownLatch latch = new CountDownLatch(10);

    public static void main(String[] args)throws InterruptedException{

        //获取数据
        int[] data = query();
        System.out.println("获取数据完毕");
        //处理数据
        IntStream.range(0,data.length).forEach(i->{
            executorService.execute(()->{
                System.out.println(Thread.currentThread() + "处理第" + (i + 1) + "条数据");
                int value = data[i];
                if (value % 2 == 0) {
                    data[i] = value * 2;
                } else {
                    data[i] = value * 10;
                }
                //每个线程执行完毕后调用countDown()方法,计数-1
                latch.countDown();
            });
        });
        //阻塞当前线程,直到计数器减至为零,这样就保证了只有等待所有的数据处理完毕后才执行保存操作
        latch.await();
        //latch.await(1, TimeUnit.MINUTES); // 最多等待1分钟,1分钟没有处理完数据将继续执行保存操作
        //关闭线程池
        executorService.shutdown();
        //保存数据
        save(data);
    }
    private static int[] query() {
        return new int[]{1,2,3,4,5,6,7,8,9,10};
    }
    private static void save(int[] data) {
        System.out.println("保存数据 = " + Arrays.toString(data));
    }
}

运行结果:

获取数据完毕
Thread[pool-1-thread-2,5,main]处理第2条数据
Thread[pool-1-thread-1,5,main]处理第1条数据
Thread[pool-1-thread-1,5,main]处理第4条数据
Thread[pool-1-thread-1,5,main]处理第5条数据
Thread[pool-1-thread-1,5,main]处理第6条数据
Thread[pool-1-thread-1,5,main]处理第7条数据
Thread[pool-1-thread-2,5,main]处理第3条数据
Thread[pool-1-thread-1,5,main]处理第8条数据
Thread[pool-1-thread-2,5,main]处理第9条数据
Thread[pool-1-thread-1,5,main]处理第10条数据
保存数据 = [10, 4, 30, 8, 50, 12, 70, 16, 90, 20]

上述代码中使用CountDownLatch解决等待所有数据处理完之后再进行保存操作。

CountDownLatch与CyclicBarrier的区别

CountDownLatch是一个计数器,线程完成一个记录一个,计数器递减直至为0,只能使用一次
CyclicBarrier的计数器相当于一个阀门,需要所有线程完成,然后继续执行,计数器递减提供reset功能
简单来说,CountDownLatch是一次性使用的,CyclicBarrier可循环使用。

0 条评论

还没有人发表评论

发表评论 取消回复

记住我的信息,方便下次评论
有人回复时邮件通知我