手写简化版阿里TTL:解决线程池父子线程上下文传递难题

大家好!今天咱们要动手实现一个简化版的阿里TTL(Transmittable Thread Local),核心目标是解决线程池中父子线程的上下文传递问题。在开始编码前,咱们先梳理下Java原生的线程上下文工具,搞清楚TTL要解决的核心痛点。

一、先复习:Java原生的ThreadLocal与InheritableThreadLocal

在讲TTL之前,必须先搞懂两个基础工具——ThreadLocalInheritableThreadLocal,这是TTL的技术基石。

1. ThreadLocal:线程内的“全局变量”

ThreadLocal是Java原生的线程上下文存储工具,作用是让变量在单个线程内全局可见,不同线程间互不干扰。比如在主线程中set一个值,在主线程的其他方法里能直接get到,代码示例很直观:

// 初始化ThreadLocal
private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

public static void main(String[] args) {
    // 主线程set值
    threadLocal.set("主线程上下文");
    // 主线程内其他方法get值
    printContext(); // 输出:主线程上下文
}

private static void printContext() {
    System.out.println(threadLocal.get());
}

但它有个致命问题:父子线程间无法传递。如果新建一个子线程,子线程里根本拿不到主线程ThreadLocal中的值。

2. InheritableThreadLocal:父子线程的“传递桥梁”

为了解决ThreadLocal的父子传递问题,Java提供了InheritableThreadLocal。它的核心逻辑是:在新建子线程的初始化阶段,会自动把父线程InheritableThreadLocal中的所有上下文,拷贝到子线程中。同样用代码示例看效果:

// 初始化InheritableThreadLocal
private static final InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

public static void main(String[] args) {
    // 主线程set值
    inheritableThreadLocal.set("主线程上下文");
    
    // 新建子线程
    new Thread(() -> {
        // 子线程get值
        System.out.println(inheritableThreadLocal.get()); // 输出:主线程上下文
    }).start();
}

看起来很完美?但它有个更致命的局限——线程池场景下完全失效

3. 核心痛点:InheritableThreadLocal在线程池中失效

为什么线程池里会失效?关键在于InheritableThreadLocal的传递时机:只在Thread对象初始化时传递一次

而线程池的核心是“线程复用”,线程池中的线程只在初始化时会拷贝一次父线程上下文,但后续复用这个线程执行新任务时,不会再重新拷贝上下文。比如:

// 线程池(核心线程数1,复用线程)
private static final ExecutorService executor = Executors.newFixedThreadPool(1);
private static final InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

public static void main(String[] args) {
    // 第一次提交任务:主线程set值
    inheritableThreadLocal.set("第一次任务上下文");
    executor.submit(() -> {
        System.out.println(inheritableThreadLocal.get()); // 输出:第一次任务上下文(线程初始化时拷贝)
    }).get();
    
    // 第二次提交任务:主线程更新值
    inheritableThreadLocal.set("第二次任务上下文");
    executor.submit(() -> {
        System.out.println(inheritableThreadLocal.get()); // 输出:第一次任务上下文(线程复用,未重新拷贝)
    }).get();
}

第二次任务明明主线程更新了上下文,但子线程拿到的还是第一次的旧值——这就是线程池场景下 InheritableThreadLocal 的核心问题,也是咱们要实现TTL的原因。

4. 补充:ThreadLocal的底层原理

在动手前,先搞懂 TL 家族的底层逻辑:每个Thread对象内部维护了两个Map:

  • threadLocals:存储ThreadLocal的键值对;
  • inheritableThreadLocals:存储InheritableThreadLocal的键值对。

这两个 Map 的 Key 是 ThreadLocal(或子类)对象,Value是我们存储的上下文,且Key采用弱引用(WeakReference),避免内存泄漏——当外部不再引用ThreadLocal时,GC会自动回收对应的Value。

这两个 Map 在 Thread 对象的内部,所以能在线程对象的生命周期内生效。

