04_EvaluateFrenchPOSModelBySparkMultiClassMetrics(Scala)

Evaluate French POS Model

by Spark's Multi-class Metrics

CoNLL: Computational Natural Language Learning

The Language-Independent Named Entity Recognition task introduced at CoNLL-2003 measures the performance of the systems in terms of precision, recall and f1-score, where:

precision is the percentage of named entities found by the learning system that are correct. Recall is the percentage of named entities present in the corpus that are found by the system. A named entity is correct only if it is an exact match of the corresponding entity in the data file.

I. Surface string and entity type match II. System hypothesized an entity III. System misses an entity

spark.version
res0: String = 2.4.0

Download French pretrained PerceptronModel by Spark NLP

Model:

  • pos_ud-gsd_fr

Versions:

  • 1.8.x
  • 2.0.1

This model was trained agains UD_French-GSD, more specifically by using only the train dataset:

  • fr-ud-train.conll:
    • 14,554 sentences
    • 356,638 words

Download the pretrained French POS model

%sh 
curl -o /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/models/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0 8 3462k 8 287k 0 0 318k 0 0:00:10 --:--:-- 0:00:10 318k 100 3462k 100 3462k 0 0 2450k 0 0:00:01 0:00:01 --:--:-- 2450k
  • Checkout where it saves it
  • extract the model
%sh
ls -l "/dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip"

-rw-r--r-- 1 root root 3545833 Jun 17 22:11 /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip
%sh
unzip -o /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip -d /dbfs/FileStore/tables/

Archive: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307.zip creating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/ creating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/ creating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/ inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.DS_Store creating: /dbfs/FileStore/tables/__MACOSX/ creating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/ creating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/ creating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/ inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.DS_Store inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00006.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00006.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00007.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00007.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00011.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00011.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00005.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00005.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00004.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00004.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00010.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00010.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00000.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00000.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00001.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00001.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00003.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00003.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00002.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00002.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00009.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00009.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/.part-00008.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._.part-00008.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/part-00000 inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/POS Model/._part-00000 inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/fields/._POS Model inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/._fields creating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/ inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/.DS_Store creating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/ inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/._.DS_Store inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/.part-00000.crc inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/._.part-00000.crc inflating: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/part-00000 inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/metadata/._part-00000 inflating: /dbfs/FileStore/tables/__MACOSX/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/._metadata inflating: /dbfs/FileStore/tables/__MACOSX/._pos_ud-gsd_fr_2.0.0_2.4_1553029753307
%sh
ls -l /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/

total 0 drwxr-xr-x 1 root root 0 Jan 1 1970 fields drwxr-xr-x 1 root root 0 Jan 1 1970 metadata

So this is the path to our pretrained French POS model: /dbfs/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.evaluation.MultilabelMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml._
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.functions.{col, concat_ws, lit, split, udf}
import org.apache.spark.ml.feature.NGram
import org.apache.spark.ml.Pipeline

import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import com.johnsnowlabs.nlp.{DocumentAssembler, Finisher}
import com.johnsnowlabs.nlp.annotators.{Normalizer, Stemmer, Tokenizer}
import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.util.Benchmark

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._

import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.spark.rdd.RDD
import scala.collection.mutable.ArrayBuffer

import scala.util.control.Breaks._

import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.evaluation.MultilabelMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.ml._ import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.functions.{col, concat_ws, lit, split, udf} import org.apache.spark.ml.feature.NGram import org.apache.spark.ml.Pipeline import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} import com.johnsnowlabs.nlp.{DocumentAssembler, Finisher} import com.johnsnowlabs.nlp.annotators.{Normalizer, Stemmer, Tokenizer} import com.johnsnowlabs.nlp.annotator._ import com.johnsnowlabs.nlp.base._ import com.johnsnowlabs.util.Benchmark import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.io.{LongWritable, Text} import org.apache.spark.rdd.RDD import scala.collection.mutable.ArrayBuffer import scala.util.control.Breaks._ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.io.{LongWritable, Text} import org.apache.spark.rdd.RDD

Define UDFs to extract tokens and labels from the test dataset

Format: CoNLL-U Multi-words: only original format

  • There are several closed classes of contractions that are treated as multi-word tokens and segmented to individual syntactic words. For instance, au -> à + le, auquel -> de + lequel. Note that du and des are ambiguous and can be split or not depending of their usage.
  • This corpus contains 11025 multi-word tokens. On average, one multi-word token consists of 2.00 syntactic words.
  • There are 9 types of multi-word tokens. Examples: du, des, au, aux, auquel, duquel, auxquels, desquelles, auxquelles.

