008_DiamondsPipeline_01ETLEDA(Scala)

Loading...

ScaDaMaLe Course site and book

Diamonds ML Pipeline Workflow - DataFrame ETL and EDA Part

This is the Spark SQL parts that are focussed on extract-transform-Load (ETL) and exploratory-data-analysis (EDA) parts of an end-to-end example of a Machine Learning (ML) workflow.

Why are we using DataFrames? This is because of the Announcement in the Spark MLlib Main Guide for Spark 2.2 https://spark.apache.org/docs/latest/ml-guide.html that "DataFrame-based API is primary API".

This notebook is a scalarific break-down of the pythonic 'Diamonds ML Pipeline Workflow' from the Databricks Guide.

We will see this example again in the sequel

For this example, we analyze the Diamonds dataset from the R Datasets hosted on DBC.

Later on, we will use the DecisionTree algorithm to predict the price of a diamond from its characteristics.

Here is an outline of our pipeline:

  • Step 1. Load data: Load data as DataFrame
  • Step 2. Understand the data: Compute statistics and create visualizations to get a better understanding of the data.
  • Step 3. Hold out data: Split the data randomly into training and test sets. We will not look at the test data until after learning.
  • Step 4. On the training dataset:
    • Extract features: We will index categorical (String-valued) features so that DecisionTree can handle them.
    • Learn a model: Run DecisionTree to learn how to predict a diamond's price from a description of the diamond.
    • Tune the model: Tune the tree depth (complexity) using the training data. (This process is also called model selection.)
  • Step 5. Evaluate the model: Now look at the test dataset. Compare the initial model with the tuned model to see the benefit of tuning parameters.
  • Step 6. Understand the model: We will examine the learned model and results to gain further insight.

In this notebook, we will only cover Step 1 and Step 2. above. The other Steps will be revisited in the sequel.

Step 1. Load data as DataFrame

This section loads a dataset as a DataFrame and examines a few rows of it to understand the schema.

For more info, see the DB guide on importing data.

// We'll use the Diamonds dataset from the R datasets hosted on DBC.
val diamondsFilePath = "dbfs:/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv"
diamondsFilePath: String = dbfs:/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv
sc.textFile(diamondsFilePath).take(2) // looks like a csv file as it should
res0: Array[String] = Array("","carat","cut","color","clarity","depth","table","price","x","y","z", "1",0.23,"Ideal","E","SI2",61.5,55,326,3.95,3.98,2.43)
val diamondsRawDF = sqlContext.read    // we can use sqlContext instead of SparkSession for backwards compatibility to 1.x
    .format("com.databricks.spark.csv") // use spark.csv package
    .option("header", "true") // Use first line of all files as header
    .option("inferSchema", "true") // Automatically infer data types
    //.option("delimiter", ",") // Specify the delimiter as comma or ',' DEFAULT
    .load(diamondsFilePath)
diamondsRawDF: org.apache.spark.sql.DataFrame = [_c0: int, carat: double ... 9 more fields]
//There are 10 columns.  We will try to predict the price of diamonds, treating the other 9 columns as features.
diamondsRawDF.printSchema()
root |-- _c0: integer (nullable = true) |-- carat: double (nullable = true) |-- cut: string (nullable = true) |-- color: string (nullable = true) |-- clarity: string (nullable = true) |-- depth: double (nullable = true) |-- table: double (nullable = true) |-- price: integer (nullable = true) |-- x: double (nullable = true) |-- y: double (nullable = true) |-- z: double (nullable = true)

Note: (nullable = true) simply means if the value is allowed to be null.

Let us count the number of rows in diamondsDF.

diamondsRawDF.count() // Ctrl+Enter
res3: Long = 53940

So there are 53940 records or rows in the DataFrame.

Use the show(n) method to see the first n (default is 20) rows of the DataFrame, as folows:

diamondsRawDF.show(10)
+---+-----+---------+-----+-------+-----+-----+-----+----+----+----+ |_c0|carat| cut|color|clarity|depth|table|price| x| y| z| +---+-----+---------+-----+-------+-----+-----+-----+----+----+----+ | 1| 0.23| Ideal| E| SI2| 61.5| 55.0| 326|3.95|3.98|2.43| | 2| 0.21| Premium| E| SI1| 59.8| 61.0| 326|3.89|3.84|2.31| | 3| 0.23| Good| E| VS1| 56.9| 65.0| 327|4.05|4.07|2.31| | 4| 0.29| Premium| I| VS2| 62.4| 58.0| 334| 4.2|4.23|2.63| | 5| 0.31| Good| J| SI2| 63.3| 58.0| 335|4.34|4.35|2.75| | 6| 0.24|Very Good| J| VVS2| 62.8| 57.0| 336|3.94|3.96|2.48| | 7| 0.24|Very Good| I| VVS1| 62.3| 57.0| 336|3.95|3.98|2.47| | 8| 0.26|Very Good| H| SI1| 61.9| 55.0| 337|4.07|4.11|2.53| | 9| 0.22| Fair| E| VS2| 65.1| 61.0| 337|3.87|3.78|2.49| | 10| 0.23|Very Good| H| VS1| 59.4| 61.0| 338| 4.0|4.05|2.39| +---+-----+---------+-----+-------+-----+-----+-----+----+----+----+ only showing top 10 rows

