前段时间上班的时候,隔壁同事遇到了一个奇怪的场景,有点复杂和诡异,所以拿出来说一下。

场景

需要向前端提供一个批量处理的接口,这个接口还需要尽量做成同步的,因为这个接口的并发性特别小,QPS约等于零,但是吞吐量有点吓人,每一次会传入一个200条数据的Excel文档,对于这些数据处理有几个要求:

  • 这些数据需要 RPC 调用一个 AI 服务的计算向量的接口,这个算法端的接口每次只能支持传一条数据,每次请求到响应的耗时大概会有 200-400ms 不等。
  • 处理完之后,需要将数据持久化到数据库,数据库层面不是传统的 MySQL 类型的,而是 AI 服务特用的向量数据库,只能支持批量一次性存入大概 25 条数据,耗时 300-400ms 不等(这个也是抽象的要死,贼慢)
  • 这些数据要以事务的形式存入数据库,事务是声明式事务,支持报错时全量回滚。(不知道为什么对事务这么执着,就当是因为数据库有洁癖不允许有脏数据,不然我们直接存乐观版本号好就行了,学GFS至少一次的思想。)

如果要尽量做成同步接口,有没有一个比较好的解决方案?

分析

首先想到一个比较简单的方案:线程池+CountLantch,等 200 多个 RPC 调用并发完成,计数器结束,大概要 10 s(24核48线程恐怖如斯),事务就被唤醒然后批量插入,这个时候也就是批量插入 200 多个数据,需要大概同步插入8次,耗时 2-3s,时间还是有点久的。

那这个时候去分析并发时刻轴就会发现,计数器带来的最大问题就是长尾效应,对于 200 多个 RPC 调用来说,有那么一两个 RPC 调用在网路里面迷路了来晚了就不是什么小概率事件了,整个 RPC 调用的时间反而取决于最慢的几个,这套方案的并发性还是有点不行。

有必要叠个甲,学习交流,这种方案连基建都没几个支持的还是希望大家少遇到,最好的优化就是改技术方案,我们现在只是考虑一下如何带脚镣起舞。

顺着这个思路,我们不难发现,其实当 RPC 调用完成 25 次之后就可以开始进行一次批量插入了,这是一种每一批进行一次集中处理的线程间同步机制,很理所应当地就想到了 CyclicBarrier 这个可循环,带技术,有边界处理手段的同步器。

其实想到这里,基本方案就已经完成了,甚至可以说这已经满足线程同步的要求了,本来可以开开心心回家的时候,突然发现,有第三个要求,声明式事务。

这个事情就显得很抽象了,直接用 CyclicBarrier 可是行不通的哦,因为 CyclicBarrier 的边界执行并不是指定线程的,而是看哪一个调用 dowait() 方法的线程刚好满足了 Parties 的条件,就直接在这个线程上执行的。

显然让这么多线程属于同一个声明式事务是不合理的,因为Spring的声明式事务是绝对线程隔离的。

那么很显然,我们需要对 Barrier 做一些魔改,让边界处理从一个处理者,变成一个通知者,去通知特定的事务线程可以在事务内进行一次批量处理了。这种通知方式,很容易就想到了令牌桶策略,每到达一次 Barrier 就发放一次令牌,事务线程在没有令牌的时候挂起等待,有令牌的时候处理。按照这个思路,很容易就可以设计出来一个合适的批量任务同步器了。

虽然现在说的比较简单,但是写明白这玩意花了笔者一整天,真的是坑又深又多......

看一个已经写完可以实现的代码

public class BatchTaskSynchronizer<E> {
    // 基本信息
    private final int batchSize;
    private final int totalTaskNum;
    private volatile boolean finished = false;
    private final AtomicInteger remainTaskNum;
    private final ArrayBlockingQueue<E> finishedQueue;
    // 令牌桶
    private final AtomicInteger finishedBatchNum = new AtomicInteger(0);
    // 阻塞等待锁,避免自旋
    private final ReentrantLock lock = new ReentrantLock();
    private final Condition waitForBatch = lock.newCondition();
    // 循环倒数同步器
    private final CyclicBarrier barrier;