These UDFs will help to extract tokens and labels from a CoNLL-U test dataset

 def extractTokens = udf { docs: Seq[String] =>
    var posTagsArray = ArrayBuffer[String]()
    var previousSentenceNumber = Array[String]()
    for ((e,i) <- docs.zipWithIndex){
      val splitedArray = e.split("\t")      
      val currentSentenceNumber = splitedArray(0).split("-")

      if(currentSentenceNumber.length > 1){
        previousSentenceNumber = currentSentenceNumber
        val nextSentence = docs(i+1).split("\t")
        posTagsArray += splitedArray(1)
      }else if(previousSentenceNumber.contains(currentSentenceNumber(0))){

      }else{
        posTagsArray += splitedArray(1)
      }
    }
    posTagsArray
  }

  def extractTags = udf { docs: Seq[String] =>
    var posTagsArray = ArrayBuffer[String]()
    var previousSentenceNumber = Array[String]()
    for ((e,i) <- docs.zipWithIndex){
      val splitedArray = e.split("\t")

      val currentSentenceNumber = splitedArray(0).split("-")
      if(currentSentenceNumber.length > 1){
        previousSentenceNumber = currentSentenceNumber
        val nextSentence = docs(i+1).split("\t")
        posTagsArray += nextSentence(3)

      }else if(previousSentenceNumber.contains(currentSentenceNumber(0))){

      }else{
        posTagsArray += splitedArray(3)
      }
    }
    posTagsArray
  }
// these are for some checks afterwards
 def extractMissingTokens= udf { (testTokens: Seq[String], predictTokens: Seq[String]) =>
    var missingTokensArray = ArrayBuffer[String]()

    for (e <- testTokens) {
      if (!predictTokens.contains(e)) {
        missingTokensArray += e
      }
    }
    missingTokensArray
  }
  def calLengthOfArray = udf { docs: Seq[String] =>
    docs.length
  }
extractTokens: org.apache.spark.sql.expressions.UserDefinedFunction extractTags: org.apache.spark.sql.expressions.UserDefinedFunction extractMissingTokens: org.apache.spark.sql.expressions.UserDefinedFunction calLengthOfArray: org.apache.spark.sql.expressions.UserDefinedFunction

We will be using the Gold tokenization, so we don't need to use Tokenizer. We are going to use the gold tokens and convert them into annotatorType.TOKEN

def customeTokenizer: UserDefinedFunction = udf { (tokens: Seq[String], text: String, sentenceIndex: String) =>
  lazy val strTokens = tokens.mkString("#")

  val tokenTagAnnotation: ArrayBuffer[Annotation] = ArrayBuffer()
  def annotatorType: String = AnnotatorType.TOKEN
  var lastIndex = 0

  for ((e, i) <- tokens.zipWithIndex) {

    val beginOfToken = text.indexOfSlice(e, lastIndex)
    val endOfToken = (beginOfToken + e.length) - 1

    val fullTokenAnnotatorStruct = new Annotation(
      annotatorType = annotatorType,
      begin=beginOfToken,
      end=endOfToken,
      result=e,
      metadata=Map("sentence" -> sentenceIndex)
    )
    tokenTagAnnotation += fullTokenAnnotatorStruct
    lastIndex = text.indexOfSlice(e, lastIndex)
  }
  tokenTagAnnotation
}
def wrapColumnMetadata(col: Column, annotatorType: String, outPutColName: String): Column = {
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
    metadataBuilder.putString("annotatorType", annotatorType)
    col.as(outPutColName, metadataBuilder.build)
}
customeTokenizer: org.apache.spark.sql.expressions.UserDefinedFunction wrapColumnMetadata: (col: org.apache.spark.sql.Column, annotatorType: String, outPutColName: String)org.apache.spark.sql.Column

Let's download the test dataset from repository

