Databricks notebook source exported at Tue, 28 Jun 2016 11:17:28 UTC

Analyzing Golden State Warriors’ passing network using GraphFrames

** This notebook is created by Yuki Katoh and is a modified version of the article originally posted to Opiate for the masses **

** Dataset: Golden State Warriors’s pass data in 2015-16 regular season given by NBA.com **

** Source: http://stats.nba.com/ **

*This notebook requires Spark 1.6+

View the html source url of this databricks notebook

WARNING: Install the graphframe library before running the following commands. For instructions, see here.


from graphframes import *
import pandas as pd
import os
import json


# Get player IDs of Golden State Warriors
playerids = [201575, 201578, 2738, 202691, 101106, 2760, 2571, 203949, 203546, 203110, 201939, 203105, 2733, 1626172, 203084]


# Call stats.nba.com API and save pass data for each player as local JSON files
for playerid in playerids:
    os.system('curl "http://stats.nba.com/stats/playerdashptpass?'
        'DateFrom=&'
        'DateTo=&'
        'GameSegment=&'
        'LastNGames=0&'
        'LeagueID=00&'
        'Location=&'
        'Month=0&'
        'OpponentTeamID=0&'
        'Outcome=&'
        'PerMode=Totals&'
        'Period=0&'
        'PlayerID={playerid}&'
        'Season=2015-16&'
        'SeasonSegment=&'
        'SeasonType=Regular+Season&'
        'TeamID=0&'
        'VsConference=&'
        'VsDivision=" > {playerid}.json'.format(playerid=playerid))


# Parse JSON files and create pandas DataFrame
raw = pd.DataFrame()
for playerid in playerids:
    with open("{playerid}.json".format(playerid=playerid)) as json_file:
        parsed = json.load(json_file)['resultSets'][0]
        raw = raw.append(
            pd.DataFrame(parsed['rowSet'], columns=parsed['headers']))

raw = raw.rename(columns={'PLAYER_NAME_LAST_FIRST': 'PLAYER'})
raw['id'] = raw['PLAYER'].str.replace(', ', '')


# Create passes
passes = raw[raw['PASS_TO']
.isin(raw['PLAYER'])][['PLAYER', 'PASS_TO','PASS']]


# Make raw vertices
pandas_vertices = raw[['PLAYER', 'id']].drop_duplicates()
pandas_vertices.columns = ['name', 'id']


# Make raw edges
pandas_edges = pd.DataFrame()
for passer in raw['id'].drop_duplicates():
    for receiver in raw[(raw['PASS_TO'].isin(raw['PLAYER'])) &
     (raw['id'] == passer)]['PASS_TO'].drop_duplicates():
        pandas_edges = pandas_edges.append(pd.DataFrame(
        	{'passer': passer, 'receiver': receiver
        	.replace(  ', ', '')}, 
        	index=range(int(raw[(raw['id'] == passer) &
        	 (raw['PASS_TO'] == receiver)]['PASS'].values))))

pandas_edges.columns = ['src', 'dst']


# Bring the local vertices and edges to Spark
vertices = sqlContext.createDataFrame(pandas_vertices)
edges = sqlContext.createDataFrame(pandas_edges)


# Create GraphFrame
g = GraphFrame(vertices, edges)


# Print vertices
g.vertices.show()


#Print edges
g.edges.show()


# Print inDegree
g.inDegrees.sort('inDegree', ascending=False).show()


# Print outDegrees
g.outDegrees.sort('outDegree', ascending=False).show()


# Print degree
g.degrees.sort('degree', ascending=False).show()


pd.merge(left = g.outDegrees, right = g.inDegrees, on = 'id')



# %fs rm -r /FileStore/groups


# Print labelPropagation
lp = g.labelPropagation(maxIter=5)
lp.show()


#Print pageRank
pr = g.pageRank(resetProbability = 0.15, tol = 0.01).vertices.sort(
    'pagerank', ascending = False)
pr.show()


# Create a network
passes = sqlContext.createDataFrame(passes)
network = passes.join(lp, passes.PLAYER == lp.name, "inner")
network = network.join(pr,network.PLAYER == pr.name, "inner")
network = network[['PLAYER','PASS_TO','label','PASS','pagerank']]
network.collect()


# Make network available as a SQL table.
network.registerTempTable("network")


%sql select * from network


%scala if (org.apache.spark.BuildInfo.sparkBranch < "1.6") sys.error("Attach this notebook to a cluster running Spark 1.6+")


%scala
package d3
// We use a package object so that we can define top level classes like Edge that need to be used in other cells

import org.apache.spark.sql._
import com.databricks.backend.daemon.driver.EnhancedRDDFunctions.displayHTML

case class Edge(PLAYER: String, PASS_TO: String, PASS: Long, label: Long, pagerank: Double)

case class Node(name: String, label: Long, pagerank: Double)
case class Link(source: Int, target: Int, value: Long)
case class Graph(nodes: Seq[Node], links: Seq[Link])