    public SyncTaskNew(int taskNum, int batchSize) {
        this.totalTaskNum = taskNum;
        this.batchSize = batchSize;
        this.remainTaskNum = new AtomicInteger(taskNum);
        this.finishedQueue = new ArrayBlockingQueue<E>(taskNum);
        this.barrier = new CyclicBarrier(batchSize, this::syncBarrier);
    }

    @SneakyThrows
    private void syncBarrier() {
        int nowRemains = finishedBatchNum.getAndIncrement(); // 发放令牌
        if (nowRemains == 0) {
            try{
                lock.lock();
                waitForBatch.signal(); // 唤醒,假如没有线程等待也无所谓
            }finally {
                lock.unlock();
            }
        }
    }

    @SneakyThrows
    public void finishOne(E task) {
        finishedQueue.add(task);
        try {
            barrier.await();
        } catch (BrokenBarrierException e) {
            // do nothing
        }
    }

    @SneakyThrows
    public List<E> getFinishedBatch() {

        if (isAllFinished())
            return new ArrayList<E>();

        if (finishedBatchNum.get() == 0) {
            // 判断是不是最后一批且数量不够barrier的情况
            if(remainTaskNum.get() < batchSize){ 
                int remain = remainTaskNum.get();
                // 轮询等待剩余任务全部完成
                while(finishedQueue.size()!=remain); 
                List<E> oneBatches = getOneBatch();
                // 轮询等待剩余所有任务线程到达 barrier
                while(barrier.getNumberWaiting()!=remain);
                barrier.reset(); // 释放
                return oneBatches;
            }
            // 等令牌
            try {
                lock.lock();
                waitForBatch.await();
            } finally {
                lock.unlock();
            }
        }
        
        finishedBatchNum.decrementAndGet();
        return getOneBatch();
    }

    private List<E> getOneBatch() {
        List<E> batches = new ArrayList<E>(batchSize);
        for (int i = 0; i < batchSize; i++) {
            E poll = finishedQueue.poll();
            if (Objects.isNull(poll)) break;
            batches.add(poll);
            remainTaskNum.decrementAndGet();
        }
        if (remainTaskNum.get() == 0 && finishedQueue.isEmpty()) {
            this.finished = true;
        }
        return batches;
    }

    public int getTotalTaskNum() {
        return this.totalTaskNum;
    }

    public boolean isAllFinished() {
        return finished;
    }

}

一个典型的使用案例:

@Slf4j
public class AsyncService {

    private final int CORE_NUM = 48;

    private final ThreadPoolExecutor executor = new ThreadPoolExecutor(
        CORE_NUM, 
        CORE_NUM, 
        5, TimeUnit.SECONDS,
        new ArrayBlockingQueue<>(500)
    );

    public void batch(List<Task> taskList, int batchSize) {
        // 实例化一个任务同步器
        BatchTaskSynchronizer<TaskResult> taskSync = new BatchTaskSynchronizer<>(taskNum, batch);
        // 注册 batch 个任务完成时处理线程
        executor.submit(() -> doWhenEnoughBatch(syncTask));
        // 注册任务线程
        for (Task task : taskList)
            executor.submit(() -> doInLoop(taskSync, task));
        // 实际上也可以在这里做同步完成操作
        // doWhenEnoughBatch(syncTask);
    }

    private void doInLoop(BatchTaskSynchronizer<TaskResult> taskSync, Task task) {
        // 执行任务
        TaskResult result = task.work();
        try{
            taskSync.finishOne(result);
        } catch(interruptedexception e) {
            log.warn("任务被中断, task:{}",task);
        }
    }
    
    @Transactional 
    private void doWhenBatch(BatchTaskSynchronizer<TaskResult> taskSync) {
        // ... 做一些初始化的工作 ...
        // log.info("事务开启!");
        while (!taskSync.isAllFinished()) {
            List<TaskResult> resultBatch = taskSync.getFinishedBatch();
            // ... 处理一下 batch ...
            // mapper.insertBatch(resultBatch);
        }
        // ... 做一些收尾的工作 ...
        // log.info("事务结束!");
    }

}