%sh
wget -N https://github.com/UniversalDependencies/UD_French-GSD/raw/master/fr_gsd-ud-test.conllu -P /dbfs/FileStore/tables/
--2019-06-17 22:12:09-- https://github.com/UniversalDependencies/UD_French-GSD/raw/master/fr_gsd-ud-test.conllu Resolving github.com (github.com)... 192.30.255.112 Connecting to github.com (github.com)|192.30.255.112|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://raw.githubusercontent.com/UniversalDependencies/UD_French-GSD/master/fr_gsd-ud-test.conllu [following] --2019-06-17 22:12:10-- https://raw.githubusercontent.com/UniversalDependencies/UD_French-GSD/master/fr_gsd-ud-test.conllu Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.188.133 Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.188.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 623344 (609K) [text/plain] Saving to: ‘/dbfs/FileStore/tables/fr_gsd-ud-test.conllu’ 0K .......... .......... .......... .......... .......... 8% 11.0M 0s 50K .......... .......... .......... .......... .......... 16% 13.8M 0s 100K .......... .......... .......... .......... .......... 24% 14.8M 0s 150K .......... .......... .......... .......... .......... 32% 13.3M 0s 200K .......... .......... .......... .......... .......... 41% 15.7M 0s 250K .......... .......... .......... .......... .......... 49% 15.4M 0s 300K .......... .......... .......... .......... .......... 57% 14.3M 0s 350K .......... .......... .......... .......... .......... 65% 9.09M 0s 400K .......... .......... .......... .......... .......... 73% 16.7M 0s 450K .......... .......... .......... .......... .......... 82% 17.5M 0s 500K .......... .......... .......... .......... .......... 90% 14.3M 0s 550K .......... .......... .......... .......... .......... 98% 14.1M 0s 600K ........ 100% 18.0M=0.04s Last-modified header missing -- time-stamps turned off. 2019-06-17 22:12:10 (13.8 MB/s) - ‘/dbfs/FileStore/tables/fr_gsd-ud-test.conllu’ saved [623344/623344]
%sh
ls -l /dbfs/FileStore/tables/fr_gsd-ud-test.conllu
-rw-r--r-- 1 root root 623344 Jun 17 22:12 /dbfs/FileStore/tables/fr_gsd-ud-test.conllu

Now let's put it all together

import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat

val pathCoNNLFile = "/FileStore/tables/fr_gsd-ud-test.conllu"

// We need to change the Hadoop's default delimiter from 1 newline to 2 newlines
// This way we can catch CoNLL-U 2x newlines that seperates sentences
// This will show deprecation warning, don't mind it for the time being :)
// NOTE: we should make these transient to avoid task not serializable
@transient val conf = new org.apache.hadoop.mapreduce.Job().getConfiguration
conf.set("textinputformat.record.delimiter", "\n\n")

@transient val usgRDD = sc.newAPIHadoopFile(pathCoNNLFile, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], conf).map{ case (_, v) => v.toString }
  
@transient val conllSentencesDF = usgRDD.map(s => s.split("\n").filter(x => !x.startsWith("#"))).filter(x => x.length > 0).toDF("sentence")

conf.set("textinputformat.record.delimiter", "\n")

// testTokensTagsDF will be test dataset's sentences, tokens, and tags.
val testTokensTagsDF = conllSentencesDF
.withColumn("id", monotonically_increasing_id)
.withColumn("testTokens", extractTokens($"sentence"))
.withColumn("testTags", extractTags($"sentence"))
.withColumn("text", concat_ws(" ", $"testTokens"))
.drop("sentence")