二、TTL的核心设计思路:从“线程生命周期”转向“任务生命周期”

TTL的本质是增强InheritableThreadLocal,核心改动是把上下文传递的时机,从“线程初始化”改成“任务初始化”。

具体怎么理解?

  • 原InheritableThreadLocal:传递时机是new Thread()时(线程生命周期);
  • 我们要做的TTL:传递时机是 new Runnable() / new Callable()时(任务生命周期)。

也就是说,每次向线程池提交任务时,都主动把当前父线程的上下文“快照”下来,塞到任务里;等线程执行这个任务时,再把“快照”恢复到当前线程中;任务执行完后,再把线程原来的上下文还原——这样既不影响线程复用,又能保证每个任务拿到最新的父线程上下文。

三、动手实现简化版TTL:核心步骤拆解

接下来咱们一步步写代码,实现一个能支持线程池的TTL。整体分为3个核心部分:

1. 第一步:定义TTL的核心存储容器——TransmittableThreadLocal

我们需要一个自定义的ThreadLocal子类,命名为TransmittableThreadLocal(简称TTL)。它的核心作用是:

  • 继承InheritableThreadLocal,保留“新建线程时传递”的能力;
  • 内部维护一个“TTL专属Map”,存储所有TTL的键值对(避免和原生InheritableThreadLocal冲突)。

首先,先定义这个“TTL专属Map”的容器——用一个静态的 InheritableThreadLocal,值是WeakHashMap<TransmittableThreadLocal<?>, Object>(弱引用Map,仿照原生设计):

public class TTL<T> extends InheritableThreadLocal<T>{

    private static final InheritableThreadLocal<Map<TTL<Object>,Boolean>> ttlMap = new InheritableThreadLocal<Map<TTL<Object>,Boolean>>(){
        @Override
        protected Map<TTL<Object>, Boolean> initialValue() {
            return new WeakHashMap<>();
        }
    };
    
    @SuppressWarnings("unchecked")
    private void addToMap(){
        ttlMap.get().put((TTL<Object>) this,Boolean.TRUE);
    }
    
    private void removeFromMap(){
        ttlMap.get().remove(this);
    }

    // 重写 ThreadLocal 的 get,set,remove 操作,因为本质也是 ThreadLocal,所以也需要在这些方法被调用时进行父类方法的调用。
    
    @Override
    public T get() {
        addToMap(); // 根据 ThreadLocal 的逻辑,只是 get 也需要添加进 ThreadLocalMap 里。
        return super.get();
    }

    @Override
    public void set(T value) {
        addToMap();
        super.set(value);
    }

    @Override
    public void remove() {
        removeFromMap();
        super.remove();
    }
}

用 InheritableThreadLocal 实现 TTL 的 ThreadLocalMap 有两个好处:

  1. InheritableThreadLocal 本身有原生实现的 ThreadLocalMap,我们相当于在 InheritableThreadLocal 的 ThreadLocalMap 中嵌套了一个我们自己 Map,避免修改 JDK 源码。
  2. 在 JDK 源码中,InheritableThreadLocal 的生命周期是耦合进 Thread 的生命周期的(在 Thread.java)中,且由于 ttlMap 是被 static final 修饰的,首次初始化在 JVM 主线程中,保证了我们的 ThreadLocalMap(ttlMap) 一定会在所有线程中初始化。

2. 第二步:实现上下文的“捕获-回放-恢复”工具类

接下来需要一个工具类,负责3个核心操作:

  • 捕获(capture):把当前线程的TTL上下文“快照”下来;
  • 回放+备份(replayAndBackup):把“快照”的上下文恢复到当前线程,并备份线程原来的上下文;
  • 恢复(restore):任务执行完后,把线程原来的上下文还原回去。

代码实现如下:

package priv.dawn.lab.ttl;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.WeakHashMap;

/**
 * Created with IntelliJ IDEA.
 * Description:
 *
 * @author Dawn Yang
 * @since 2025/08/02/17:33
 */
