2016-2020, Ivan Sadikov and Raazesh Sainudiin

We extend Shortest Paths algorithm in Spark's GraphX Library to allow for user-specified edge-weights as an edge attribute.

This is part of Project MEP: Meme Evolution Programme and supported by databricks academic partners program.

The analysis is available in the following databricks notebook: *

Let's modify shortest paths algorithm to allow for user-specified edge-weights

Update shortest paths algorithm to work over edge attribute of edge-weights as Double, key concepts are: - we increment map with delta, which is edge.attr - edge attribute is anything numeric, tested on Double - infinity value is not infinity, but Integer.MAX_VALUE

import scala.reflect.ClassTag
import org.apache.spark.graphx._

 * Computes shortest weighted paths to the given set of landmark vertices, returning a graph where each
 * vertex attribute is a map containing the shortest-path distance to each reachable landmark.
 * Currently supports only Graph of [VD, Double], where VD is an arbitrary vertex type.
object GraphXShortestWeightedPaths extends Serializable {
  /** Stores a map from the vertex id of a landmark to the distance to that landmark. */
  type SPMap = Map[VertexId, Double]
  // initial and infinity values, use to relax edges
  private val INITIAL = 0.0
  private val INFINITY = Int.MaxValue.toDouble

  private def makeMap(x: (VertexId, Double)*) = Map(x: _*)

