Spark源码分析之分区器的作用

[复制链接]
发表于 : 2017-7-31 17:24:30 最新回复:2017-07-31 17:26:22
1617 1
建赟
建赟  专家

最近因为手抖,在Spark中给自己挖了一个数据倾斜的坑。为了解决这个问题,顺便研究了下Spark分区器的原理,趁着周末加班总结一下~

先说说数据倾斜

数据倾斜是指Spark中的RDD在计算的时候,每个RDD内部的分区包含的数据不平均。比如一共有5个分区,其中一个占有了90%的数据,这就导致本来5个分区可以5个人一起并行干活,结果四个人不怎么干活,工作全都压到一个人身上了。遇到这种问题,网上有很多的解决办法。

但是如果是底层数据的问题,无论怎么优化,还是无法解决数据倾斜的。

比如你想要对某个rdd做groupby,然后做join操作,如果分组的key就是分布不均匀的,那么真样都是无法优化的。因为一旦这个key被切分,就无法完整的做join了,如果不对这个key切分,必然会造成对应的分区数据倾斜。

不过,了解数据为什么会倾斜还是很重要的,继续往下看吧!

分区的作用

在PairRDD即(key,value)这种格式的rdd中,很多操作都是基于key的,因此为了独立分割任务,会按照key对数据进行重组。比如groupbykey

http://s2.51cto.com/wyfs02/M01/92/30/wKioL1j9U73SyQHAAAAnlqP5KV0819.jpg

重组肯定是需要一个规则的,最常见的就是基于Hash,Spark还提供了一种稍微复杂点的基于抽样的Range分区方法。

下面我们先看看分区器在Spark计算流程中是怎么使用的:

Paritioner的使用

就拿groupbykey来说:

  1. def groupByKey(): JavaPairRDD[K, JIterable[V]] = 
  2.     fromRDD(groupByResultToJava(rdd.groupByKey())) 

它会调用PairRDDFunction的groupByKey()方法

  1. def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { 
  2.     groupByKey(defaultPartitioner(self)) 
  3.   } 

在这个方法里面创建了默认的分区器。默认的分区器是这样定义的:

  1. def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { 
  2.     val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse 
  3.     for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { 
  4.       return r.partitioner.get 
  5.     } 
  6.     if (rdd.context.conf.contains("spark.default.parallelism")) { 
  7.       new HashPartitioner(rdd.context.defaultParallelism) 
  8.     } else { 
  9.       new HashPartitioner(bySize.head.partitions.size
  10.     } 
  11.   } 

首先获取当前分区的分区个数,如果没有设置spark.default.parallelism参数,则创建一个跟之前分区个数一样的Hash分区器。

当然,用户也可以自定义分区器,或者使用其他提供的分区器。API里面也是支持的:

  1. // 传入分区器对象 
  2. def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = 
  3.     fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) 
  4. // 传入分区的个数 
  5. def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = 
  6.     fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) 

HashPatitioner

Hash分区器,是最简单也是默认提供的分区器,了解它的分区规则,对我们处理数据倾斜或者设计分组的key时,还是很有帮助的。

  1. class HashPartitioner(partitions: Int) extends Partitioner { 
  2.   require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative."
  3.  
  4.   def numPartitions: Int = partitions 
  5.  
  6.   // 通过key计算其HashCode,并根据分区数取模。如果结果小于0,直接加上分区数。 
  7.   def getPartition(keyAny): Int = key match { 
  8.     case null => 0 
  9.     case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) 
  10.   } 
  11.  
  12.   // 对比两个分区器是否相同,直接对比其分区个数就行 
  13.   override def equals(other: Any): Boolean = other match { 
  14.     case h: HashPartitioner => 
  15.       h.numPartitions == numPartitions 
  16.     case _ => 
  17.       false 
  18.   } 
  19.  
  20.   override def hashCode: Int = numPartitions 

这里最重要的是这个Utils.nonNegativeMod(key.hashCode, numPartitions),它决定了数据进入到哪个分区。

  1. def nonNegativeMod(x: Int, mod: Int): Int = { 
  2.     val rawMod = x % mod 
  3.     rawMod + (if (rawMod < 0) mod else 0) 
  4.   } 

说白了,就是基于这个key获取它的hashCode,然后对分区个数取模。由于HashCode可能为负,这里直接判断下,如果小于0,再加上分区个数即可。

因此,基于hash的分区,只要保证你的key是分散的,那么最终数据就不会出现数据倾斜的情况。

RangePartitioner

这个分区器,适合想要把数据打散的场景,但是如果相同的key重复量很大,依然会出现数据倾斜的情况。

每个分区器,最核心的方法,就是getPartition

  1. def getPartition(keyAny): Int = { 
  2.     val k = key.asInstanceOf[K] 
  3.     var partition = 0 
  4.     if (rangeBounds.length <= 128) { 
  5.       // If we have less than 128 partitions naive search 
  6.       while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) { 
  7.         partition += 1 
  8.       } 
  9.     } else { 
  10.       // Determine which binary search method to use only once. 
  11.       partition = binarySearch(rangeBounds, k) 
  12.       // binarySearch either returns the match location or -[insertion point]-1 
  13.       if (partition < 0) { 
  14.         partition = -partition-1 
  15.       } 
  16.       if (partition > rangeBounds.length) { 
  17.         partition = rangeBounds.length 
  18.       } 
  19.     } 
  20.     if (ascending) { 
  21.       partition 
  22.     } else { 
  23.       rangeBounds.length - partition 
  24.     } 
  25.   } 

在range分区中,会存储一个边界的数组,比如[1,100,200,300,400],然后对比传进来的key,返回对应的分区id。

那么这个边界是怎么确定的呢?

这就是Range分区最核心的算法了,大概描述下,就是遍历每个paritiion,对里面的数据进行抽样,把抽样的数据进行排序,并按照对应的权重确定边界。

有几个比较重要的地方:

1 抽样

2 确定边界

关于抽样,有一个很常见的算法题,即在不知道数据规模的情况下,如何以等概率的方式,随机选择一个值。

最笨的办法,就是遍历一次数据,知道数据的规模,然后随机一个数,取其对应的值。其实这样相当于遍历了两次(第二次的取值根据不同的存储介质,可能不同)。

在Spark中,是使用水塘抽样这种算法。即首先取第一个值,然后依次往后遍历;第二个值有二分之一的几率替换选出来的值;第三个值有三分之一的几率替换选出来的值;…;直到遍历到最后一个值。这样,通过依次遍历就取出来随机的数值了。

算法参考源码:

  1. private var rangeBounds: Array[K] = { 
  2.     if (partitions <= 1) { 
  3.       Array.empty 
  4.     } else { 
  5.       // This is the sample size we need to have roughly balanced output partitions, capped at 1M. 
  6.       // 最大采样数量不能超过1M。比如,如果分区是5,采样数为100 
  7.       val sampleSize = math.min(20.0 * partitions, 1e6) 
  8.       // Assume the input partitions are roughly balanced and over-sample a little bit
  9.       // 每个分区的采样数为平均值的三倍,避免数据倾斜造成的数据量过少 
  10.       val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt 
  11.  
  12.       // 真正的采样算法(参数1:rdd的key数组, 采样个数) 
  13.       val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) 
  14.       if (numItems == 0L) { 
  15.         Array.empty 
  16.       } else { 
  17.         // If a partition contains much more than the average number of items, we re-sample from it 
  18.         // to ensure that enough items are collected from that partition. 
  19.         // 如果有的分区包含的数量远超过平均值,那么需要对它重新采样。每个分区的采样数/采样返回的总的记录数 
  20.         val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) 
  21.         //保存有效的采样数 
  22.         val candidates = ArrayBuffer.empty[(K, Float)] 
  23.         //保存数据倾斜导致的采样数过多的信息 
  24.         val imbalancedPartitions = mutable.Set.empty[Int
  25.  
  26.         sketched.foreach { case (idx, n, sample) => 
  27.           if (fraction * n > sampleSizePerPartition) { 
  28.             imbalancedPartitions += idx 
  29.           } else { 
  30.             // The weight is 1 over the sampling probability. 
  31.             val weight = (n.toDouble / sample.size).toFloat 
  32.             for (key <- sample) { 
  33.               candidates += ((key, weight)) 
  34.             } 
  35.           } 
  36.         } 
  37.         if (imbalancedPartitions.nonEmpty) { 
  38.           // Re-sample imbalanced partitions with the desired sampling probability. 
  39.           val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains
  40.           val seed = byteswap32(-rdd.id - 1) 
  41.           //基于RDD获取采样数据 
  42.           val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() 
  43.           val weight = (1.0 / fraction).toFloat 
  44.           candidates ++= reSampled.map(x => (x, weight)) 
  45.         } 
  46.         RangePartitioner.determineBounds(candidates, partitions) 
  47.       } 
  48.     } 
  49.   } 
  50.    
  51.   def sketch[K : ClassTag]( 
  52.       rdd: RDD[K], 
  53.       sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { 
  54.     val shift = rdd.id 
  55.     // val classTagK = classTag[K] // to avoid serializing the entire partitioner object 
  56.     val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => 
  57.       val seed = byteswap32(idx ^ (shift << 16)) 
  58.       val (sample, n) = SamplingUtils.reservoirSampleAndCount( 
  59.         iter, sampleSizePerPartition, seed) 
  60.       //包装成三元组,(索引号,分区的内容个数,抽样的内容) 
  61.       Iterator((idx, n, sample)) 
  62.     }.collect() 
  63.     val numItems = sketched.map(_._2).sum 
  64.     //返回(数据条数,(索引号,分区的内容个数,抽样的内容)) 
  65.     (numItems, sketched) 
  66.   } 
  67.    

真正的抽样算法在SamplingUtils中,由于在Spark中是需要一次性取多个值的,因此直接去前n个数值,然后依次概率替换即可:

  1. def reservoirSampleAndCount[T: ClassTag]( 
  2.       input: Iterator[T], 
  3.       k: Int
  4.       seed: Long = Random.nextLong()) 
  5.     : (Array[T], Long) = { 
  6.     //创建临时数组 
  7.     val reservoir = new Array[T](k) 
  8.     // Put the first k elements in the reservoir. 
  9.     // 取出前k个数,并把对应的rdd中的数据放入对应的序号的数组中 
  10.     var i = 0 
  11.     while (i < k && input.hasNext) { 
  12.       val item = input.next() 
  13.       reservoir(i) = item 
  14.       i += 1 
  15.     } 
  16.  
  17.     // If we have consumed all the elements, return them. Otherwise do the replacement. 
  18.     // 如果全部的元素,比要抽取的采样数少,那么直接返回 
  19.     if (i < k) { 
  20.       // If input size < k, trim the array to return only an array of input size
  21.       val trimReservoir = new Array[T](i) 
  22.       System.arraycopy(reservoir, 0, trimReservoir, 0, i) 
  23.       (trimReservoir, i) 
  24.  
  25.     // 否则开始抽样替换 
  26.     } else { 
  27.       // If input size > k, continue the sampling process. 
  28.       // 从刚才的序号开始,继续遍历 
  29.       var l = i.toLong 
  30.       // 随机数 
  31.       val rand = new XORShiftRandom(seed) 
  32.       while (input.hasNext) { 
  33.         val item = input.next() 
  34.         // 随机一个数与当前的l相乘,如果小于采样数k,就替换。(越到后面,替换的概率越小...) 
  35.         val replacementIndex = (rand.nextDouble() * l).toLong 
  36.         if (replacementIndex < k) { 
  37.           reservoir(replacementIndex.toInt) = item 
  38.         } 
  39.         l += 1 
  40.       } 
  41.       (reservoir, l) 
  42.     } 
  43.   } 

确定边界

最后就可以通过获取的样本数据,确定边界了。

  1. def determineBounds[K : Ordering : ClassTag]( 
  2.       candidates: ArrayBuffer[(K, Float)], 
  3.       partitions: Int): Array[K] = { 
  4.     val ordering = implicitly[Ordering[K]] 
  5.     // 数据格式为(key,权重) 
  6.     val ordered = candidates.sortBy(_._1) 
  7.     val numCandidates = ordered.size 
  8.     val sumWeights = ordered.map(_._2.toDouble).sum 
  9.     val step = sumWeights / partitions 
  10.     var cumWeight = 0.0 
  11.     var target = step 
  12.     val bounds = ArrayBuffer.empty[K] 
  13.     var i = 0 
  14.     var j = 0 
  15.     var previousBound = Option.empty[K] 
  16.     while ((i < numCandidates) && (j < partitions - 1)) { 
  17.       val (key, weight) = ordered(i) 
  18.       cumWeight += weight 
  19.       if (cumWeight >= target) { 
  20.         // Skip duplicate values
  21.         if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { 
  22.           bounds += key 
  23.           target += step 
  24.           j += 1 
  25.           previousBound = Some(key
  26.         } 
  27.       } 
  28.       i += 1 
  29.     } 
  30.     bounds.toArray 
  31.   } 

直接看代码,还是有些晦涩难懂,我们举个例子,一步一步解释下:

http://s5.51cto.com/wyfs02/M00/92/32/wKiom1j9Vwuxpq9DAAB9PiPL9xE156.jpg

按照上面的算法流程,大致可以理解:

  1. 抽样-->确定边界(排序) 

首先对spark有一定了解的都应该知道,在spark中每个RDD可以理解为一组分区,这些分区对应了内存块block,他们才是数据最终的载体。那么一个RDD由不同的分区组成,这样在处理一些map,filter等算子的时候,就可以直接以分区为单位并行计算了。直到遇到shuffle的时候才需要和其他的RDD配合。

在上面的图中,如果我们不特殊设置的话,一个RDD由3个分区组成,那么在对它进行groupbykey的时候,就会按照3进行分区。

按照上面的算法流程,如果分区数为3,那么采样的大小为:

  1. val sampleSize = math.min(20.0 * partitions, 1e6) 

即采样数为60,每个分区取60个数。但是考虑到数据倾斜的情况,有的分区可能数据很多,因此在实际的采样时,会按照3倍大小采样:

  1. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt 

也就是说,最多会取60个样本数据。

然后就是遍历每个分区,取对应的样本数。

  1. val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => 
  2.       val seed = byteswap32(idx ^ (shift << 16)) 
  3.       val (sample, n) = SamplingUtils.reservoirSampleAndCount( 
  4.         iter, sampleSizePerPartition, seed) 
  5.       //包装成三元组,(索引号,分区的内容个数,抽样的内容) 
  6.       Iterator((idx, n, sample)) 
  7.     }.collect() 

