ThreadLocal源码分析以及线程池下的使用问题

前言

如果我想要在多线程下, 同一个变量在不同线程下使用不同的值, 如何去做?
我可以声明一个Map, 线程作为Key(实际上并不是这样设计的), Map作为value, 存储当前线程下的键值对.
下面是一个简单的例子(生产不要这样用)

1
2
3
4
5
6
7
8
9
public class ThreadMap {
public static Map<Thread, Map<String, Object>> map = new HashMap<>();
public static void set(String key, Object value) {
map.get(Thread.currentThread()).put(key, value);
}
public static Object get(String key) {
return map.get(Thread.currentThread()).get(key);
}
}

是不是很简单? 但是我们不需要去重复造轮子.
JDK早就给我们实现了一套轮子, 本质和我写的这个Demo是一样的.

ThreadLocal 实现线程隔离

测试用例

我们来看个简单的单元测试例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class ThreadLocalTest {
@Test
public void threadLocal1() throws Exception {
ThreadLocal<String> context = new ThreadLocal<>();

Assertions.assertNull(context.get());
String expectP = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS);
context.set(expectP);

Runnable runnable = () -> {
Assertions.assertNull(context.get());
String expect = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS);
context.set(expect);
String value = context.get();
System.out.println(Thread.currentThread().getName() + " get " + value);
Assertions.assertEquals(expect, value);
};

Thread thread1 = new Thread(runnable);
thread1.start();
thread1.join();

Thread thread2 = new Thread(runnable);
thread2.start();
thread2.join();

String value = context.get();
System.out.println(Thread.currentThread().getName() + " get " + value);
Assertions.assertEquals(expectP, value);
}
}

执行完毕输出

1
2
3
Thread-1 get hello 2020-04-19 00:54:44:249
Thread-2 get hello 2020-04-19 00:54:44:254
main get hello 2020-04-19 00:54:44:230

可以看到, 明明是同一个变量, 在不同线程下, 获取到的值却不一样.

get 和 set 源码分析

关键点就在于setget方法. 我们现在先来看set方法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// java.lang.ThreadLocal
public class ThreadLocal<T> {
public void set(T value) {
// 1. 获取当前线程内的 ThreadLocalMap
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 2. 往 Map 塞值, 以 ThreadLocal为key
if (map != null) map.set(this, value);
else createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
}

可以看到和我一开始写的一样, 就是往一个Map里塞数据.
再看看get方法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// java.lang.ThreadLocal
public class ThreadLocal<T> {
public T get() {
// 1. 获取当前线程内的 ThreadLocalMap
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 2. 根据 key 获取 value
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
return (T) e.value;
}
}
// 4. 找不到就返回一个默认值, 默认为 null
return setInitialValue();
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
}

也一样, 就是从一个Map里, 把ThreadLocal作为key, 获取数据.

那么不同点来了, 就是这个Map的构造, 可以看到不是我用的HashMap.
而是当前线程里的ThreadLocalMap实例变量threadLocals.

ThreadLocalMap 源码分析

ThreadLocalMapThreadLocal的静态内部类, 而每一个线程Thread都有各自的ThreadLocalMap成员变量.

1
2
3
4
// java.lang.Thread
public class Thread implements Runnable {
ThreadLocal.ThreadLocalMap threadLocals = null;
}

那我们继续来看ThreadLocalMap. 和普通的HashMap一样, 内部是一个节点数组.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// java.lang.ThreadLocal
public class ThreadLocal<T> {
static class ThreadLocalMap {
// 1. 节点数组
private Entry[] table;
// 2. 弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
}
}

可以看到ThreadLocalMap是以ThreadLocal对象的弱引用为Key, ObjectValueMap集合.
Entry继承了弱引用WeakReference, 所以当每次GC发生时, 扫描到这个对象时, 如果没有强引用持有这个对象, 就会回收ThreadLocal. 但这里会导致一些内存问题, 这里后面讲.