  private def incrementMap(spmap: SPMap, delta: Double): SPMap = { { case (v, d) => v -> (d + delta) }

  private def addMaps(spmap1: SPMap, spmap2: SPMap): SPMap = {
    (spmap1.keySet ++ spmap2.keySet).map {
      k => k -> math.min(spmap1.getOrElse(k, INFINITY), spmap2.getOrElse(k, INFINITY))
  // at this point it does not really matter what vertex type is
  def run[VD](graph: Graph[VD, Double], landmarks: Seq[VertexId]): Graph[SPMap, Double] = {
    val spGraph = graph.mapVertices { (vid, attr) =>
      // initial value for itself is 0.0 as Double
      if (landmarks.contains(vid)) makeMap(vid -> INITIAL) else makeMap()

    val initialMessage = makeMap()

    def vertexProgram(id: VertexId, attr: SPMap, msg: SPMap): SPMap = {
      addMaps(attr, msg)

    def sendMessage(edge: EdgeTriplet[SPMap, Double]): Iterator[(VertexId, SPMap)] = {
      val newAttr = incrementMap(edge.dstAttr, edge.attr)
      if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
      else Iterator.empty

    Pregel(spGraph, initialMessage)(vertexProgram, sendMessage, addMaps)

println("Usage: val result =, Seq(4L, 0L, 9L))")
Usage: val result =, Seq(4L, 0L, 9L))
defined object GraphXShortestWeightedPaths

Generate test graph

Generate simple graph with double weights for edges

import scala.util.Random

import org.apache.spark.graphx.{Graph, VertexId}
import org.apache.spark.graphx.util.GraphGenerators

// A graph with edge attributes containing distances
val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 10, seed=123L).mapEdges { e => 
  // to make things nicer we assign 0 distance to itself
  if (e.srcId == e.dstId) 0.0 else Random.nextDouble()
graph: org.apache.spark.graphx.Graph[Long,Double] = org.apache.spark.graphx.impl.GraphImpl@6a98b503
val landMarkVertexIds = Seq(4L, 0L, 9L)
val result =, landMarkVertexIds)
landMarkVertexIds: Seq[Long] = List(4, 0, 9)
result: org.apache.spark.graphx.Graph[GraphXShortestWeightedPaths.SPMap,Double] = org.apache.spark.graphx.impl.GraphImpl@763902d8
// Found shortest paths
(4,Map(4 -> 0.0, 0 -> 0.4978771374144447, 9 -> 0.0039659650390762025))
(0,Map(0 -> 0.0, 4 -> 0.6026164718020462, 9 -> 0.6065824368411225))
(6,Map(0 -> 0.18329441371794564, 4 -> 0.7501236631976119, 9 -> 0.5876411577224736))
(8,Map(0 -> 1.0175525403647847, 4 -> 0.8260424329285025, 9 -> 0.6635599274533642))
(2,Map(0 -> 0.27424380755189526, 4 -> 0.43261521868768604, 9 -> 0.43658118372676225))
(1,Map(0 -> 0.8561112576964152, 4 -> 1.4229405071760814, 9 -> 1.2604580017009432))
(3,Map(0 -> 0.45813682659496957, 4 -> 0.2666267191586873, 9 -> 0.10414421368354898))
(7,Map(4 -> 0.2073486603716349, 9 -> 0.04486615489649659, 0 -> 0.42470420866790193))
(9,Map(9 -> 0.0, 0 -> 0.5045977394480895, 4 -> 0.3302516601623805))
(5,Map(0 -> 0.7760991789332677, 4 -> 0.278222041518823, 9 -> 0.2821880065578992))
// edges with weights, make sure to check couple of shortest paths from above
|srcId|dstId|               attr|
|    0|    0|                0.0|
|    0|    1| 0.3685042379337995|
|    0|    4| 0.6026164718020462|
|    1|    1|                0.0|
|    1|    6| 0.6728168439784695|
|    1|    6| 0.7852936724902796|
|    2|    0|0.27424380755189526|
|    2|    1| 0.7870284262414232|
|    2|    4|0.43261521868768604|
|    2|    6| 0.3595658109787053|
|    2|    6|0.27647602206577304|
|    2|    9|  0.841095651455322|
|    3|    2| 0.1838930190430743|
|    3|    3|                0.0|
|    3|    5|0.29401237129990443|
|    3|    7|0.05927805878705239|
|    3|    8| 0.6194270974815277|
|    4|    0| 0.4978771374144447|
|    4|    3|  0.895324471945953|
|    4|    5| 0.3082168662674215|
only showing top 20 rows // this is the directed weighted edge of the graph
|srcId|dstId|                attr|
|    0|    0|                 0.0|
|    0|    1| 0.18179342573305257|
|    0|    4| 0.10499884753471289|
|    1|    1|                 0.0|
|    1|    6| 0.06902878615829067|
|    1|    6| 0.31916888663632115|
|    2|    0|   0.994117390558936|
|    2|    1|  0.3483496903814788|
|    2|    4|  0.6637120585508148|
|    2|    6|  0.7903591616520668|
|    2|    6| 0.15547330775766222|
|    2|    9| 0.20548813724247617|
|    3|    2| 0.30236814876856666|
|    3|    3|                 0.0|
|    3|    5|  0.8097988524182909|
|    3|    7|  0.7296892168314497|
|    3|    8|  0.7622160701710208|
|    4|    0| 0.19303017905390474|
|    4|    3| 0.09317663461190051|
|    4|    5|0.013527517923307197|
only showing top 20 rows
// now let us collect the shortest distance between every vertex and every landmark vertex
// to manipulate scala maps that are vertices of the result see:
// a quick point:
val shortestDistsVertex2Landmark = result.vertices.flatMap(GxSwpSPMap => { => (GxSwpSPMap._1, x._1, x._2)) // to get triples: vertex, landmarkVertex, shortest_distance
shortestDistsVertex2Landmark: org.apache.spark.rdd.RDD[(org.apache.spark.graphx.VertexId, org.apache.spark.graphx.VertexId, Double)] = MapPartitionsRDD[112] at flatMap at command-2971213210276591:4
res6: String =

Let's make a DataFrame for visualizing pairwise matrix plots

We want to make 4 columns in this example as follows (note actual values change for each realisation of graph!):

landmark_Id1 ("0"),   landmarkID2 ("4"), landmarkId3 ("9"),  srcVertexId
0.0,                  0.7425..,          0.8718,                0
0.924...,             1.2464..,          1.0472,                1
// we need this to make sure that the maps are ordered by the keys for ensuring unique column values
import scala.collection.immutable.ListMap
import sqlContext.implicits._
import scala.collection.immutable.ListMap
import sqlContext.implicits._
 // recall our landmark vertices in landMarkVertexIds. let's use their Strings for names
val unorderedNamedLandmarkVertices = => (id, id.toString) )
val orderedNamedLandmarkVertices = ListMap(unorderedNamedLandmarkVertices.sortBy(_._1):_*)
val orderedLandmarkVertexNames = => x._2)
orderedLandmarkVertexNames.mkString(", ")
unorderedNamedLandmarkVertices: Seq[(Long, String)] = List((4,4), (0,0), (9,9))
orderedNamedLandmarkVertices: scala.collection.immutable.ListMap[Long,String] = ListMap(0 -> 0, 4 -> 4, 9 -> 9)
orderedLandmarkVertexNames: Seq[String] = Vector(0, 4, 9)
res7: String = 0, 4, 9
// this is going to be our column names
val columnNames:Seq[String] = orderedLandmarkVertexNames :+ "srcVertexId"
columnNames: Seq[String] = Vector(0, 4, 9, srcVertexId)
// a case class to make a data-frame quickly from the result
case class SeqOfDoublesAndsrcVertexId(shortestDistances: Seq[Double], srcVertexId: VertexId)
defined class SeqOfDoublesAndsrcVertexId
val shortestDistsSeqFromVertex2Landmark2DF = => {
  // => (GxSwpSPMap._1, x._1, x._2)) // from before to get triples: vertex, landmarkVertex, shortest_distance
  val v = GxSwpSPMap._1
  val a = ListMap(GxSwpSPMap._2.toSeq.sortBy(_._1):_*) => x._2)
  val d = (a,v)
}).map(x => SeqOfDoublesAndsrcVertexId(x._1, x._2)).toDF()
shortestDistsSeqFromVertex2Landmark2DF: org.apache.spark.sql.DataFrame = [shortestDistances: array<double>, srcVertexId: bigint] // but this dataframe needs the first column exploded into 3 columns
|   shortestDistances|srcVertexId|
|[0.49787713741444...|          4|
|[0.0, 0.602616471...|          0|
|[0.18329441371794...|          6|
|[1.01755254036478...|          8|
|[0.27424380755189...|          2|
|[0.85611125769641...|          1|
|[0.45813682659496...|          3|
|[0.42470420866790...|          7|
|[0.50459773944808...|          9|
|[0.77609917893326...|          5|

Now we want to make separate columns for each distance in the Sequence in column 'shortestDistances'.

Let us use the following ideas for this: *

// this is from
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions.{lit, udf}

// UDF to extract i-th element from array column
//val elem = udf((x: Seq[Int], y: Int) => x(y))
val elem = udf((x: Seq[Double], y: Int) => x(y)) // modified for Sequence of Doubles

// Method to apply 'elem' UDF on each element, requires knowing length of sequence in advance
def split(col: Column, len: Int): Seq[Column] = {
  for (i <- 0 until len) yield { elem(col, lit(i)).as(s"$col($i)") }

// Implicit conversion to make things nicer to use, e.g. 
// select(Column, Seq[Column], Column) is converted into select(Column*) flattening sequences
implicit class DataFrameSupport(df: DataFrame) {
  def select(cols: Any*): DataFrame = {
    var buffer: Seq[Column] = Seq.empty
    for (col <- cols) {
      if (col.isInstanceOf[Seq[_]]) {
        buffer = buffer ++ col.asInstanceOf[Seq[Column]]
      } else {
        buffer = buffer :+ col.asInstanceOf[Column]
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions.{lit, udf}
elem: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$6726/1752149086@9f2dc3d,DoubleType,List(Some(class[value[0]: array<double>]), Some(class[value[0]: int])),Some(class[value[0]: double]),None,false,true)
split: (col: org.apache.spark.sql.Column, len: Int)Seq[org.apache.spark.sql.Column]
defined class DataFrameSupport
val shortestDistsFromVertex2Landmark2DF =$"shortestDistances", 3), $"srcVertexId")
shortestDistsFromVertex2Landmark2DF: org.apache.spark.sql.DataFrame = [shortestDistances(0): double, shortestDistances(1): double ... 2 more fields]
|                  0|                  4|                   9|srcVertexId|
| 0.4978771374144447|                0.0|0.003965965039076...|          4|
|                0.0| 0.6026164718020462|  0.6065824368411225|          0|
|0.18329441371794564| 0.7501236631976119|  0.5876411577224736|          6|
| 1.0175525403647847| 0.8260424329285025|  0.6635599274533642|          8|
|0.27424380755189526|0.43261521868768604| 0.43658118372676225|          2|
| 0.8561112576964152| 1.4229405071760814|  1.2604580017009432|          1|
|0.45813682659496957| 0.2666267191586873| 0.10414421368354898|          3|
|0.42470420866790193| 0.2073486603716349| 0.04486615489649659|          7|
| 0.5045977394480895| 0.3302516601623805|                 0.0|          9|
| 0.7760991789332677|  0.278222041518823|  0.2821880065578992|          5|
// now let's give it our names based on the landmark vertex Ids
val shortestDistsFromVertex2Landmark2DF =$"shortestDistances", 3), $"srcVertexId").toDF(columnNames:_*)
shortestDistsFromVertex2Landmark2DF: org.apache.spark.sql.DataFrame = [0: double, 4: double ... 2 more fields]
|                  0|                  4|                   9|srcVertexId|
| 0.4978771374144447|                0.0|0.003965965039076...|          4|
|                0.0| 0.6026164718020462|  0.6065824368411225|          0|
|0.18329441371794564| 0.7501236631976119|  0.5876411577224736|          6|
| 1.0175525403647847| 0.8260424329285025|  0.6635599274533642|          8|
|0.27424380755189526|0.43261521868768604| 0.43658118372676225|          2|
| 0.8561112576964152| 1.4229405071760814|  1.2604580017009432|          1|
|0.45813682659496957| 0.2666267191586873| 0.10414421368354898|          3|
|0.42470420866790193| 0.2073486603716349| 0.04486615489649659|          7|
| 0.5045977394480895| 0.3302516601623805|                 0.0|          9|
| 0.7760991789332677|  0.278222041518823|  0.2821880065578992|          5|
0 4 9
0.4978771374144447 0.0 3.9659650390762025e-3
0.0 0.6026164718020462 0.6065824368411225
0.18329441371794564 0.7501236631976119 0.5876411577224736
1.0175525403647847 0.8260424329285025 0.6635599274533642
0.27424380755189526 0.43261521868768604 0.43658118372676225
0.8561112576964152 1.4229405071760814 1.2604580017009432
0.45813682659496957 0.2666267191586873 0.10414421368354898
0.42470420866790193 0.2073486603716349 4.486615489649659e-2
0.5045977394480895 0.3302516601623805 0.0
0.7760991789332677 0.278222041518823 0.2821880065578992