前言
如果我想要在多线程下, 同一个变量在不同线程下使用不同的值, 如何去做?
我可以声明一个Map
, 线程作为Key
(实际上并不是这样设计的), Map
作为value
, 存储当前线程下的键值对.
下面是一个简单的例子(生产不要这样用)
1 2 3 4 5 6 7 8 9
| public class ThreadMap { public static Map<Thread, Map<String, Object>> map = new HashMap<>(); public static void set(String key, Object value) { map.get(Thread.currentThread()).put(key, value); } public static Object get(String key) { return map.get(Thread.currentThread()).get(key); } }
|
是不是很简单? 但是我们不需要去重复造轮子.
JDK
早就给我们实现了一套轮子, 本质和我写的这个Demo
是一样的.
ThreadLocal 实现线程隔离
测试用例
我们来看个简单的单元测试例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| public class ThreadLocalTest { @Test public void threadLocal1() throws Exception { ThreadLocal<String> context = new ThreadLocal<>();
Assertions.assertNull(context.get()); String expectP = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS); context.set(expectP); Runnable runnable = () -> { Assertions.assertNull(context.get()); String expect = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS); context.set(expect); String value = context.get(); System.out.println(Thread.currentThread().getName() + " get " + value); Assertions.assertEquals(expect, value); };
Thread thread1 = new Thread(runnable); thread1.start(); thread1.join();
Thread thread2 = new Thread(runnable); thread2.start(); thread2.join();
String value = context.get(); System.out.println(Thread.currentThread().getName() + " get " + value); Assertions.assertEquals(expectP, value); } }
|
执行完毕输出
1 2 3
| Thread-1 get hello 2020-04-19 00:54:44:249 Thread-2 get hello 2020-04-19 00:54:44:254 main get hello 2020-04-19 00:54:44:230
|
可以看到, 明明是同一个变量, 在不同线程下, 获取到的值却不一样.
get 和 set 源码分析
关键点就在于set
和get
方法. 我们现在先来看set
方法.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| public class ThreadLocal<T> { public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); else createMap(t, value); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } }
|
可以看到和我一开始写的一样, 就是往一个Map
里塞数据.
再看看get
方法.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| public class ThreadLocal<T> { public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); if (e != null) { return (T) e.value; } } return setInitialValue(); } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } }
|
也一样, 就是从一个Map
里, 把ThreadLocal
作为key
, 获取数据.
那么不同点来了, 就是这个Map
的构造, 可以看到不是我用的HashMap
.
而是当前线程里的ThreadLocalMap
实例变量threadLocals
.
ThreadLocalMap 源码分析
ThreadLocalMap
是ThreadLocal
的静态内部类, 而每一个线程Thread
都有各自的ThreadLocalMap
成员变量.
1 2 3 4
| public class Thread implements Runnable { ThreadLocal.ThreadLocalMap threadLocals = null; }
|
那我们继续来看ThreadLocalMap
. 和普通的HashMap
一样, 内部是一个节点数组.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| public class ThreadLocal<T> { static class ThreadLocalMap { private Entry[] table; static class Entry extends WeakReference<ThreadLocal<?>> { Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } } }
|
可以看到ThreadLocalMap
是以ThreadLocal
对象的弱引用为Key
, Object
为Value
的Map
集合.
而Entry
继承了弱引用WeakReference
, 所以当每次GC
发生时, 扫描到这个对象时, 如果没有强引用持有这个对象, 就会回收ThreadLocal
. 但这里会导致一些内存问题, 这里后面讲.
接下来是ThreadLocalMap
的set
和get
方法, 和普通的HashMap
差不多, 无非就是扩容, rehash
那一套.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| public class ThreadLocal<T> { static class ThreadLocalMap { private void set(ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); if (k == key) { e.value = value; return; } if (k == null) { replaceStaleEntry(key, value, i); return; } } tab[i] = new Entry(key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); } private Entry getEntry(ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else return getEntryAfterMiss(key, i, e); } } }
|
值得注意的是, ThreadLocalMap
的哈希冲突解决方案是开放定址法, 而不是HashMap
的链表红黑树那套.
ThreadLocal 为什么必须设置为 static ?
经常会听到一种言论, ThreadLocal
必须设置为static
静态变量, 否则会浪费内存的问题.
我们先来看一个单元测试例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| public class ThreadLocalTest { public static final ThreadLocal<String> staticContext = new ThreadLocal<>(); @Test public void staticVerify() throws Exception { Assertions.assertNull(staticContext.get()); staticContext.set(Thread.currentThread().getName());
Runnable runnable = () -> { ThreadLocal<String> context = new ThreadLocal<>(); Thread thread = Thread.currentThread();
staticContext.set(thread.getName()); context.set(thread.getName());
System.out.println(thread.getName() + " static get " + staticContext.get()); System.out.println(thread.getName() + " get " + context.get()); };
Thread thread1 = new Thread(runnable); thread1.start(); thread1.join();
Thread thread2 = new Thread(runnable); thread2.start(); thread2.join();
String value = staticContext.get(); Thread thread = Thread.currentThread(); System.out.println(thread.getName() + " static get " + staticContext.get()); Assertions.assertEquals(thread.getName(), value); } }
|
我们打个断点看看Thread
里的ThreadLocalMap
的Key
引用很容易就发现问题.
当线程执行的时候, 线程的ThreadLocalMap
持有两个Entry
, 一个是静态的ThreadLocal
, 一个是new
出来的ThreadLocal
.
可以看到,静态的ThreadLocal
在ThreadLocalMap
是固定的hashcode
, 而另一个ThreadLocal
是不同的hashcode
,说明每次运行都会创建ThreadLocal
。
虽然并不影响正常使用,但是用static
修饰ThreadLocal
可以减少对象的频繁创建,降低GC
频率, 如果没有特殊要求, 还是用static
最佳.
另外还有一点需要注意的是, ThreadLocal
如果设置成static
了, 就会被当前类对象强引用, 下面的内存泄漏问题会提到.
ThreadLocal的 内存泄漏 问题
ThreadLocalMap
的Key
是ThreadLocal
的弱引用对象.
不管ThreadLocal
对象是否因为弱引用被GC
回收, Entry
节点都会持有value
的强引用.
要回收掉value
只有两种方法
- 结束当前线程
- 手动释放
Key
为null
的Entry
, 常用的是remove
方法
在ThreadLocal
的set
、get
、remove
方法内都有对Key
为null
的Entry
做清除引用的操作.
那为什么Entry
的Key
要做弱引用, 不直接强引用呢?
我们反推一下, 如果Entry
的Key
是强引用, 当ThreadLocal
的外部强引用取消后, ThreadLocalMap
内部的Entry
还持有ThreadLocal
的强引用.
那么ThreadLocal
就一直不能被GC
回收, 需要手动remove
才能回收.
反之, 如果是弱引用, 那么ThreadLocal
的外部强引用取消后, 因为Entry
持有的是ThreadLocal
的弱引用, 当发生GC
时, 能及时回收掉ThreadLocal
.
在下次set
、get
、remove
方法做清除key
为null
的value
的操作.
总结
Thread
线程有一个ThreadLocalMap
成员变量, 这是一个存储了以ThreadLocal
为Key
, 值为Value
的Map
集合.
ThreadLocal
的set
和get
方法, 本质是获取当前线程的ThreadLocalMap
, 对这个Map
进行操作, 做到线程隔离.
下面是一个简单的最佳实践.
1 2 3 4 5 6 7 8 9 10 11
| public class Main { public static final ThreadLocal<String> local = new ThreadLocal<>(); public void test() { local.set("hello"); try { } finally{ local.remove(); } } }
|
InheritableThreadLocal 实现父子线程变量共享
那如果我想要创建一个线程, 然后子线程继承父线程的变量, 要怎么做呢?
你可以看到在上面的单元测试用例中, 我使用了一个断言Assertions.assertNull(context.get());
.
也就是说, 如果单单使用ThreadLocal
, 子线程是不能获取到父线程的变量的.
其实这也很好理解, 线程都不一样, 线程对象里面的Map
变量当然也不一样.
如果让我们来实现, 也很容易. 只要在创建线程的时候, 将父线程的ThreadLocalMap
复制一份到子线程即可.
1 2 3 4 5 6 7 8 9 10 11
| public static class ThreadFactory { public static Thread createThread() { Thread parentThread = Thread.currentThread(); Thread childThread = new Thread();
childThread.threadLocals = parentThread.threadLocals.clone(); return childThread; } }
|
JDK
内部已经实现了, 我们只要直接使用InheritableThreadLocal
即可.
InheritableThreadLocal 源码分析
InheritableThreadLocal
继承了ThreadLocal
, 并重写了ThreadLocalMap
的相关方法.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| public class InheritableThreadLocal<T> extends ThreadLocal<T> { protected T childValue(T parentValue) { return parentValue; }
ThreadLocalMap getMap(Thread t) { return t.inheritableThreadLocals; } void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue); } }
public class ThreadLocal<T> { ThreadLocalMap getMap(Thread t) { return t.threadLocals; } void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); } }
|
我们可以看到, InheritableThreadLocal
就是换了个变量操作, 它操作的是inheritableThreadLocals
变量. 避免和ThreadLocal
操作的变量冲突.
代码就短短几行, 没有涉及到我们说的**复制ThreadLocalMap
**的操作.
那我们继续看下Thread
的创建过程.
Thread 创建
我们直接追源码, 这里省略部分非核心代码.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| public class Thread implements Runnable { public Thread() { init(null, null, "Thread-" + nextThreadNum(), 0); } private void init(ThreadGroup g, Runnable target, String name, long stackSize) { init(g, target, name, stackSize, null, true); } private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) { Thread parent = currentThread(); if (inheritThreadLocals && parent.inheritableThreadLocals != null) this.inheritableThreadLocals = new ThreadLocalMap(parent.inheritableThreadLocals); } }
|
一目了然, 就是在new Thread()
的时候, 从父线程中复制一份ThreadLocalMap
到子线程.
TransmittableThreadLocal 线程池子任务共享父线程的变量
InheritableThreadLocal
解决了父子线程变量共享的问题, 但生产环境下, 我们一般都是用线程池来管理线程的.
这里就涉及到一个线程复用的问题.
InheritableThreadLocal
只有在创建线程的时候才会将父线程的变量复制给子线程, 但是线程池的线程是复用的.
new Thread()
之后, 可能会执行多个不同的任务. 这个时候就不能继承父线程的变量了.
还是老套路, 既然线程是复用的, 那我能不能在创建任务的时候, 存一份父线程数据, 在run()
执行之前丢到子线程里去呢?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| public class RunnableWrapper implements Runnable { public static ThreadLocal<Map<String, Object>> holder = ThreadLocal.withInitial(HashMap::new);
private Map<String, Object> value; private Runnable runnable; public RunnableWrapper(Runnable runnable) { this.runnable = runnable; this.value = holder.get(); }
@Override public void run() { holder.set(value); runnable.run(); } }
|
这里JDK
可没有提供轮子了, 但是Alibaba
提供了一个transmittable-thread-local
的轮子.
测试用例
我们来看个单元测试例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| public class TtlTest { @Test public void wrapper() throws Exception { ExecutorService executorService = Executors.newSingleThreadExecutor(); TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();
String expect = "hello " + DateHelper.getNow(DateHelper.yyyyMMdd_hhmmssSSS); context.set(expect); System.out.println("[parent thread] set " + context.get()); Assertions.assertEquals(expect, context.get());
Runnable task = () -> { System.out.println("[child thread] get " + context.get() + " in Runnable"); Assertions.assertEquals(expect, context.get()); }; TtlRunnable ttlRunnable = TtlRunnable.get(task); executorService.submit(ttlRunnable).get();
executorService.shutdown(); } }
|
执行完毕输出
1 2
| [parent thread] set hello 2020-04-19 00:54:44:230 [child thread] get hello 2020-04-19 00:54:44:230 in Runnable
|
可以看到, 尽管是在线程池里复用线程的情况, 依然能获取到父线程的变量.
如果觉得我只用一个Runnable
样本不够, 可以自己多创建几百个Runnable
.
陌生的类只有两个, TtlRunnable
和TransmittableThreadLocal
.
TransmittableThreadLocal 的 set get 源码分析
我们先来看看set
方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> { public final void set(T value) { if (!disableIgnoreNullValueSemantics && null == value) { remove(); } else { super.set(value); addThisToHolder(); } } private void addThisToHolder() { if (!holder.get().containsKey(this)) { holder.get().put((TransmittableThreadLocal<Object>) this, null); } } private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() { @Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() { return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(); }
@Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) { return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue); } }; }
|
holder
存储了当前线程下的所有TransmittableThreadLocal
对象, 并用Set
去重.
有人看到这里可能有疑问, 上面没看到Set
关键字. 是因为JDK
没有弱引用的Set
, 所以用WeakHashMap
代替.
再来看看get
方法, 也只是多了个addThisToHolder
方法.
1 2 3 4 5 6 7 8
| public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> { public final T get() { T value = super.get(); if (disableIgnoreNullValueSemantics || null != value) addThisToHolder(); return value; } }
|
TtlRunnable 源码分析 保存快照
既然上面设置了参数, 下面就要获取这些参数了.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments { private final AtomicReference<Object> capturedRef; private final Runnable runnable; private final boolean releaseTtlValueReferenceAfterRun; private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) { this.capturedRef = new AtomicReference<Object>(TransmittableThreadLocal.Transmitter.capture()); this.runnable = runnable; this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun; } @Override public void run() { Object captured = capturedRef.get(); if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) { throw new IllegalStateException("TTL value reference is released after run!"); }
Object backup = TransmittableThreadLocal.Transmitter.replay(captured); try { runnable.run(); } finally { TransmittableThreadLocal.Transmitter.restore(backup); } } }
|
顺着开头提到的思路, 在Runnable
创建时, 先对父线程的变量做一个快照, 在run
执行时, 复制到新线程里.
那么第一步, 做快照, 关键就是这个TransmittableThreadLocal.Transmitter.capture()
方法.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> { public static class Transmitter { public static Object capture() { return new Snapshot(captureTtlValues(), captureThreadLocalValues()); } private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() { WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>(); for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) { ttl2Value.put(threadLocal, threadLocal.copyValue()); } return ttl2Value; } private static class Snapshot { final WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value; final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value; private Snapshot(WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value) { this.ttl2Value = ttl2Value; this.threadLocal2Value = threadLocal2Value; } } } }
|
第二步, 执行run
方法的时候, 取出快照, 并保存到当前线程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> { public static class Transmitter { @NonNull public static Object replay(@NonNull Object captured) { final Snapshot capturedSnapshot = (Snapshot) captured; return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value)); }
@NonNull private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) { WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>(); for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : captured.entrySet()) { TransmittableThreadLocal<Object> threadLocal = entry.getKey(); threadLocal.set(entry.getValue()); } return backup; } } }
|
老实说, 看到这里的代码, 不由让我惊叹. 这个for
循环, 太神了.
如此一来, 就将之前保存的ThreadLocal
快照, 在执行这个Runnable
的线程重新复制了一遍.
总结
ThreadLocal
算是高频面试考点, 并且一个用的不小心就会造成线上生产问题. 使用的时候只要注意套模版就可以了.
static
修饰ThreadLocal
set
之后要用try {} finally {}
做remove
操作
- 线程池要用
TransmittableThreadLocal
来传递变量.
另外, Netty
还有一个FastThreadLocal
的东西, 不过没有涉及本篇的线程变量传递的主题, 所以就不讲了.
参考资料