基于 Striped64 实现 DoubleMinUpdater

0

最近写了一个metric统计的sdk,功能类似dropwizard/metrics

在Histograms场景,需要统计最大最小值,为了数据尽量精准,这里没有采用sample的办法,考虑到QPS会很高,所以也不能存储统计周期内的所有数据,因为内存占用不可控。

因此对于百分位数,这里采用了桶压缩的办法,计算近似值;对于最大最小值,这里使用两个变量精确计算。

如果并发统计最大最小值也是蛮有趣的一个问题。

使用 AtomicDouble

最先想到的可能就是AtomicDouble

private AtomicDouble max = new AtomicDouble(Double.NEGATIVE_INFINITY);

public void updateMax(double v) {
    double cur;
    while ((cur = max.get()) < v) {
        if (max.compareAndSet(cur, v)) {
            break;
        }
    }
}

通常来说这种方式可以满足大部分的场景,但是对于QPS的特别高的场景,可能会影响到业务的执行。

基于 Striped64

如何处理高并发的场景,首先联想到了LongAdder,Java 8 提供的累加器实现,随着并发线程的增加,性能越优于AtomicLong,具体的性能差别见Java 8 Performance Improvements: LongAdder vs AtomicLong

LongAdder提高并发性能的秘诀在于将并发更新的单元分为了多个段,从而减小整体的并发压力,主要的处理逻辑封装在了父类——Striped64中。

同样的思路可以应用到统计最大、最小值上,实际上在jsr166e中,已经存在基于Striped64的并发统计最大值的实现DoubleMaxUpdater.java

在看这些功能类之前,首先分析一下Striped64的实现,这里选择jdk8u最新的代码

package java.util.concurrent.atomic;
import java.util.function.LongBinaryOperator;
import java.util.function.DoubleBinaryOperator;
import java.util.concurrent.ThreadLocalRandom;

@SuppressWarnings("serial")
abstract class Striped64 extends Number {
    