public class TTL<T> extends InheritableThreadLocal<T>{

    private static final InheritableThreadLocal<Map<TTL<Object>,Boolean>> ttlMap = new InheritableThreadLocal<Map<TTL<Object>,Boolean>>(){
        @Override
        protected Map<TTL<Object>, Boolean> initialValue() {
            return new WeakHashMap<>();
        }
    };

    @SuppressWarnings("unchecked")
    private void addToMap(){
        ttlMap.get().put((TTL<Object>) this,Boolean.TRUE);
    }

    private void removeFromMap(){
        ttlMap.get().remove(this);
    }

    @Override
    public T get() {
        addToMap();
        return super.get();
    }

    @Override
    public void set(T value) {
        addToMap();
        super.set(value);
    }

    @Override
    public void remove() {
        removeFromMap();
        super.remove();
    }

    // 新增代码在这:
    public static class Transmitter {

        // capture 捕获上下文
        public static Map<TTL<Object>,Object> capture(){
            // 获取当前线程的的所有 TTL
            Set<TTL<Object>> ttlSet = ttlMap.get().keySet();
            Map<TTL<Object>,Object> capture = new HashMap<>();
            for (TTL<Object> ttl : ttlSet) {
                // 将捕获的上下文写入到快照中
                capture.put(ttl,ttl.get());
            }
            // 返回快照
            return capture;
        }

        // replay 重放上下文,将捕获出的 ttlMap 快照写入到当前线程的 ttlMap 中
        public static Map<TTL<Object>,Object> replay(Map<TTL<Object>,Object> capture){
            // 备份当前线程的 ttlMap 以方便恢复,因为会出现线程任务套线程任务的情况。
            Map<TTL<Object>, Object> backup = capture();
            // 清空当前线程的 ttlMap
            ttlMap.get().keySet().forEach(TTL::remove); // 避免内存泄漏
            ttlMap.get();
            for (Map.Entry<TTL<Object>, Object> entry : capture.entrySet()) {
                // 将每个 TTL 都拿出来调用一次 set 方法,由于 TTL 本质是 ThreadLocal,也就写入到了 Thread 的 ThreadLocalMap 中,后面再用 get 方法就能获取到了。
                entry.getKey().set(entry.getValue());
            }
            return backup;
        }

        // restore 恢复当前线程的 ttlMap
        public static void restore(Map<TTL<Object>,Object> backup){
            // 清空当前线程的 ttlMap
            ttlMap.get().keySet().forEach(TTL::remove); // 避免内存泄漏
            ttlMap.get().clear();
            for (Map.Entry<TTL<Object>, Object> entry : backup.entrySet()) {
                // 将备份的值重新 set 回去
                entry.getKey().set(entry.getValue());
            }
        }
    }

}

3. 第三步:包装任务(Runnable/Callable)——核心!

线程池提交的任务是Runnable或Callable,我们需要对这两个接口进行包装,让任务在执行前后自动完成“捕获-回放-恢复”的流程。

以Runnable为例,实现RunnableWrapper

public class RunnableWrapper implements Runnable{
    // 在对象初始化时捕获当前线程的TTL上下文,最可靠的方案
    private final Map<TTL<Object>,Object> capture = TTL.Transmitter.capture();
    private final Runnable runnable;

    public RunnableWrapper(Runnable runnable) {
        this.runnable = runnable;
    }


    @Override
    public void run() {
        // 保存当前线程的原有TTL上下文作为备份
        Map<TTL<Object>, Object> backup = TTL.Transmitter.replay(capture);
         try {
            // 执行被包装的任务
            runnable.run();
        } finally {
            // 确保能恢复线程原有TTL上下文
            TTL.Transmitter.restore(backup);
        }
    }
}

同理,Callable的包装TtlCallable逻辑完全一致,只是返回值不同:

public class CallableWrapper<T> implements Callable<T> {
    // 在对象初始化时捕获当前线程的TTL上下文
    private final Map<TTL<Object>, Object> capture = TTL.Transmitter.capture();
    // 被包装的Callable对象
    private final Callable<T> callable;