object graphs {
val sqlContext = SQLContext.getOrCreate(org.apache.spark.SparkContext.getOrCreate())  
import sqlContext.implicits._
  
def force(network: Dataset[Edge], height: Int = 100, width: Int = 960): Unit = {
  val data = network.collect()
//   val nodes = (data.map(_.PLAYER) ++ data.map(_.PASS_TO)).map(_.replaceAll("_", " ")).toSet.toSeq.map(Node)
  val nodes = data.map { t =>
    Node(t.PLAYER, t.label, t.pagerank)}.distinct
  val links = data.map { t =>
    Link(nodes.indexWhere(_.name == t.PLAYER), nodes.indexWhere(_.name == t.PASS_TO), t.PASS / 20 + 1)
  }
  //     Link(nodes.indexWhere(_.name == t.PLAYER.replaceAll("_", " ")), nodes.indexWhere(_.name == t.PASS_TO.replaceAll("_", " ")), t.PASS / 20 + 1)
  showGraph(height, width, Seq(Graph(nodes, links)).toDF().toJSON.first())
}

/**
 * Displays a force directed graph using d3
 * input: {"nodes": [{"name": "..."}], "links": [{"source": 1, "target": 2, "value": 0}]}
 */
def showGraph(height: Int, width: Int, graph: String): Unit = {

displayHTML(s"""
<!DOCTYPE html>
<html>
<head>
  <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
  <title>Polish Books Themes - an Interactive Map</title>
  <meta charset="utf-8">
<style>

.node_circle {
  stroke: #777;
  stroke-width: 1.3px;
}

.node_label {
  pointer-events: none;
}

.link {
  stroke: #777;
  stroke-opacity: .2;
}

.node_count {
  stroke: #777;
  stroke-width: 1.0px;
  fill: #999;
}

text.legend {
  font-family: Verdana;
  font-size: 13px;
  fill: #000;
}

.node text {
  font-family: "Helvetica Neue","Helvetica","Arial",sans-serif;
  font-size: 17px;
  font-weight: 200;
}

</style>
</head>

<body>
<script src="//d3js.org/d3.v3.min.js"></script>
<script>

var graph = $graph;

var width = $width,
    height = $height;

var color = d3.scale.category10();

var force = d3.layout.force()
    .charge(-700)
    .linkDistance(350)
    .size([width, height]);

var svg = d3.select("body").append("svg")
    .attr("width", width)
    .attr("height", height);
    
force
    .nodes(graph.nodes)
    .links(graph.links)
    .start();

var link = svg.selectAll(".link")
    .data(graph.links)
    .enter().append("line")
    .attr("class", "link")
    .style("stroke-width", function(d) { return Math.sqrt(d.value); });

var node = svg.selectAll(".node")
    .data(graph.nodes)
    .enter().append("g")
    .attr("class", "node")
    .call(force.drag);

node.append("circle")
    .attr("r", function(d) { return d.pagerank*10+4 ;})
    .style("fill", function(d) { return color(d.label);})
    .style("opacity", 0.5)

node.append("text")
      .attr("dx", 10)
      .attr("dy", ".35em")
      .text(function(d) { return d.name });
      
//Now we are giving the SVGs co-ordinates - the force layout is generating the co-ordinates which this code is using to update the attributes of the SVG elements
force.on("tick", function () {
    link.attr("x1", function (d) {
        return d.source.x;
    })
        .attr("y1", function (d) {
        return d.source.y;
    })
        .attr("x2", function (d) {
        return d.target.x;
    })
        .attr("y2", function (d) {
        return d.target.y;
    });
    d3.selectAll("circle").attr("cx", function (d) {
        return d.x;
    })
        .attr("cy", function (d) {
        return d.y;
    });
    d3.selectAll("text").attr("x", function (d) {
        return d.x;
    })
        .attr("y", function (d) {
        return d.y;
    });
});
</script>
</html>
""")
}
  
  def help() = {
displayHTML("""
<p>
Produces a force-directed graph given a collection of edges of the following form:</br>
<tt><font color="#a71d5d">case class</font> <font color="#795da3">Edge</font>(<font color="#ed6a43">PLAYER</font>: <font color="#a71d5d">String</font>, <font color="#ed6a43">PASS_TO</font>: <font color="#a71d5d">String</font>, <font color="#ed6a43">PASS</font>: <font color="#a71d5d">Long</font>, <font color="#ed6a43">label</font>: <font color="#a71d5d">Double</font>, <font color="#ed6a43">pagerank</font>: <font color="#a71d5d">Double</font>)</tt>
</p>
<p>Usage:<br/>
<tt>%scala</tt></br>
<tt><font color="#a71d5d">import</font> <font color="#ed6a43">d3._</font></tt><br/>
<tt><font color="#795da3">graphs.force</font>(</br>
&nbsp;&nbsp;<font color="#ed6a43">height</font> = <font color="#795da3">500</font>,<br/>
&nbsp;&nbsp;<font color="#ed6a43">width</font> = <font color="#795da3">500</font>,<br/>
&nbsp;&nbsp;<font color="#ed6a43">clicks</font>: <font color="#795da3">Dataset</font>[<font color="#795da3">Edge</font>])</tt>
</p>""")
  }
}


%scala 
import d3._
// print the help for the graphing library
d3.graphs.help()


%scala
import d3._

graphs.force(
  height = 800,
  width = 1000,
  network = sql("""
    SELECT 
      PLAYER,
      PASS_TO,
      PASS,
      label,
      pagerank
      FROM network
    """).as[Edge])


Updated: