ThreadLocal的设计精要

内容纲要

1 ThreadLocal 是什么?

ThreadLocal是指线程的本地变量,我们可以通过ThreadLocal去设计只有线程内部才可以访问的变量,该变量是与其他线程所隔离的。

2 ThreadLocal 可以干什么?

ThreadLocal最大的特点就是线程隔离,其他线程无法获取,当前线程ThreadLocal存放的数据,然后就是在ThreadLocal中存放的信息,无论其程序的调用链路有多深,只要是同一个线程,无需参数传递,可以直接获取。

3 ThreadLocal 典型应用

一个比较典型的应用场景就是,MyBatis分页插件中的应用,需要分页的接口,需要先将分页信息放入ThreadLocal中,需要用的时候直接通过ThreadLocal获取,这样就不需要每个接口都传入分页信息,然后再传给分页插件。

file

Spring的声明式事物中,也应用了ThreadLocal来存储事物信息,因为我们只有使用同一个数据库连接,设置事物才会生效,Spring的事物管理器在获取到数据库连接后就会将其与ThreadLocal绑定,事物完成后解除绑定。

4 ThreadLocal 源码解析

ThreadLocal 的源码是相对容易看懂的,我们以ThreadLocal的set方法为入口,开始看其原理

首先set方法会先获取当前线程 t
通过当前线程 t 获取当前线程的ThreadLocalMap对象
判断ThreadLocalMap是否为空

不为空时 为map设置值,key为当前的ThreadLocal对象,value为传入的参数
为空时 调用ThreadLocalMap构造函数初始化对象

public class ThreadLocal<T> {
    ......

    // ThreadLocal设置值
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    // 获取当前线程的ThreadLocalMap对象
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    // 当前线程ThreadLocalMap对象为空时,调用ThreadLocalMap构造函数初始化对象
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
    // ThreadLocalMap为ThreadLocal的静态内部类,为Thread的属性
    static class ThreadLocalMap {
        ......

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
    }

}

通过这一段代码,我们就可以很清楚的分析到ThreadLocal为什么能做到线程间的隔离,因为ThreadLocalMap是Thread的属性,也就是说这些数据是存储在Thread上的,这也是对其名字最好的理解。
然后我们再来看一下这个静态内部类ThreadLocalMap的数据结构,我们可以看到这个Map比HashMap简单许多,它是一个Entry数组,Entry是继承自WeakReference对象的(弱引用,垃圾回收时会清理),并且该Map并没有采用拉链的方式来解决hash冲突,而是拿当前下标,判断寻找下一个符合条件的位置存放。
Entry的类继承关系如下图,其中key是由弱引用,保存在Reference对象中的,通过get()方法获取。

file

public class ThreadLocal<T> {
    ......

    static class ThreadLocalMap {

        ......

       static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        private void set(ThreadLocal<?> key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

        /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
    }
}   

ThreadLocal子类InheritableThreadLocal的设计,它主要是从写其getMap,createMap,childValue方法,使其调用Thread中inheritableThreadLocals属性,inheritableThreadLocals在Thread会继承父线程的inheritableThreadLocals属性,从而实现子父线程间变量的传递。

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

5 ThreadLocal 精妙设计

file

  1. 线程隔离
    ThreadLocal线程本地变量如其名字一样,并没有将数据放在ThreadLocal对象中,而将ThreadLocalMap放在Thread对象上,这样ThreadLocal只是提供操作数据的入口,并不具备实际存储能力,这样就可以做到线程隔离。

  2. 弱引用解决key的内存释放
    ThreadLocal通过弱引用的方式,使得ThreadLocal在外部强引用消失时,可以自动被垃圾回收收集,而不需要去释放ThreadLocalMap中key的引用。

重点:key可以通过弱引用被释放,那么value如何处理呢?

尤其是在使用线程池,线程会被复用,如果不能有效的回收ThreadLocalMap,那么很容易出现脏数据,甚至内存溢出,当我们在使用完毕ThreadLocal后,一定要去调用其remove方法。

    public static void main(String[] args) {
        new Thread(()-> {
            try {
                threadLocal.set("hello");
                threadLocal.get();
            } finally {
                threadLocal.remove();
            }
        }).start();
    }
  1. 父子线程传递ThreadLocal
    虽然说ThreadLocal是线程隔离的,但是理论上子线程想要获取父线程中设置的值是可以的,这时我们就可以通过InheritableThreadLocal来实现

InheritableThreadLocal是ThreadLocal的子类,其重写了创建createMap、getMap、childValue等方法,在Thread类中有两个成员变量一个是threadLocals,一个是inheritableThreadLocals,在线程创建的时候回判断父线程是否具有inheritableThreadLocals,有的化会继承下来,从而实现父线程向子线程ThreadLocalMap的传递。

6 ThreadLocal 测试案例

下面是ThreadLocal最简单的使用方法,大家可以自行跟踪源码进行验证。

package com.zhj.interview;

public class Test18 {

    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();
    private static ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        inheritableThreadLocal.set("hello main");
        new Thread(()-> {
            try {
                threadLocal.set("hello");
                System.out.println( Thread.currentThread().getName() + threadLocal.get());
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
                inheritableThreadLocal.set("hello thread1");
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
            } finally {
                threadLocal.remove();
            }
        },"thread1 - ").start();

        new Thread(()-> {
            try {
                Thread.sleep(10);
                System.out.println( Thread.currentThread().getName() + inheritableThreadLocal.get());
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                threadLocal.remove();
            }
        }, "thread2 - ").start();
    }
}
赞(1) 打赏
未经允许不得转载:魔豆IT » ThreadLocal的设计精要
分享到: 更多 (0)

评论 抢沙发

评论前必须登录!