    //@sun.misc.Contended 通过注解添加 cache line 的 padding
    //更通用的写法如下:
    //volatile long p0, p1, p2, p3, p4, p5, p6;
    //volatile long value;
    //volatile long q0, q1, q2, q3, q4, q5, q6;
    //
    @sun.misc.Contended static final class Cell {
        volatile long value;
        Cell(long x) { value = x; }
        final boolean cas(long cmp, long val) {
            return UNSAFE.compareAndSwapLong(this, valueOffset, cmp, val);
        }

        // Unsafe mechanics
        private static final sun.misc.Unsafe UNSAFE;
        private static final long valueOffset;
        static {
            try {
                UNSAFE = sun.misc.Unsafe.getUnsafe();
                Class<?> ak = Cell.class;
                valueOffset = UNSAFE.objectFieldOffset
                    (ak.getDeclaredField("value"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }
    }

    // CPU数量,限制cells的长度
    static final int NCPU = Runtime.getRuntime().availableProcessors();

    transient volatile Cell[] cells;

    // 如果没有 contention,就通过 base 计算数值
    transient volatile long base;

    // 操作 cells 的自旋锁
    transient volatile int cellsBusy;

    // 包可见
    Striped64() {
    }

    // 对 base 的 cas
    final boolean casBase(long cmp, long val) {
        return UNSAFE.compareAndSwapLong(this, BASE, cmp, val);
    }

    // 获取 cellBusy 锁
    final boolean casCellsBusy() {
        return UNSAFE.compareAndSwapInt(this, CELLSBUSY, 0, 1);
    }

    // 获取当前线程的 threadLocalRandomProbe 值,也就是线程的 hashcode
    static final int getProbe() {
        return UNSAFE.getInt(Thread.currentThread(), PROBE);
    }

    // Pseudo-randomly advances, 设置当前线程的 threadLocalRandomProbe 值
    static final int advanceProbe(int probe) {
        probe ^= probe << 13;   // xorshift
        probe ^= probe >>> 17;
        probe ^= probe << 5;
        UNSAFE.putInt(Thread.currentThread(), PROBE, probe);
        return probe;
    }

    /**
     * Handles cases of updates involving initialization, resizing,
     * creating new Cells, and/or contention. See above for
     * explanation. This method suffers the usual non-modularity
     * problems of optimistic retry code, relying on rechecked sets of
     * reads.
     *
     * @param x the value
     * @param fn the update function, or null for add (this convention
     * avoids the need for an extra field or function in LongAdder).
     * @param wasUncontended false if CAS failed before call
     */
    final void longAccumulate(long x, LongBinaryOperator fn,
                              boolean wasUncontended) {
        // 设置线程 probe, 此时没有 contention
        int h;
        if ((h = getProbe()) == 0) {
            ThreadLocalRandom.current(); // force initialization
            h = getProbe();
            wasUncontended = true;
        }
        
        // 最后的 cell 不为空为 true
        boolean collide = false;                // True if last slot nonempty
        for (;;) {
            Cell[] as; Cell a; int n; long v;
            
            // 如果 cells 已经初始化
            if ((as = cells) != null && (n = as.length) > 0) {
            
                // 如果线程 hashcode 对应的 cell 为空
                if ((a = as[(n - 1) & h]) == null) {
                
                    // 如果可以获取 cellBusy 锁
                    if (cellsBusy == 0) {       // Try to attach new Cell
                        Cell r = new Cell(x);   // Optimistically create
                        
                        // 如果获取锁成功
                        if (cellsBusy == 0 && casCellsBusy()) {
                            boolean created = false;
                            try {               // Recheck under lock
                                Cell[] rs; int m, j;
                                if ((rs = cells) != null &&
                                    (m = rs.length) > 0 &&
                                    rs[j = (m - 1) & h] == null) {
                                    rs[j] = r;
                                    created = true;
                                }
                            } finally {
                                // 释放锁
                                cellsBusy = 0;
                            }
                            
                            // x 值放在了对应的 cell 中,执行完毕,退出
                            if (created)
                                break;
                            
                            // 如果对应的 cell 被别的线程修改了,那么从头开始
                            continue;           // Slot is now non-empty
                        }
                    }
                    
                    // 锁被占用
                    collide = false;
                }
                // 如果已经知道 cas 失败,那么重新散列
                else if (!wasUncontended)       // CAS already known to fail
                    wasUncontended = true;      // Continue after rehash
                
                // 如果在对应的 cell 上操作成功
                // 默认累加,负责使用传入的LongBinaryOperator
                else if (a.cas(v = a.value, ((fn == null) ? v + x :
                                             fn.applyAsLong(v, x))))
                    break;
                
                // cells 长度达到上限,或者 cells 扩容了
                else if (n >= NCPU || cells != as)
                    collide = false;            // At max size or stale
                
                // 如果不存在冲突, 设置为存在冲突
                else if (!collide)
                    collide = true;
                    
                // 如果获取自旋锁成功
                else if (cellsBusy == 0 && casCellsBusy()) {
                    try {
                        // 扩容 *2
                        if (cells == as) {      // Expand table unless stale
                            Cell[] rs = new Cell[n << 1];
                            for (int i = 0; i < n; ++i)
                                rs[i] = as[i];
                            cells = rs;
                        }
                    } finally {
                        cellsBusy = 0;
                    }
                    collide = false;
                    continue;                   // Retry with expanded table
                }
                
                // 如果获取不到错,重新散列
                h = advanceProbe(h);
            }
            
            // 如果 cells 没有初始化,那么重试获取锁
            else if (cellsBusy == 0 && cells == as && casCellsBusy()) {
                boolean init = false;
                try {                           // Initialize table
                    if (cells == as) {
                        Cell[] rs = new Cell[2];
                        rs[h & 1] = new Cell(x);
                        cells = rs;
                        init = true;
                    }
                } finally {
                    cellsBusy = 0;
                }
                if (init)
                    break;
            }
            
            // cells 正被别的线程初始化,那么尝试更新 base
            else if (casBase(v = base, ((fn == null) ? v + x :
                                        fn.applyAsLong(v, x))))
                break;                          // Fall back on using base
        }
    }

    // 逻辑与 longAccumulate 相同,只是加入了 long/double 转换
    final void doubleAccumulate(double x, DoubleBinaryOperator fn,
                                boolean wasUncontended) {
        int h;
        if ((h = getProbe()) == 0) {
            ThreadLocalRandom.current(); // force initialization
            h = getProbe();
            wasUncontended = true;
        }
        boolean collide = false;                // True if last slot nonempty
        for (;;) {
            Cell[] as; Cell a; int n; long v;
            if ((as = cells) != null && (n = as.length) > 0) {
                if ((a = as[(n - 1) & h]) == null) {
                    if (cellsBusy == 0) {       // Try to attach new Cell
                        Cell r = new Cell(Double.doubleToRawLongBits(x));
                        if (cellsBusy == 0 && casCellsBusy()) {
                            boolean created = false;
                            try {               // Recheck under lock
                                Cell[] rs; int m, j;
                                if ((rs = cells) != null &&
                                    (m = rs.length) > 0 &&
                                    rs[j = (m - 1) & h] == null) {
                                    rs[j] = r;
                                    created = true;
                                }
                            } finally {
                                cellsBusy = 0;
                            }
                            if (created)
                                break;
                            continue;           // Slot is now non-empty
                        }
                    }
                    collide = false;
                }
                else if (!wasUncontended)       // CAS already known to fail
                    wasUncontended = true;      // Continue after rehash
                else if (a.cas(v = a.value,
                               ((fn == null) ?
                                Double.doubleToRawLongBits
                                (Double.longBitsToDouble(v) + x) :
                                Double.doubleToRawLongBits
                                (fn.applyAsDouble
                                 (Double.longBitsToDouble(v), x)))))
                    break;
                else if (n >= NCPU || cells != as)
                    collide = false;            // At max size or stale
                else if (!collide)
                    collide = true;
                else if (cellsBusy == 0 && casCellsBusy()) {
                    try {
                        if (cells == as) {      // Expand table unless stale
                            Cell[] rs = new Cell[n << 1];
                            for (int i = 0; i < n; ++i)
                                rs[i] = as[i];
                            cells = rs;
                        }
                    } finally {
                        cellsBusy = 0;
                    }
                    collide = false;
                    continue;                   // Retry with expanded table
                }
                h = advanceProbe(h);
            }
            else if (cellsBusy == 0 && cells == as && casCellsBusy()) {
                boolean init = false;
                try {                           // Initialize table
                    if (cells == as) {
                        Cell[] rs = new Cell[2];
                        rs[h & 1] = new Cell(Double.doubleToRawLongBits(x));
                        cells = rs;
                        init = true;
                    }
                } finally {
                    cellsBusy = 0;
                }
                if (init)
                    break;
            }
            else if (casBase(v = base,
                             ((fn == null) ?
                              Double.doubleToRawLongBits
                              (Double.longBitsToDouble(v) + x) :
                              Double.doubleToRawLongBits
                              (fn.applyAsDouble
                               (Double.longBitsToDouble(v), x)))))
                break;                          // Fall back on using base
        }
    }

    // Unsafe mechanics
    private static final sun.misc.Unsafe UNSAFE;
    private static final long BASE;
    private static final long CELLSBUSY;
    private static final long PROBE;
    static {
        try {
            UNSAFE = sun.misc.Unsafe.getUnsafe();
            Class<?> sk = Striped64.class;
            BASE = UNSAFE.objectFieldOffset
                (sk.getDeclaredField("base"));
            CELLSBUSY = UNSAFE.objectFieldOffset
                (sk.getDeclaredField("cellsBusy"));
            Class<?> tk = Thread.class;
            PROBE = UNSAFE.objectFieldOffset
                (tk.getDeclaredField("threadLocalRandomProbe"));
        } catch (Exception e) {
            throw new Error(e);
        }
    }

}

逻辑上还是很清晰的,想要写出实现上没有bug的代码还是不容易的,不得不膜拜一下 Doug Lea。

jsr166e 中的 Striped64 与上述实现略有不同,主要是为了兼容 jdk6,因为 DoubleBinaryOperator 这种函数接口之后在 jdk8 之后才可以使用,所以就就变成了抽象方法fn,另外线程的 hashcode 采用了ThreadLocal变量保存。

笔者为了兼容性采用了 jsr166e 中的 Striped64 实现

实现 DoubleMinUpdater

jsr166e 已经实现了 DoubleMaxUpdater,参考一下很容易实现DoubleMinUpdater

public class DoubleMinUpdater extends Striped64 {

    /**
     * long 表示的 double 正无穷
     */
    private static final long MAX_AS_LONG = 0x7ff0000000000000L;

    /**
     * fn 函数即为 min(v, x)
     */
    final long fn(long v, long x) {
        return Double.longBitsToDouble(v) < Double.longBitsToDouble(x) ? v : x;
    }

    public DoubleMinUpdater() {
        base = MAX_AS_LONG;
    }

    /**
     * 更新最小值
     */
    public void update(double x) {
        long lx = Double.doubleToRawLongBits(x);
        Cell[] as; long b, v; int[] hc; Cell a; int n;
        // 如果 cells 不为空
        // 或者 cells 为空,但是更新 base 是吧
        if ((as = cells) != null ||
                (Double.longBitsToDouble(b = base) > x && !casBase(b, lx))) {
            boolean uncontended = true;
            // 如果线程 hashcode 为空
            // 或者 cells 为空
            // 或者线程对应的 cell 为空
            // 或者更新对应的 cell 失败
            if ((hc = threadHashCode.get()) == null ||
                    as == null || (n = as.length) < 1 ||
                    (a = as[(n - 1) & hc[0]]) == null ||
                    (Double.longBitsToDouble(v = a.value) > x &&
                            !(uncontended = a.cas(v, lx))))
                retryUpdate(lx, hc, uncontended);
        }
    }

    /**
     * 返回最小值
     * 注意:调用过程中的并发update可能会影响结果
     */
    public double min() {
        Cell[] as = cells;
        double min = Double.longBitsToDouble(base);
        if (as != null) {
            int n = as.length;
            double v;
            for (int i = 0; i < n; ++i) {
                Cell a = as[i];
                if (a != null && (v = Double.longBitsToDouble(a.value)) < min)
                    min = v;
            }
        }
        return min;
    }

    /**
     * 重置 base 和 cells
     * 注意:只能在没有并发操作时调用
     */
    public void reset() {
        internalReset(MAX_AS_LONG);
    }

    /**
     * 计算最小值,并且重置
     * 注意:调用过程中的并发update可能会影响结果
     *
     * @return the minimum
     */
    public double minThenReset() {
        Cell[] as = cells;
        double min = Double.longBitsToDouble(base);
        base = MAX_AS_LONG;
        if (as != null) {
            int n = as.length;
            for (int i = 0; i < n; ++i) {
                Cell a = as[i];
                if (a != null) {
                    double v = Double.longBitsToDouble(a.value);
                    a.value = MAX_AS_LONG;
                    if (v < min)
                        min = v;
                }
            }
        }
        return min;
    }

    /**
     * Returns the String representation of the {@link #min}.
     * @return the String representation of the {@link #min}
     */
    public String toString() {
        return Double.toString(min());
    }

    /**
     * Equivalent to {@link #min}.
     *
     * @return the min
     */
    public double doubleValue() {
        return min();
    }

}