    // 构造方法,接收需要包装的Callable
    public CallableWrapper(Callable<T> callable) {
        this.callable = callable;
    }

    @Override
    public T call() throws Exception {
        // 保存当前线程的原有TTL上下文作为备份
        Map<TTL<Object>, Object> backup = TTL.Transmitter.replay(capture);
        try {
            // 执行被包装的Callable的call方法
            return callable.call();
        } finally {
            // 确保恢复线程原有TTL上下文
            TTL.Transmitter.restore(backup);
        }
    }
}

四、测试:验证TTL在_threadPool_中是否生效

代码写完了,咱们用线程池场景测试一下,看看是否解决了InheritableThreadLocal的问题。

测试代码:

public class TtlTest {
    // 1. 初始化我们自己的TTL
    private static final TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();
    // 2. 线程池(核心线程数1,复用线程)
    private static final ExecutorService executor = Executors.newFixedThreadPool(1);

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        // ------------------------------ 第一次提交任务 ------------------------------
        ttl.set("第一次任务:父线程上下文");
        // 包装Runnable任务并提交
        executor.submit(TtlRunnable.wrap(() -> {
            System.out.println("第一次任务结果:" + ttl.get()); // 预期:第一次任务:父线程上下文
        })).get();

        // ------------------------------ 第二次提交任务 ------------------------------
        ttl.set("第二次任务:父线程上下文");
        // 再次提交包装后的任务
        executor.submit(TtlRunnable.wrap(() -> {
            System.out.println("第二次任务结果:" + ttl.get()); // 预期:第二次任务:父线程上下文
        })).get();

        // 关闭线程池
        executor.shutdown();
    }
}

测试结果

第一次任务结果:第一次任务:父线程上下文
第二次任务结果:第二次任务:父线程上下文

完美!即使线程池复用线程,每次提交任务都能拿到最新的父线程上下文,解决了InheritableThreadLocal的失效问题。

五、总结:简化版TTL的核心逻辑

通过上面的实现,我们可以总结出简化版TTL解决线程池上下文传递问题的核心逻辑:

  1. 核心容器设计:自定义TransmittableThreadLocal继承InheritableThreadLocal,保留基础的父子线程传递能力。借用InheritableThreadLocal 维护一个专属的WeakHashMap记录所有的TTL实例,用于记录哪些是需要(快照-回放-恢复)的 ThreadLocal,底层的实现仍然是 ThreadLocal。

  2. 上下文生命周期管理:通过Transmitter工具类实现线程任务生命周期里的三个关键操作:

    • capture():在提交任务时,快照当前线程的所有TTL上下文
    • replay():在任务执行前,将快照的上下文恢复到执行线程,并备份线程原有上下文
    • restore():在任务执行后,将线程原有上下文还原,避免影响后续任务
  3. 任务包装机制:通过包装RunnableCallable,在任务执行的前后自动触发上下文的"回放"和"恢复"操作,实现对业务代码的无侵入性。

通过上面的操纵,我们就可以将 ThreadLocal 的上下文传递的时机从"线程创建时"改为"任务创建时",完美适配线程池的线程复用机制,确保每个任务都能获取到提交时的最新上下文。

在实际项目中,我们也可以包装一个TTL线程池,实现自动包装所有提交的任务:

public class TtlExecutorService extends ThreadPoolExecutor {
    // 构造方法省略...
    
    @Override
    public void execute(Runnable command) {
        super.execute(new RunnableWrapper(command));
    }
    
    @Override
    public <T> Future<T> submit(Callable<T> task) {
        return super.submit(new CallableWrapper<>(task));
    }
}

通过这个简化版的实现,我们理解了阿里TTL的核心原理,它本质上是通过精巧的"快照-回放-恢复"机制,解决了线程池场景下上下文传递的难题,是分布式追踪、日志链路等场景的重要基础工具。