接下来是ThreadLocalMapsetget方法, 和普通的HashMap差不多, 无非就是扩容, rehash那一套.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public class ThreadLocal<T> {
static class ThreadLocalMap {
private void set(ThreadLocal<?> key, Object value) {
// 1. 对 key 进行 hash, 获取 Entry 节点的下标
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

// 2. 如果 hash 冲突, 就 开放定址法 找到下一个下标
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
// 2.1. 清除 null key 的节点
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 3. 往下标位置塞数据, 扩容, rehash
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
private Entry getEntry(ThreadLocal<?> key) {
// 1. 对 key 进行 hash, 获取 Entry 节点的下标
int i = key.threadLocalHashCode & (table.length - 1);
// 2. 从节点数组获取数据
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 开放定址法获取数据
return getEntryAfterMiss(key, i, e);
}
}
}

值得注意的是, ThreadLocalMap的哈希冲突解决方案是开放定址法, 而不是HashMap的链表红黑树那套.

ThreadLocal 为什么必须设置为 static ?

经常会听到一种言论, ThreadLocal必须设置为static静态变量, 否则会浪费内存的问题.

我们先来看一个单元测试例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class ThreadLocalTest {
public static final ThreadLocal<String> staticContext = new ThreadLocal<>();
@Test
public void staticVerify() throws Exception {
// 1. 静态ThreadLocal 设置值
Assertions.assertNull(staticContext.get());
staticContext.set(Thread.currentThread().getName());

Runnable runnable = () -> {
// 2. 普通ThreadLocal 设置值
ThreadLocal<String> context = new ThreadLocal<>();
Thread thread = Thread.currentThread();

staticContext.set(thread.getName());
context.set(thread.getName());

System.out.println(thread.getName() + " static get " + staticContext.get());
System.out.println(thread.getName() + " get " + context.get()); // 打断点
};

Thread thread1 = new Thread(runnable);
thread1.start();
thread1.join();

Thread thread2 = new Thread(runnable);
thread2.start();
thread2.join();

String value = staticContext.get();
Thread thread = Thread.currentThread();
System.out.println(thread.getName() + " static get " + staticContext.get());
Assertions.assertEquals(thread.getName(), value);
}
}

我们打个断点看看Thread里的ThreadLocalMapKey引用很容易就发现问题.

当线程执行的时候, 线程的ThreadLocalMap持有两个Entry, 一个是静态的ThreadLocal, 一个是new出来的ThreadLocal.
可以看到,静态的ThreadLocalThreadLocalMap是固定的hashcode, 而另一个ThreadLocal是不同的hashcode,说明每次运行都会创建ThreadLocal
虽然并不影响正常使用,但是用static修饰ThreadLocal可以减少对象的频繁创建,降低GC频率, 如果没有特殊要求, 还是用static最佳.

另外还有一点需要注意的是, ThreadLocal如果设置成static了, 就会被当前类对象强引用, 下面的内存泄漏问题会提到.

ThreadLocal的 内存泄漏 问题

ThreadLocal引用链

ThreadLocalMapKeyThreadLocal的弱引用对象.
不管ThreadLocal对象是否因为弱引用被GC回收, Entry节点都会持有value的强引用.
要回收掉value只有两种方法

  1. 结束当前线程
  2. 手动释放KeynullEntry, 常用的是remove方法

ThreadLocalsetgetremove方法内都有对KeynullEntry做清除引用的操作.

那为什么EntryKey要做弱引用, 不直接强引用呢?
我们反推一下, 如果EntryKey是强引用, 当ThreadLocal的外部强引用取消后, ThreadLocalMap内部的Entry还持有ThreadLocal的强引用.
那么ThreadLocal就一直不能被GC回收, 需要手动remove才能回收.
反之, 如果是弱引用, 那么ThreadLocal的外部强引用取消后, 因为Entry持有的是ThreadLocal的弱引用, 当发生GC时, 能及时回收掉ThreadLocal.
在下次setgetremove方法做清除keynullvalue的操作.

总结

Thread线程有一个ThreadLocalMap成员变量, 这是一个存储了以ThreadLocalKey, 值为ValueMap集合.
ThreadLocalsetget方法, 本质是获取当前线程的ThreadLocalMap, 对这个Map进行操作, 做到线程隔离.

下面是一个简单的最佳实践.

1
2
3
4
5
6
7
8
9
10
11
public class Main {
public static final ThreadLocal<String> local = new ThreadLocal<>();
public void test() {
local.set("hello");
try {
// 业务操作
} finally{
local.remove();
}
}
}

InheritableThreadLocal 实现父子线程变量共享

那如果我想要创建一个线程, 然后子线程继承父线程的变量, 要怎么做呢?
你可以看到在上面的单元测试用例中, 我使用了一个断言Assertions.assertNull(context.get());.

也就是说, 如果单单使用ThreadLocal, 子线程是不能获取到父线程的变量的.
其实这也很好理解, 线程都不一样, 线程对象里面的Map变量当然也不一样.

如果让我们来实现, 也很容易. 只要在创建线程的时候, 将父线程的ThreadLocalMap复制一份到子线程即可.

1
2
3
4
5
6
7
8
9
10
11
public static class ThreadFactory {
public static Thread createThread() {
Thread parentThread = Thread.currentThread();
Thread childThread = new Thread();

// 复制 ThreadLocalMap
childThread.threadLocals = parentThread.threadLocals.clone();

return childThread;
}
}

JDK内部已经实现了, 我们只要直接使用InheritableThreadLocal即可.

InheritableThreadLocal 源码分析

InheritableThreadLocal继承了ThreadLocal, 并重写了ThreadLocalMap的相关方法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// java.lang.InheritableThreadLocal
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
protected T childValue(T parentValue) {
return parentValue;
}

ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}