If you notice the schema of diamondsRawDF you will see that the automatic schema inference of SqlContext.read method has cast the values in the column price as integer.

To cleanup:

  • let's recast the column price as double for downstream ML tasks later and
  • let's also get rid of the first column of row indices.
import org.apache.spark.sql.types.DoubleType
//we will convert price column from int to double for being able to model, fit and predict in downstream ML task
val diamondsDF = diamondsRawDF.select($"carat", $"cut", $"color", $"clarity", $"depth", $"table",$"price".cast(DoubleType).as("price"), $"x", $"y", $"z")
diamondsDF.cache() // let's cache it for reuse
diamondsDF.printSchema // print schema
root |-- carat: double (nullable = true) |-- cut: string (nullable = true) |-- color: string (nullable = true) |-- clarity: string (nullable = true) |-- depth: double (nullable = true) |-- table: double (nullable = true) |-- price: double (nullable = true) |-- x: double (nullable = true) |-- y: double (nullable = true) |-- z: double (nullable = true) import org.apache.spark.sql.types.DoubleType diamondsDF: org.apache.spark.sql.DataFrame = [carat: double, cut: string ... 8 more fields]
diamondsDF.show(10,false) // notice that price column has Double values that end in '.0' now
+-----+---------+-----+-------+-----+-----+-----+----+----+----+ |carat|cut |color|clarity|depth|table|price|x |y |z | +-----+---------+-----+-------+-----+-----+-----+----+----+----+ |0.23 |Ideal |E |SI2 |61.5 |55.0 |326.0|3.95|3.98|2.43| |0.21 |Premium |E |SI1 |59.8 |61.0 |326.0|3.89|3.84|2.31| |0.23 |Good |E |VS1 |56.9 |65.0 |327.0|4.05|4.07|2.31| |0.29 |Premium |I |VS2 |62.4 |58.0 |334.0|4.2 |4.23|2.63| |0.31 |Good |J |SI2 |63.3 |58.0 |335.0|4.34|4.35|2.75| |0.24 |Very Good|J |VVS2 |62.8 |57.0 |336.0|3.94|3.96|2.48| |0.24 |Very Good|I |VVS1 |62.3 |57.0 |336.0|3.95|3.98|2.47| |0.26 |Very Good|H |SI1 |61.9 |55.0 |337.0|4.07|4.11|2.53| |0.22 |Fair |E |VS2 |65.1 |61.0 |337.0|3.87|3.78|2.49| |0.23 |Very Good|H |VS1 |59.4 |61.0 |338.0|4.0 |4.05|2.39| +-----+---------+-----+-------+-----+-----+-----+----+----+----+ only showing top 10 rows
//View DataFrame in databricks
// note this 'display' is a databricks notebook specific command that is quite powerful for visual interaction with the data
// other notebooks like zeppelin have similar commands for interactive visualisation
display(diamondsDF) 
 
carat
cut
color
clarity
depth
table
price
x
y
z
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
0.23
Ideal
E
SI2
61.5
55
326
3.95
3.98
2.43
0.21
Premium
E
SI1
59.8
61
326
3.89
3.84
2.31
0.23
Good
E
VS1
56.9
65
327
4.05
4.07
2.31
0.29
Premium
I
VS2
62.4
58
334
4.2
4.23
2.63
0.31
Good
J
SI2
63.3
58
335
4.34
4.35
2.75
0.24
Very Good
J
VVS2
62.8
57
336
3.94
3.96
2.48
0.24
Very Good
I
VVS1
62.3
57
336
3.95
3.98
2.47
0.26
Very Good
H
SI1
61.9
55
337
4.07
4.11
2.53
0.22
Fair
E
VS2
65.1
61
337
3.87
3.78
2.49
0.23
Very Good
H
VS1
59.4
61
338
4
4.05
2.39
0.3
Good
J
SI1
64
55
339
4.25
4.28
2.73
0.23
Ideal
J
VS1
62.8
56
340
3.93
3.9
2.46
0.22
Premium
F
SI1
60.4
61
342
3.88
3.84
2.33
0.31
Ideal
J
SI2
62.2
54
344
4.35
4.37
2.71
0.2
Premium
E
SI2
60.2
62
345
3.79
3.75
2.27
0.32
Premium
E
I1
60.9
58
345
4.38
4.42
2.68
0.3
Ideal
I
SI2
62
54
348
4.31
4.34
2.68

Truncated results, showing first 1000 rows.