Friday, 28 July 2017

[Scala / Threads / Parallel collections] How to configure number of threads on parallel collection ?

I've already written an article about controlling pool size while using Java's parallel stream. Recently I needed exactly the same feature in Scala. Scala collection can be turned into parallel collection by invoking par method (introduced by Parallelizable trait).
trait Parallelizable[+A, +ParRepr <: Parallel] extends Any
so assuming you don't break basic rules of functional programming (immutablility, statelessness etc.) you can make your algorithms much faster by adding .par. Unfortunately .par method (exactly like .parallelStream() in Java) doesn't take any parameter which could controll the size of a pool.
/** Returns a parallel implementation of this collection.
   *
   *  For most collection types, this method creates a new parallel collection by copying
   *  all the elements. For these collection, `par` takes linear time. Mutable collections
   *  in this category do not produce a mutable parallel collection that has the same
   *  underlying dataset, so changes in one collection will not be reflected in the other one.
   *
   *  Specific collections (e.g. `ParArray` or `mutable.ParHashMap`) override this default
   *  behaviour by creating a parallel collection which shares the same underlying dataset.
   *  For these collections, `par` takes constant or sublinear time.
   *
   *  All parallel collections return a reference to themselves.
   *
   *  @return  a parallel implementation of this collection
   */
  def par: ParRepr = {
    val cb = parCombiner
    for (x <- seq) cb += x
    cb.result()
  }
However you can set the number of threads using ForkJoinTaskSupport like that:
object ParExample extends App {
  val friends = List("rachel", "ross", "joey", "chandler", "pheebs", "monica").par
  friends.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(6))
}
Let's prove it works:
object ParExample extends App {
  val friends = List("rachel", "ross", "joey", "chandler", "pheebs", "monica", "gunther", "mike", "richard").par
  val stopwatch = Stopwatch.createStarted()
  friends.foreach(f => Thread.sleep(5000))
  println(stopwatch.elapsed(MILLISECONDS))
}
It prints 10025 so the pool must have been smaller than the list which contains 9 strings. Now I'll set the pool size:
object ParExample extends App {
  val friends = List("rachel", "ross", "joey", "chandler", "pheebs", "monica", "gunther", "mike", "richard").par
  friends.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(9))
  val stopwatch = Stopwatch.createStarted()
  friends.foreach(f => Thread.sleep(5000))
  println(stopwatch.elapsed(MILLISECONDS))
}
This time it took only 5024 milliseconds because each element of the list has been processed by separate thread. The solution is ok for me but I'd rather want to use concise one-liners so I've created implicit conversion between ParSeq and RichParSeq that adds parallelismLevel(numberOfThreads: Int) method.
class RichParSeq[T](val p: ParSeq[T]) {
  def parallelismLevel(numberOfThreads: Int): ParSeq[T] = {
    p.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(numberOfThreads))
    p
  }
}

object ConfigurableParallelism {
  implicit def parIterableToRichParIterable[T](p: ParSeq[T]): RichParSeq[T] = new RichParSeq[T](p)
}
Now you can use it like that:
import ConfigurableParallelism._

object ParExample extends App {
  val stopwatch = Stopwatch.createStarted()
  val friends = List("rachel", "ross", "joey", "chandler", "pheebs", "monica", "gunther", "mike", "richard").par
    .parallelismLevel(9)
    .foreach(f => Thread.sleep(5000))
  println(stopwatch.elapsed(MILLISECONDS))
}
It took 5231 milliseconds.