void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
// java.lang.ThreadLocal
public class ThreadLocal<T> {
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
}

我们可以看到, InheritableThreadLocal就是换了个变量操作, 它操作的是inheritableThreadLocals变量. 避免和ThreadLocal操作的变量冲突.
代码就短短几行, 没有涉及到我们说的**复制ThreadLocalMap**的操作.
那我们继续看下Thread的创建过程.

Thread 创建

我们直接追源码, 这里省略部分非核心代码.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// java.lang.Thread
public class Thread implements Runnable {
public Thread() {
init(null, null, "Thread-" + nextThreadNum(), 0);
}
private void init(ThreadGroup g, Runnable target, String name, long stackSize) {
init(g, target, name, stackSize, null, true);
}
private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) { // 注意这个 inheritThreadLocals 变量, 默认为 true
// 1. 获取创建Thread变量的线程, 作为父线程
Thread parent = currentThread();

// 2. 如果父线程有 inheritableThreadLocals 的值, 那么就复制一份到子线程
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals = new ThreadLocalMap(parent.inheritableThreadLocals);
}
}

一目了然, 就是在new Thread()的时候, 从父线程中复制一份ThreadLocalMap到子线程.

TransmittableThreadLocal 线程池子任务共享父线程的变量

InheritableThreadLocal解决了父子线程变量共享的问题, 但生产环境下, 我们一般都是用线程池来管理线程的.
这里就涉及到一个线程复用的问题.
InheritableThreadLocal只有在创建线程的时候才会将父线程的变量复制给子线程, 但是线程池的线程是复用的.
new Thread()之后, 可能会执行多个不同的任务. 这个时候就不能继承父线程的变量了.

还是老套路, 既然线程是复用的, 那我能不能在创建任务的时候, 存一份父线程数据, 在run()执行之前丢到子线程里去呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class RunnableWrapper implements Runnable {
// 1. 创建一个 ThreadLocal, 用于暂存创建Runnable时父线程的数据, 在外部直接使用 RunnableWrapper.holder.set();
public static ThreadLocal<Map<String, Object>> holder = ThreadLocal.withInitial(HashMap::new);

private Map<String, Object> value;
private Runnable runnable;
public RunnableWrapper(Runnable runnable) {
this.runnable = runnable;
// 2. 创建任务的时候, 暂存父线程的数据
this.value = holder.get();
}

@Override
public void run() {
// 3. 在子线程执行任务前, 将父线程的数据存入子线程
holder.set(value);
runnable.run();
}
}

这里JDK可没有提供轮子了, 但是Alibaba提供了一个transmittable-thread-local的轮子.

测试用例

我们来看个单元测试例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class TtlTest {
@Test
public void wrapper() throws Exception {
// 1. 创建 线程池 和 TransmittableThreadLocal
ExecutorService executorService = Executors.newSingleThreadExecutor();
TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();

String expect = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS);
context.set(expect);
System.out.println("[parent thread] set " + context.get());
Assertions.assertEquals(expect, context.get());

// 2. 装饰 Runnable
Runnable task = () -> {
System.out.println("[child thread] get " + context.get() + " in Runnable");
Assertions.assertEquals(expect, context.get());
};
TtlRunnable ttlRunnable = TtlRunnable.get(task);
executorService.submit(ttlRunnable).get();

// 3. 关闭线程池
executorService.shutdown();
}
}