testTokensTagsDF.show()
+---+--------------------+--------------------+--------------------+ | id| testTokens| testTags| text| +---+--------------------+--------------------+--------------------+ | 0|[Je, sens, qu', e...|[PRON, VERB, SCON...|Je sens qu' entre...| | 1|[On, pourra, touj...|[PRON, VERB, ADV,...|On pourra toujour...| | 2|[«, Il, a, été, l...|[PUNCT, PRON, AUX...|« Il a été largem...| | 3|[1er, :, début, d...|[NUM, PUNCT, NOUN...|1er : début de la...| | 4|[Et, pourtant, ,,...|[CCONJ, ADV, PUNC...|Et pourtant , lor...| | 5|[Les, spéculation...|[DET, NOUN, ADV, ...|Les spéculations ...| | 6|[Ils, ne, citent,...|[PRON, ADV, VERB,...|Ils ne citent pas...| | 7|[Il, y, en, a, qu...|[PRON, PRON, PRON...|Il y en a qui cro...| | 8|[Son, discours, ,...|[DET, NOUN, PUNCT...|Son discours , in...| | 9|[Pour, lui, ,, je...|[ADP, PRON, PUNCT...|Pour lui , je cro...| | 10|[Est, -ce, le, fa...|[AUX, PRON, DET, ...|Est -ce le fait d...| | 11|[Maintenant, ,, a...|[ADV, PUNCT, ADP,...|Maintenant , avec...| | 12|[Royale, donc, ,,...|[ADJ, ADV, PUNCT,...|Royale donc , cet...| | 13|[Pour, pouvoir, m...|[ADP, VERB, VERB,...|Pour pouvoir mene...| | 14|[Belle, ,, grande...|[ADJ, PUNCT, ADJ,...|Belle , grande , ...| | 15|[Et, si, officiel...|[CCONJ, SCONJ, AD...|Et si officiellem...| | 16|[Trois, ans, plus...|[NUM, NOUN, ADV, ...|Trois ans plus ta...| | 17|[Je, suis, déchir...|[PRON, AUX, ADJ, ...| Je suis déchirée .| | 18|[Selon, lui, ,, c...|[ADP, PRON, PUNCT...|Selon lui , cette...| | 19|[Le, 20, juin, ,,...|[DET, NUM, NOUN, ...|Le 20 juin , le S...| +---+--------------------+--------------------+--------------------+ only showing top 20 rows notebook:12: warning: constructor Job in class Job is deprecated: see corresponding Javadoc for more information. @transient val conf = new org.apache.hadoop.mapreduce.Job().getConfiguration ^ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.lib.input.TextInputFormat pathCoNNLFile: String = /FileStore/tables/fr_gsd-ud-test.conllu conf: org.apache.hadoop.conf.Configuration = Configuration: core-default.xml, core-site.xml, mapred-default.xml, mapred-site.xml, yarn-default.xml, yarn-site.xml, hdfs-default.xml, hdfs-site.xml usgRDD: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[655] at map at command-1915956987763339:15 conllSentencesDF: org.apache.spark.sql.DataFrame = [sentence: array<string>] testTokensTagsDF: org.apache.spark.sql.DataFrame = [id: bigint, testTokens: array<string> ... 2 more fields]

As you can see we already extracted tokens, labels, and the original text from fr_gsd-ud-test.conllu file. Now let's use the text and tokens to predict the labels by using the pretrained POS model.

// Annotate the documents
// We don't need SentenceDetector nor Tokenizer since in this example we'll use golden sentences/tokens

val documentAssembler = new DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")
    .transform(testTokensTagsDF)

// create Tokenizer column based on golden tokens
// we only select document, sentence and tokens to feed into our POS Model for prediction
val goldenTokenizer = documentAssembler
.withColumn("documentText", $"document.result"(0))
.withColumn("sentenceIndex", lit("0"))
.withColumn("token", customeTokenizer($"testTokens", $"documentText", $"sentenceIndex"))
.withColumn("token", wrapColumnMetadata($"token", AnnotatorType.TOKEN, "token"))
.select("id", "document", "token")

// Let's check make sure we got Tokenizer column right
val datasetSchemaFields = goldenTokenizer.schema.fields
        .find(f => f.metadata.contains("annotatorType") && f.metadata.getString("annotatorType") == AnnotatorType.TOKEN)

val tokenColumn = datasetSchemaFields.map(_.name).get

// Let's load our pretrained POS Model
// As you can see I am only using the document and token columns to output prediction column called pos
val perceptronModel = PerceptronModel.load("/FileStore/tables/pos_ud-gsd_fr_2.0.0_2.4_1553029753307/")
    .setInputCols(Array("document", "token"))
    .setOutputCol("pos")
    .transform(goldenTokenizer)
    .select(
        $"id",
        $"document",
        $"token",
        $"token.result".alias("predictedTokens"),
        $"pos.result".alias("predictedTags")
    )

perceptronModel.show

+---+--------------------+--------------------+--------------------+--------------------+ | id| document| token| predictedTokens| predictedTags| +---+--------------------+--------------------+--------------------+--------------------+ | 0|[[document, 0, 15...|[[token, 0, 1, Je...|[Je, sens, qu', e...|[PRON, NOUN, SCON...| | 1|[[document, 0, 76...|[[token, 0, 1, On...|[On, pourra, touj...|[PRON, VERB, ADV,...| | 2|[[document, 0, 31...|[[token, 0, 0, «,...|[«, Il, a, été, l...|[PUNCT, PRON, AUX...| | 3|[[document, 0, 19...|[[token, 0, 2, 1e...|[1er, :, début, d...|[ADJ, PUNCT, NOUN...| | 4|[[document, 0, 20...|[[token, 0, 1, Et...|[Et, pourtant, ,,...|[CCONJ, ADV, PUNC...| | 5|[[document, 0, 55...|[[token, 0, 2, Le...|[Les, spéculation...|[DET, NOUN, ADV, ...| | 6|[[document, 0, 15...|[[token, 0, 2, Il...|[Ils, ne, citent,...|[PRON, ADV, VERB,...| | 7|[[document, 0, 25...|[[token, 0, 1, Il...|[Il, y, en, a, qu...|[PRON, ADV, PRON,...| | 8|[[document, 0, 69...|[[token, 0, 2, So...|[Son, discours, ,...|[DET, NOUN, PUNCT...| | 9|[[document, 0, 15...|[[token, 0, 3, Po...|[Pour, lui, ,, je...|[ADP, PRON, PUNCT...| | 10|[[document, 0, 12...|[[token, 0, 2, Es...|[Est, -ce, le, fa...|[VERB, PRON, DET,...| | 11|[[document, 0, 19...|[[token, 0, 9, Ma...|[Maintenant, ,, a...|[ADV, PUNCT, ADP,...| | 12|[[document, 0, 10...|[[token, 0, 5, Ro...|[Royale, donc, ,,...|[ADJ, ADV, PUNCT,...| | 13|[[document, 0, 18...|[[token, 0, 3, Po...|[Pour, pouvoir, m...|[ADP, VERB, VERB,...| | 14|[[document, 0, 86...|[[token, 0, 4, Be...|[Belle, ,, grande...|[ADJ, PUNCT, ADJ,...| | 15|[[document, 0, 16...|[[token, 0, 1, Et...|[Et, si, officiel...|[CCONJ, SCONJ, AD...| | 16|[[document, 0, 56...|[[token, 0, 4, Tr...|[Trois, ans, plus...|[NUM, NOUN, ADV, ...| | 17|[[document, 0, 17...|[[token, 0, 1, Je...|[Je, suis, déchir...|[PRON, AUX, VERB,...| | 18|[[document, 0, 27...|[[token, 0, 4, Se...|[Selon, lui, ,, c...|[ADP, PRON, PUNCT...| | 19|[[document, 0, 14...|[[token, 0, 1, Le...|[Le, 20, juin, ,,...|[DET, NUM, NOUN, ...| +---+--------------------+--------------------+--------------------+--------------------+ only showing top 20 rows documentAssembler: org.apache.spark.sql.DataFrame = [id: bigint, testTokens: array<string> ... 3 more fields] goldenTokenizer: org.apache.spark.sql.DataFrame = [id: bigint, document: array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>,sentence_embeddings:array<float>>> ... 1 more field] datasetSchemaFields: Option[org.apache.spark.sql.types.StructField] = Some(StructField(token,ArrayType(StructType(StructField(annotatorType,StringType,true), StructField(begin,IntegerType,false), StructField(end,IntegerType,false), StructField(result,StringType,true), StructField(metadata,MapType(StringType,StringType,true),true), StructField(embeddings,ArrayType(FloatType,false),true), StructField(sentence_embeddings,ArrayType(FloatType,false),true)),true),true)) tokenColumn: String = token perceptronModel: org.apache.spark.sql.DataFrame = [id: bigint, document: array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>,sentence_embeddings:array<float>>> ... 3 more fields]

Now it's time to do some checks to make sure our test tokens/tags are the same numbers as our predicted tokens/tags. (since we are using the gold tokenization, it should be the same number)

val joinedDF = perceptronModel
  .join(testTokensTagsDF, Seq("id"))
  .withColumn("predictedTokensLength", calLengthOfArray($"predictedTokens"))
  .withColumn("predictedTagsLength", calLengthOfArray($"predictedTags"))
  .withColumn("testTokensLength", calLengthOfArray($"testTokens"))
  .withColumn("testTagsLength", calLengthOfArray($"testTags"))
  .withColumn("tokensDiffFromTest", $"testTokensLength" - $"predictedTokensLength")
  .withColumn("tagsDiffFromTest", $"testTagsLength" - $"predictedTagsLength")
  .withColumn("missingTokens", extractMissingTokens($"testTokens", $"predictedTokens"))
  .withColumn("missingTags", extractMissingTokens($"testTags", $"predictedTags"))
  .withColumn("equalTags", col("predictedTagsLength") === col("testTagsLength"))

joinedDF.agg(
    sum("predictedTokensLength").as("p_token_lengh"),
    sum("predictedTagsLength").as("p_tags_lengh"),
    sum("testTokensLength").as("t_token_lengh"),
    sum("predictedTagsLength").as("t_tags_lengh"),
    sum("tokensDiffFromTest").as("tokensDiffFromTest"),
    sum("tagsDiffFromTest").as("tagsDiffFromTest")
    ).show
+-------------+------------+-------------+------------+------------------+----------------+ |p_token_lengh|p_tags_lengh|t_token_lengh|t_tags_lengh|tokensDiffFromTest|tagsDiffFromTest| +-------------+------------+-------------+------------+------------------+----------------+ | 9739| 9739| 9739| 9739| 0| 0| +-------------+------------+-------------+------------+------------------+----------------+ joinedDF: org.apache.spark.sql.DataFrame = [id: bigint, document: array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>,sentence_embeddings:array<float>>> ... 15 more fields]

We have 9742 tokens and obviously the same of tags. Now let's evaluate and see some metrics regarding our model

spark.catalog.clearCache
joinedDF.count
res7: Long = 416

These are our UD-POS labels

ADJ: adjective
ADP: adposition
ADV: adverb
AUX: auxiliary
CCONJ: coordinating conjunction
DET: determiner
INTJ: interjection
NOUN: noun
NUM: numeral
PART: particle
PRON: pronoun
PROPN: proper noun
PUNCT: punctuation
SCONJ: subordinating conjunction
SYM: symbol
VERB: verb
X: other

In order to use Apache Spark MulticlassMetrics, we have to have 2 columns (label, prediction) only in Double type. The following code will explode the array of labels from test and prediction columns and map them to a unique double index.

case class NewColumns(label: String, prediction: String, labelIndex: Double, predictionIndex: Double)

val finalPredictionDF = joinedDF.select("testTokens", "testTags", "predictedTokens", "predictedTags").flatMap(row => {
    val labelsMap: List[(String, Double)] = List(
        ("ADJ",1.0), ("ADP",2.0), ("ADV",3.0),
        ("AUX",4.0), ("CCONJ",5.0), ("DET",6.0),
        ("INTJ",7.0), ("NOUN",8.0), ("NUM",9.0),
        ("PART",10.0), ("PRON",11.0), ("PROPN",12.0),
        ("PUNCT",13.0), ("SCONJ",14.0), ("SYM",15.0),
        ("VERB",16.0), ("X",17.0)
    )
    val finalRow: ArrayBuffer[Seq[NewColumns]] = ArrayBuffer()
    val newRow: ArrayBuffer[NewColumns] = ArrayBuffer()
    
    val testTagsWithTokens = row.get(0).asInstanceOf[Seq[String]].zip(row.getSeq(1).asInstanceOf[Seq[String]]).map{case (k,v) => (k,v)}
    var predictTagsWithTokens = row.getSeq(2).asInstanceOf[Seq[String]].zip(row.getSeq(3).asInstanceOf[Seq[String]]).map{case (k,v) => (k,v)}
    
    val testTags = row.getSeq(1).asInstanceOf[Seq[String]]
    val predictTags = row.getSeq(3).asInstanceOf[Seq[String]]
       
    for ((t,i) <- testTags.zipWithIndex) {

        val labelIndex = labelsMap.find(_._1.matches(t)).map(_._2).getOrElse(0.0).asInstanceOf[Double]
        val predictionIndex = labelsMap.find(_._1.matches(predictTags(i))).map(_._2).getOrElse(0.0).asInstanceOf[Double]
        
        newRow += NewColumns(t, predictTags(i), labelIndex, predictionIndex)
    }
    finalRow.append(newRow)
    finalRow
})
.select(explode($"value").as("valueArray"))
.withColumn("label", $"valueArray.label")
.withColumn("prediction", $"valueArray.prediction")
.withColumn("labelIndex", $"valueArray.labelIndex")
.withColumn("predictionIndex", $"valueArray.predictionIndex")
.drop("value", "valueArray")

finalPredictionDF.count
finalPredictionDF.show(2, false)
finalPredictionDF.printSchema  
+-----+----------+----------+---------------+ |label|prediction|labelIndex|predictionIndex| +-----+----------+----------+---------------+ |ADP |ADP |2.0 |2.0 | |PRON |PRON |11.0 |11.0 | +-----+----------+----------+---------------+ only showing top 2 rows root |-- label: string (nullable = true) |-- prediction: string (nullable = true) |-- labelIndex: double (nullable = true) |-- predictionIndex: double (nullable = true) defined class NewColumns finalPredictionDF: org.apache.spark.sql.DataFrame = [label: string, prediction: string ... 2 more fields]

Spark MulticlassMetrics

//convert our DataFrame to 2 columns RDD with double values
val predictionLabelsRDD = finalPredictionDF.select("predictionIndex", "labelIndex").map(r => (r.getDouble(0), r.getDouble(1)))
val metrics = new MulticlassMetrics(predictionLabelsRDD.rdd)

predictionLabelsRDD: org.apache.spark.sql.Dataset[(Double, Double)] = [_1: double, _2: double] metrics: org.apache.spark.mllib.evaluation.MulticlassMetrics = org.apache.spark.mllib.evaluation.MulticlassMetrics@17fc00fd
// we use this to convert the index back to its label
val labelsMap: List[(String, Double)] = List(
        ("ADJ",1.0), ("ADP",2.0), ("ADV",3.0),
        ("AUX",4.0), ("CCONJ",5.0), ("DET",6.0),
        ("INTJ",7.0), ("NOUN",8.0), ("NUM",9.0),
        ("PART",10.0), ("PRON",11.0), ("PROPN",12.0),
        ("PUNCT",13.0), ("SCONJ",14.0), ("SYM",15.0),
        ("VERB",16.0), ("X",17.0)
)
labelsMap: List[(String, Double)] = List((ADJ,1.0), (ADP,2.0), (ADV,3.0), (AUX,4.0), (CCONJ,5.0), (DET,6.0), (INTJ,7.0), (NOUN,8.0), (NUM,9.0), (PART,10.0), (PRON,11.0), (PROPN,12.0), (PUNCT,13.0), (SCONJ,14.0), (SYM,15.0), (VERB,16.0), (X,17.0))
// Overall Statistics
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")

Summary Statistics Accuracy = 0.9565663825854811 accuracy: Double = 0.9565663825854811

Now let's see Precision, Recall and F1 metrics for each label

val labels = metrics.labels
labels: Array[Double] = Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0)
// Precision by label
labels.foreach { l =>
  val labelStr = labelsMap.find(_._2 == l).map(_._1).getOrElse("Not defined").asInstanceOf[String]
  println(s"Precision($labelStr) = " + metrics.precision(l))
}
Precision(ADJ) = 0.9038142620232172 Precision(ADP) = 0.9786524349566378 Precision(ADV) = 0.9468503937007874 Precision(AUX) = 0.94579945799458 Precision(CCONJ) = 1.0 Precision(DET) = 0.9614754098360656 Precision(INTJ) = 0.5 Precision(NOUN) = 0.9575625680087051 Precision(NUM) = 0.9523809523809523 Precision(PART) = 1.0 Precision(PRON) = 0.9733333333333334 Precision(PROPN) = 0.8623188405797102 Precision(PUNCT) = 0.9949664429530202 Precision(SCONJ) = 0.9173553719008265 Precision(SYM) = 1.0 Precision(VERB) = 0.952020202020202 Precision(X) = 0.3333333333333333
// Recall by label
labels.foreach { l =>
  val labelStr = labelsMap.find(_._2 == l).map(_._1).getOrElse("Not defined").asInstanceOf[String]
  println(s"Recall($labelStr) = " + metrics.recall(l))
}
Recall(ADJ) = 0.902317880794702 Recall(ADP) = 0.9925575101488497 Recall(ADV) = 0.939453125 Recall(AUX) = 0.9775910364145658 Recall(CCONJ) = 1.0 Recall(DET) = 0.9750623441396509 Recall(INTJ) = 0.6666666666666666 Recall(NOUN) = 0.9534127843986999 Recall(NUM) = 0.9606986899563319 Recall(PART) = 0.5714285714285714 Recall(PRON) = 0.930783242258652 Recall(PROPN) = 0.9242718446601942 Recall(PUNCT) = 1.0 Recall(SCONJ) = 0.8740157480314961 Recall(SYM) = 0.7631578947368421 Recall(VERB) = 0.9161603888213852 Recall(X) = 0.07142857142857142