然后检查,是否有分区的样本数过多,如果多于平均值,则继续采样,这时直接用sample 就可以了

  1. sketched.foreach { case (idx, n, sample) => 
  2.           if (fraction * n > sampleSizePerPartition) { 
  3.             imbalancedPartitions += idx 
  4.           } else { 
  5.             // The weight is 1 over the sampling probability. 
  6.             val weight = (n.toDouble / sample.size).toFloat 
  7.             for (key <- sample) { 
  8.               candidates += ((key, weight)) 
  9.             } 
  10.           } 
  11.         } 
  12.         if (imbalancedPartitions.nonEmpty) { 
  13.           // Re-sample imbalanced partitions with the desired sampling probability. 
  14.           val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains
  15.           val seed = byteswap32(-rdd.id - 1) 
  16.           //基于RDD获取采样数据 
  17.           val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() 
  18.           val weight = (1.0 / fraction).toFloat 
  19.           candidates ++= reSampled.map(x => (x, weight)) 
  20.         } 

取出样本后,就到了确定边界的时候了。

注意每个key都会有一个权重,这个权重是 【分区的数据总数/样本数】

  1. RangePartitioner.determineBounds(candidates, partitions) 

首先排序val ordered = candidates.sortBy(_._1),然后确定一个权重的步长

  1. val sumWeights = ordered.map(_._2.toDouble).sum 
  2. val step = sumWeights / partitions 

基于该步长,确定边界,最后就形成了几个范围数据。

然后分区器形成二叉树,遍历该数确定每个key对应的分区id

  1. partition = binarySearch(rangeBounds, k) 

实践 —— 自定义分区器

自定义分区器,也是很简单的,只需要实现对应的两个方法就行:

  1. public class MyPartioner extends Partitioner { 
  2.     @Override 
  3.     public int numPartitions() { 
  4.         return 1000; 
  5.     } 
  6.  
  7.     @Override 
  8.     public int getPartition(Object key) { 
  9.         String k = (String) key
  10.         int code = k.hashCode() % 1000; 
  11.         System.out.println(k+":"+code); 
  12.         return  code < 0?code+1000:code; 
  13.     } 
  14.  
  15.     @Override 
  16.     public boolean equals(Object obj) { 
  17.         if(obj instanceof MyPartioner){ 
  18.             if(this.numPartitions()==((MyPartioner) obj).numPartitions()){ 
  19.                 return true
  20.             } 
  21.             return false
  22.         } 
  23.         return super.equals(obj); 
  24.     } 

使用的时候,可以直接new一个对象即可。

  1. pairRdd.groupbykey(new MyPartitioner()) 

这样自定义分区器就完成了。


  • x
  • 常规:

点评 回复

跳转到指定楼层
建赟
建赟  专家 发表于 2017-7-31 17:26:22 已赞(0) 赞(0)

顶一个!
  • x
  • 常规:

点评 回复

发表回复
您需要登录后才可以回帖 登录 | 注册

警告 内容安全提示:尊敬的用户您好,为了保障您、社区及第三方的合法权益,请勿发布可能给各方带来法律风险的内容,包括但不限于政治敏感内容,涉黄赌毒内容,泄露、侵犯他人商业秘密的内容,侵犯他人商标、版本、专利等知识产权的内容,侵犯个人隐私的内容等。也请勿向他人共享您的账号及密码,通过您的账号执行的所有操作,将视同您本人的行为,由您本人承担操作后果。详情请参看“隐私声明
如果附件按钮无法使用,请将Adobe Flash Player 更新到最新版本!
登录参与交流分享

登录参与交流分享

登录