执行完毕输出

1
2
[parent thread] set hello 2020-04-19 00:54:44:230
[child thread] get hello 2020-04-19 00:54:44:230 in Runnable

可以看到, 尽管是在线程池里复用线程的情况, 依然能获取到父线程的变量.
如果觉得我只用一个Runnable样本不够, 可以自己多创建几百个Runnable.

陌生的类只有两个, TtlRunnableTransmittableThreadLocal.

TransmittableThreadLocal 的 set get 源码分析

我们先来看看set方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// com.alibaba.ttl.TransmittableThreadLocal
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
public final void set(T value) {
if (!disableIgnoreNullValueSemantics && null == value) {
// may set null to remove value
remove();
} else {
// 1. 调用 InheritableThreadLocal 的 set 方法, 保证子线程也能拿到数据
super.set(value);
// 2. 添加到 holder, 保证子任务能拿到数据
addThisToHolder();
}
}
private void addThisToHolder() {
// 判断 holder 没有当前 ThreadLocal 则 put 一份
if (!holder.get().containsKey(this)) {
holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
}
}

// 等价于 Map<线程, Map<ThreadLocal, null>>, 因为没有弱引用的Set, 所以用 WeakHashMap 代替
private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
}

@Override
protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
}
};
}

holder存储了当前线程下的所有TransmittableThreadLocal对象, 并用Set去重.
有人看到这里可能有疑问, 上面没看到Set关键字. 是因为JDK没有弱引用的Set, 所以用WeakHashMap代替.

再来看看get方法, 也只是多了个addThisToHolder方法.

1
2
3
4
5
6
7
8
// com.alibaba.ttl.TransmittableThreadLocal
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
public final T get() {
T value = super.get();
if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
return value;
}
}

TtlRunnable 源码分析 保存快照

既然上面设置了参数, 下面就要获取这些参数了.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// com.alibaba.ttl.TtlRunnable
public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
private final AtomicReference<Object> capturedRef;
private final Runnable runnable;
private final boolean releaseTtlValueReferenceAfterRun;
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
// 1. 保存快照
this.capturedRef = new AtomicReference<Object>(TransmittableThreadLocal.Transmitter.capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
@Override
public void run() {
// 2. 获取快照
Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}

// 3. 保存到当前线程
Object backup = TransmittableThreadLocal.Transmitter.replay(captured);
try {
runnable.run();
} finally {
TransmittableThreadLocal.Transmitter.restore(backup);
}
}
}

顺着开头提到的思路, 在Runnable创建时, 先对父线程的变量做一个快照, 在run执行时, 复制到新线程里.

那么第一步, 做快照, 关键就是这个TransmittableThreadLocal.Transmitter.capture()方法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
public static class Transmitter {
public static Object capture() {
// 1. captureTtlValues() 对当前线程下的所有 TransmittableThreadLocal 做一个快照
// 2. TODO captureThreadLocalValues() 这个操作的是另一个变量, 目前这个场景没用到
return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue()); // 对当前线程下的所有 TransmittableThreadLocal 做一个快照
}
return ttl2Value;
}
private static class Snapshot {
final WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value;
private Snapshot(WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
this.ttl2Value = ttl2Value;
this.threadLocal2Value = threadLocal2Value;
}
}
}
}

第二步, 执行run方法的时候, 取出快照, 并保存到当前线程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
public static class Transmitter {
@NonNull
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
// 省略部分代码

// 精华! 从快照中获取 ThreadLocal 再重新 set 到当前线程
for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : captured.entrySet()) {
TransmittableThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}

// 省略部分代码
return backup;
}
}
}

老实说, 看到这里的代码, 不由让我惊叹. 这个for循环, 太神了.
如此一来, 就将之前保存的ThreadLocal快照, 在执行这个Runnable的线程重新复制了一遍.

总结

ThreadLocal算是高频面试考点, 并且一个用的不小心就会造成线上生产问题. 使用的时候只要注意套模版就可以了.

  1. static修饰ThreadLocal
  2. set之后要用try {} finally {}remove操作
  3. 线程池要用TransmittableThreadLocal来传递变量.

另外, Netty还有一个FastThreadLocal的东西, 不过没有涉及本篇的线程变量传递的主题, 所以就不讲了.

参考资料