Spark SQL的Dataset基本操作
https://www.toutiao.com/i6631318012546793992/
版本:Spark 2.3.1,hadoop-2.7.4,jdk1.8为例讲解
基本简介和准备工作
Dataset接口是在spark 1.6引入的,受益于RDD(强类型,可以使用强大的lambda函数),同时也可以享受Spark SQL优化执行引擎的优点。Dataset的可以从jvm 对象创建,然后就可以使用转换函数(map,flatmap,filter等)。
1.6版本之前常用的Dataframe是一种特殊的Dataset,也即是1
type DataFrame = Dataset[Row]
结构化数据文件(json,csv,orc等),hive表,外部数据库,已有RDD。
下面进行测试,大家可能都会说缺少数据,实际上Spark源码里跟我们提供了丰富的测试数据。源码的examples路径下:examples/src/main/resources。
首先要创建一个SparkSession:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19val sparkConf = new SparkConf()
.setAppName(this.getClass.getName)
.setMaster("local[*]")
.set("yarn.resourcemanager.hostname", "localhost")
// executor的实例数
.set("spark.executor.instances","2")
.set("spark.default.parallelism","4")
// sql shuffle的并行度,由于是本地测试,所以设置较小值,避免产生过多空task,实际上要根据生产数据量进行设置。
.set("spark.sql.shuffle.partitions","4")
.setJars(List("/Users/meitu/Desktop/sparkjar/bigdata.jar"
,"/opt/jars/spark-streaming-kafka-0-10_2.11-2.3.1.jar"
,"/opt/jars/kafka-clients-0.10.2.2.jar"
,"/opt/jars/kafka_2.11-0.10.2.2.jar"
))
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()
创建dataset1
2
3
4
5
6val sales = spark.createDataFrame(
Seq( ("Warsaw", 2016, 100), ("Warsaw", 2017, 200), ("Warsaw", 2015, 100), ("Warsaw", 2017, 200), ("Beijing", 2017, 200), ("Beijing", 2016, 200),
("Beijing", 2015, 200), ("Beijing", 2014, 200), ("Warsaw", 2014, 200),
("Boston", 2017, 50), ("Boston", 2016, 50), ("Boston", 2015, 50),
("Boston", 2014, 150)))
.toDF("city", "year", "amount")
使用函数的时候要导入包:1
import org.apache.spark.sql.functions.{col,expr}
select
列名称可以是字符串,这种形式无法对列名称使用表达式进行逻辑操作。
使用col函数,可以直接对列进行一些逻辑操作。1
2sales.select("city","year","amount").show(1)
sales.select(col("city"),col("amount")+1).show(1)
selectExpr
参数是字符串,且直接可以使用表达式。
也可以使用select+expr函数来替代。1
2sales.selectExpr("city","year as date","amount+1").show(10)
sales.select(expr("city"),expr("year as date"),expr("amount+1")).show(10)
filter
参数可以是与col结合的表达式,参数类型为row返回值为boolean的函数,字符串表达式。1
2
3sales.filter(col("amount")>150).show()
sales.filter(row=>{ row.getInt(2)>150}).show(10)
sales.filter("amount > 150 ").show(10)
where
类似于fliter,参数可以是与col函数结合的表达式也可以是直接使用表达式字符串。1
2sales.where(col("amount")>150).show()
sales.where("amount > 150 ").show()
group by
主要是以count和agg聚合函数为例讲解groupby函数。1
2sales.groupBy("city").count().show(10)
sales.groupBy(col("city")).agg(sum("amount").as("total")).show(10)
union
两个dataset的union操作这个等价于union all操作,所以要实现传统数据库的union操作,需要在其后使用distinct进行去重操作。1
sales.union(sales).groupBy("city").count().show()
join
join操作相对比较复杂,具体如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30// 相同的列进行join
sales.join(sales,"city").show(10)
// 多列join
sales.join(sales,Seq("city","year")).show()
/* 指定join类型, join 类型可以选择:
`inner`, `cross`, `outer`, `full`, `full_outer`,
`left`, `left_outer`, `right`, `right_outer`,
`left_semi`, `left_anti`.
*/
// 内部join
sales.join(sales,Seq("city","year"),"inner").show()
/* join条件 :
可以在join方法里放入join条件,
也可以使用where,这两种情况都要求字段名称不一样。
*/
sales.join(sales, col("city").alias("city1") === col("city")).show() sales.join(sales).where(col("city").alias("city1") === col("city")).show()
/*
dataset的self join 此处使用where作为条件,
需要增加配置.set("spark.sql.crossJoin.enabled","true")
也可以加第三个参数,join类型,可以选择如下:
`inner`, `cross`, `outer`, `full`, `full_outer`, `left`,
`left_outer`, `right`, `right_outer`, `left_semi`, `left_anti`
*/
sales.join(sales,sales("city") === sales("city")).show() sales.join(sales).where(sales("city") === sales("city")).show()
/* joinwith,可以指定第三个参数,join类型,
类型可以选择如下:
`inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`,
`right`, `right_outer`。
*/
sales.joinWith(sales,sales("city") === sales("city"),"inner").show()
输出结果:
order by
orderby 全局有序,其实用的还是sort1
2sales.orderBy(col("year").desc,col("amount").asc).show()
sales.orderBy("city","year").show()
sort
全局排序,直接替换掉8小结的orderby即可。
sortwithinpartition
在分区内部进行排序,局部排序。1
2sales.sortWithinPartitions(col("year").desc,col("amount").asc).show()
sales.sortWithinPartitions("city","year").show()
可以看到,city为背景的应该是分配到不同的分区,然后每个分区内部year都是有序的。
withColumn
1 | /* withColumn |
foreach
这个跟rdd的foreach一样,元素类型是row。1
2
3sales.foreach(row=>{
println(row.getString(0))
})
foreachPartition
跟RDD的foreachPartition一样,针对分区进行计算,对于输出到数据库,kafka等数据相对于使用foreach可以大量减少连接数。1
2
3
4
5
6
7
8sales.foreachPartition(partition=>{
//打开数据库链接等
partition.foreach(each=>{
println(each.getString(0))
//插入数据库
})
//关闭数据库链接
})
distinct
针对dataset的行去重,返回的是所有行都不重复的dataset。1
sales.distinct().show(10)
dropDuplicates
这个适用于dataset有唯一的主键,然后对主键进行去重。1
2
3
4val before = sales.count()
val after = sales.dropDuplicates("city").count()
println("before ====> " +before)
println("after ====> "+after)
drop
删除一列,或者多列,这是一个变参数算子。1
2
3
4
5sales.drop("city").show()
打印出来schema信息如下:
root
|-- year: integer (nullable = false)
|-- amount: integer (nullable = false)
printSchema
输出dataset的schema信息1
2
3
4
5
6
7
8
9
10
11sales.printSchema()
输出结果如下:
root
|-- city: string (nullable = true)
|-- year: integer (nullable = false)
|-- amount: integer (nullable = false)
explain()
打印执行计划,这个便于调试,了解spark sql引擎的优化执行的整个过程1
2
3
4
5
6sales.orderBy(col("year").desc,col("amount").asc).explain()
执行计划输出如下:
== Physical Plan ==
*(1) Sort [year#7 DESC NULLS LAST, amount#8 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(year#7 DESC NULLS LAST, amount#8 ASC NULLS FIRST, 3)
+- LocalTableScan [city#6, year#7, amount#8]
Spark SQL入门到精通之第二篇Dataset的复杂操作
本文是Spark SQL入门到精通系列第二弹,数据仓库常用的操作:
cube,rollup,pivot操作。
cube
简单的理解就是维度及度量组成的数据体。1
2
3
4sales.cube("city","year")
.agg(sum("amount"))
.sort(col("city").desc_nulls_first,col("year").desc_nulls_first)
.show()
举个简单的例子,上面的纬度city,year两个列是纬度,然后amount是要进行聚合的度量。
实际上就相当于,(year,city),(year),(city),() 分别分组然后对amount求sum,最终输出结果,代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15val city_year = sales
.groupBy("city","year").
agg(sum("amount"))
val city = sales.groupBy("city")
.agg(sum("amount") as "amount")
.select(col("city"), lit(null) as "year", col("amount"))
val year = sales.groupBy("year")
.agg(sum("amount") as "amount")
.select( lit(null) as "city",col("year"), col("amount"))
val none = sales .groupBy()
.agg(sum("amount") as "amount")
.select(lit(null) as "city", lit(null) as "year", col("amount"))
city_year.union(city).union(year).union(none)
.sort(desc_nulls_first("city"), desc_nulls_first("year"))
.show()
rollup
这里也是以案例开始,代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13val expenses = spark.createDataFrame(Seq(
((2012, Month.DECEMBER, 12), 5),
((2016, Month.AUGUST, 13), 10),
((2017, Month.MAY, 27), 15))
.map { case ((yy, mm, dd), a) => (LocalDate.of(yy, mm, dd), a) }
.map { case (d, a) => (d.toString, a) }
.map { case (d, a) => (Date.valueOf(d), a) }).toDF("date", "amount")
// rollup time!
val res = expenses
.rollup(year(col("date")) as "year", month(col("date")) as "month")
.agg(sum("amount") as "amount")
.sort(col("year").asc_nulls_last, col("month").asc_nulls_last)
.show()
这个等价于分别对(year,month),(year),()进行 groupby 对amount求sum,然后再进行union操作,也即是可以按照下面的实现:1
2
3
4
5
6
7
8
9
10
11
12val year_month = expenses
.groupBy(year(col("date")) as "year", month(col("date")) as "month")
.agg(sum("amount") as "amount")
val yearOnly = expenses.groupBy(year(col("date")) as "year")
.agg(sum("amount") as "amount")
.select(col("year"), lit(null) as "month", col("amount"))
val none = expenses.groupBy()
.agg(sum("amount") as "amount")
.select(lit(null) as "year", lit(null) as "month", col("amount"))
year_month.union(yearOnly).union(none)
.sort(col("year").asc_nulls_last, col("month").asc_nulls_last)
.show()
pivot
旋转操作
https://mp.weixin.qq.com/s/Jky2q3FW5wqQ-prpsW2OEw
Pivot 算子是 spark 1.6 版本开始引入的,在 spark2.4版本中功能做了增强,还是比较强大的,做过数据清洗ETL工作的都知道,行列转换是一个常见的数据整理需求。spark 中的Pivot 可以根据枢轴点(Pivot Point) 把多行的值归并到一行数据的不同列,这个估计不太好理解,我们下面使用例子说明,看看pivot 这个算子在处理复杂数据时候的威力。1
2
3
4
5
6
7
8
9
10
11
12
13val sales = spark.createDataFrame(Seq(
("Warsaw", 2016, 100,"Warsaw"),
("Warsaw", 2017, 200,"Warsaw"),
("Warsaw", 2016, 100,"Warsaw"),
("Warsaw", 2017, 200,"Warsaw"),
("Boston", 2015, 50,"Boston"),
("Boston", 2016, 150,"Boston"),
("Toronto", 2017, 50,"Toronto")))
.toDF("city", "year", "amount","test")
sales.groupBy("year")
.pivot("city",Seq("Warsaw","Boston","Toronto"))
.agg(sum("amount") as "amount")
.show()
思路就是首先对year分组,然后旋转city字段,只取:
“Warsaw”,”Boston”,”Toronto”
然后对amount进行聚合操作
案例:
使用Pivot 来统计天气走势。
下面是西雅图的天气数据表,每行代表一天的天气最高值:1
2
3
4
5
6
7
8
9
10 Date Temp (°F)
07-22-2018 86
07-23-2018 90
07-24-2018 91
07-25-2018 92
07-26-2018 92
07-27-2018 88
07-28-2018 85
07-29-2018 94
07-30-2018 89
如果我们想看下最近几年的天气走势,如果这样一天一行数据,是很难看出趋势来的,最直观的方式是 按照年来分行,然后每一列代表一个月的平均天气,这样一行数据,就可以看到这一年12个月的一个天气走势,下面我们使用 pivot 来构造这样一个查询结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23SELECT
*
FROM
(
SELECT
year(date) year,
month(date) month,
temp
FROM
high_temps
WHERE
date between DATE '2015-01-01'
and DATE '2018-08-31'
) PIVOT (
CAST (
AVG(temp) as DECIMAL(4, 1)
) FOR month in (
1 JAN, 2 FEB, 3 MAR, 4 APR, 5 MAY, 6 JUN, 7 JUL,
8 AUG, 9 SEP, 10 OCT, 11 NOV, 12 DEC
)
)
ORDER BY
year DESC;
结果如下图:
是不是很直观,第一行就代表 2018 年,从1月到12月的平均天气,能看出一年的天气走势。
我们来看下这个 sql 是怎么玩的,首先是一个 子查询语句,我们最终关心的是 年份,月份,和最高天气值,所以先使用子查询对原始数据进行处理,从日期里面抽取出来年份和月份。
下面在子查询的结果上使用 pivot 语句,pivot 第一个参数是一个聚合语句,这个代表聚合出来一个月30天的一个平均气温,第二个参数是 FOR month,这个是指定以哪个列为枢轴列,第三个 In 子语句指定我们需要进行行列转换的具体的 枢轴点(Pivot Point)的值,上面的例子中 1到12月份都包含了,而且给了一个别名,如果只指定 1到6月份,结果就如下了:
上面sql语句里面有个特别的点需要注意, 就是聚合的时候有个隐含的维度字段,就是 年份,按理来讲,我们没有写 group-by year, 为啥结果表里面不同年份区分在了不同的行,原因是,FORM 子查询出来的每行有 3个列, year,month,tmp,如果一个列既不出现在进行聚合计算的列中(temp 是聚合计算的列), 也不作为枢轴列 , 就会作为聚合的时候一个隐含的维度。我们的例子中算平均值聚合操作的维度是 (year, month),一个是隐含维度,一个是 枢轴列维度, 这一点一定要注意,如果不需要按照 year 来区分,FORM 查询的时候就不要加上这个列。
指定多个聚合语句
上文中只有一个聚合语句,就是计算平均天气,其实是可以加多个聚合语句的,比如我们需要看到 7,8,9 月份每个月的最大气温和平均气温,就可以用以下SQL语句。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21SELECT
*
FROM
(
SELECT
year(date) year,
month(date) month,
temp
FROM
high_temps
WHERE
date between DATE '2015-01-01'
and DATE '2018-08-31'
) PIVOT (
CAST (
AVG(temp) as DECIMAL(4, 1)
) avg,
max(temp) max FOR month in (6 JUN, 7 JUL, 8 AUG, 9 SEP)
)
ORDER BY
year DESC;
上文中指定了两个聚合语句,查询后, 枢轴点(Pivot Point) 和 聚合语句 的笛卡尔积作为结果的不同列,也就是
聚合列(Grouping Columns)和 枢轴列(Pivot Columns)的不同之处
现在假如我们有西雅图每天的最低温数据,我们需要把最高温和最低温放在同一张表里面看对比着看.1
2
3
4
5
6
7
8
9Date Temp (°F)
… …
08-01-2018 59
08-02-2018 58
08-03-2018 59
08-04-2018 58
08-05-2018 59
08-06-2018 59
… …
我们使用 UNION ALL 把两张表做一个合并:1
2
3
4
5
6
7SELECT date, temp, 'H' as flag
FROM
high_temps
UNION ALL
SELECT date, temp, 'L' as flag
FROM
low_temps;
现在使用 pivot 来进行处理:1
2
3
4
5
6
7
8
9
10
11
12SELECT * FROM (
SELECT date,temp,'H' as flag
FROM high_temps
UNION ALL
SELECT date,temp,'L' as flag
FROM low_temps
)
WHERE date BETWEEN DATE '2015-01-01' AND DATE '2018-08-31'
PIVOT(
CAST(avg(temp) as DECIMAL(4,1))
FOR month in (6 JUN,7 JUL,8 AUG , 9 SEP)
) ORDER BY year DESC ,`H/L` ASC;
我们统计了 4年中 7,8,9 月份最低温和最高温的平均值,这里要注意的是,我们把 year 和 一个最低最高的标记(H/L)都作为隐含维度,不然算出来的就是最低最高温度在一起的平均值了。结果如下图:
上面的查询中,我们把最低最高的标记(H/L)也作为了一个隐含维度,group-by 的维度就变成了 (year, H/L, month), 但是year 和 H/L 体现在了不同行上面,month 体现在了不同的列上面,这就是 聚合列(Grouping Columns)和 枢轴列(Pivot Columns)的不同之处。
再接再厉,我们把 H/L 作为 枢轴列(Pivot Columns) 进行查询:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17SELECT * FROM (
SELECT date,temp,'H' as flag
FROM high_temps
UNION ALL
SELECT date,temp,'L' as flag
FROM low_temps
)
WHERE date BETWEEN DATE '2015-01-01' AND DATE '2018-08-31'
PIVOT(
CAST(avg(temp) as DECIMAL(4,1))
FOR (month,flag) in (
(6,'H') JUN_hi,(6,'L') JUN_lo,
(7,'H') JUl_hi,(7,'L') JUl_lo,
(8,'H') AUG_hi,(8,'L') AUG_lo,
(9,'H') SEP_hi,(9,'L') SEP_lo
)
) ORDER BY year DESC;
结果的展现方式就和上面的不同了,虽然每个单元格值都是相同的,但是把 H/L 和 month 笛卡尔积作为列,H/L 维度体现在了列上面了:
源码:Spark SQL 分区特性第一弹
常见RDD分区
Spark Core 中的RDD的分区特性大家估计都很了解,这里说的分区特性是指从数据源读取数据的第一个RDD或者Dataset的分区,而后续再介绍转换过程中分区的变化。
举几个浪尖在星球里分享比较多的例子,比如:
Spark Streaming 与kafka 结合 DirectDstream 生成的微批RDD(kafkardd)分区数和kafka分区数一样。
Spark Streaming 与kafka结合 基于receiver的方式,生成的微批RDD(blockRDD),分区数就是block数。
普通的文件RDD,那么分可分割和不可分割,通常不可分割的分区数就是文件数。可分割需要计算而且是有条件的,在星球里分享过了。
这些都很简单,那么今天咱们要谈的是Spark DataSet的分区数的决定因素。
准备数据
首先是由Seq数据集合生成一个Dataset1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26val sales = spark.createDataFrame(Seq(
("Warsaw", 2016, 110),
("Warsaw", 2017, 10),
("Warsaw", 2015, 100),
("Warsaw", 2015, 50),
("Warsaw", 2015, 80),
("Warsaw", 2015, 100),
("Warsaw", 2015, 130),
("Warsaw", 2015, 160),
("Warsaw", 2017, 200),
("Beijing", 2017, 100),
("Beijing", 2016, 150),
("Beijing", 2015, 50),
("Beijing", 2015, 30),
("Beijing", 2015, 10),
("Beijing", 2014, 200),
("Beijing", 2014, 170),
("Boston", 2017, 50),
("Boston", 2017, 70),
("Boston", 2017, 110),
("Boston", 2017, 150),
("Boston", 2017, 180),
("Boston", 2016, 30),
("Boston", 2015, 200),
("Boston", 2014, 20)
)).toDF("city", "year", "amount")
将Dataset存处为partquet格式的hive表,分两种情况:
用city和year字段分区1
sales.write.partitionBy("city","year").mode(SaveMode.Overwrite).saveAsTable("ParquetTestCityAndYear")
用city字段分区1
sales.write.partitionBy("city").mode(SaveMode.Overwrite).saveAsTable("ParquetTestCity")
读取数据采用的是1
val res = spark.read.parquet("/user/hive/warehouse/parquettestcity")
直接展示,结果发现结果会随着spark.default.parallelism变化而变化。文章里只读取city字段分区的数据,特点就是只有单个分区字段。
1. spark.default.parallelism =40
Dataset的分区数是由参数:
目录数和生成的FileScanRDD的分区数分别数下面截图的第一行和第二行
这个分区数目正好是文件数,那么假如不了解细节的话,肯定会认为分区数就是由文件数决定的,其实不然。
2. spark.default.parallelism =4
Dataset的分区数是由参数:1
println("partition size = "+res.rdd.partitions.length)
目录数和生成的FileScanRDD的分区数分别数下面截图的第一行和第二行。
那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。
那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。
那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91private def createNonBucketedReadRDD(
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Seq[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
/*
selectedPartitions 的大小代表目录数目
*/
println("selectedPartitions.size : "+ selectedPartitions.size)
val defaultMaxSplitBytes =
fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
// spark.default.parallelism
val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism
// 计算文件总大小,单位字节数
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
//计算平均每个并行度读取数据大小
val bytesPerCore = totalBytes / defaultParallelism
// 首先spark.sql.files.openCostInBytes 该参数配置的值和bytesPerCore 取最大值
// 然后,比较spark.sql.files.maxPartitionBytes 取小者
val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")
// 这对目录遍历
val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
val blockLocations = getBlockLocations(file)
//判断文件类型是否支持分割,以parquet为例,是支持分割的
if (fsRelation.fileFormat.isSplitable(
fsRelation.sparkSession, fsRelation.options, file.getPath)) {
// eg. 0 until 2不包括 2。相当于
// println(0 until(10) by 3) 输出 Range(0, 3, 6, 9)
(0L until file.getLen by maxSplitBytes).map { offset =>
// 计算文件剩余的量
val remaining = file.getLen - offset
// 假如剩余量不足 maxSplitBytes 那么就剩余的作为一个分区
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
// 位置信息
val hosts = getBlockHosts(blockLocations, offset, size)
PartitionedFile(
partition.values, file.getPath.toUri.toString, offset, size, hosts)
}
} else {
// 不可分割的话,那即是一个文件一个分区
val hosts = getBlockHosts(blockLocations, 0, file.getLen)
Seq(PartitionedFile(
partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
}
}
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
val partitions = new ArrayBuffer[FilePartition]
val currentFiles = new ArrayBuffer[PartitionedFile]
var currentSize = 0L
/** Close the current partition and move to the next. */
def closePartition(): Unit = {
if (currentFiles.nonEmpty) {
val newPartition =
FilePartition(
partitions.size,
currentFiles.toArray.toSeq) // Copy to a new Array.
partitions += newPartition
}
currentFiles.clear()
currentSize = 0
}
// Assign files to partitions using "Next Fit Decreasing"
splitFiles.foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
}
// Add the given file to the current partition.
currentSize += file.length + openCostInBytes
currentFiles += file
}
closePartition()
println("FileScanRDD partitions size : "+partitions.size)
new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
}
找到一列中的中位数
https://spark.apache.org/docs/latest/api/sql/index.html
df函数: approxQuantile
sql函数: percentile_approx
自定义数据源
ServiceLoader
https://mp.weixin.qq.com/s/QpYvqJpw7TnFAY8rD6Jf3w
ServiceLoader是SPI的是一种实现,所谓SPI,即Service Provider Interface,用于一些服务提供给第三方实现或者扩展,可以增强框架的扩展或者替换一些组件。
要配置在相关项目的固定目录下:
resources/META-INF/services/接口全称。
这个在大数据的应用中颇为广泛,比如Spark2.3.1 的集群管理器插入:
SparkContext类1
2
3
4
5
6
7
8
9
10
11 private def getClusterManager(url: String): Option[ExternalClusterManager] = {
val loader = Utils.getContextOrSparkClassLoader
val serviceLoaders =
ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url))
if (serviceLoaders.size > 1) {
throw new SparkException(
s"Multiple external cluster managers registered for the url $url: $serviceLoaders")
}
serviceLoaders.headOption
}
}
配置是在
spark sql数据源的接入,新增数据源插入的时候可以采用这种方式,要实现的接口是DataSourceRegister。
简单测试
首先实现一个接口1
2
3
4
5
6
7package bigdata.spark.services;
public interface DoSomething {
//可以制定实现类名加载
public String shortName();
public void doSomeThing();
}
然后将接口配置在resources/META-INF/services/
bigdata.spark.services.DoSomething文件
内容:
实现该接口1
2
3
4
5
6
7
8
9
10
11
12
13package bigdata.spark.services;
public class SayHello implements DoSomething {
public String shortName() {
return "SayHello";
}
public void doSomeThing() {
System.out.println("hello !!!");
}
}
测试1
2
3
4
5
6
7
8
9
10
11
12
13package bigdata.spark.services;
import java.util.ServiceLoader;
public class test {
static ServiceLoader<DoSomething> loader = ServiceLoader.load(DoSomething.class);
public static void main(String[] args){
for(DoSomething sayhello : loader){
//要加载的类名称我们可以制定
if(sayhello.shortName().equalsIgnoreCase("SayHello")){
sayhello.doSomeThing();
}
}
}
}
这个主要是为讲自定义数据源作准备。
https://articles.zsxq.com/id_702s32f46zet.html
首先要搞明白spark是如何支持多数据源的,昨天说了是通过serverloader加载的。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
// 指定路径加载,默认加载类名是DefaultSource
val provider2 = s"$provider1.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader
// ServiceLoader加载
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
try {
serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
// the provider format did not match any given registered aliases
case Nil =>
try {
Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
case Success(dataSource) =>
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
if (provider1.toLowerCase == "orc" ||
provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw new AnalysisException(
"The ORC data source must be used with Hive support enabled")
} else if (provider1.toLowerCase == "avro" ||
provider1 == "com.databricks.spark.avro") {
throw new AnalysisException(
s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " +
"package at http://spark.apache.org/third-party-projects.html")
} else {
throw new ClassNotFoundException(
s"Failed to find data source: $provider1. Please find packages at " +
"http://spark.apache.org/third-party-projects.html",
error)
}
}
} catch {
case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
// NoClassDefFoundError's class name uses "/" rather than "." for packages
val className = e.getMessage.replaceAll("/", ".")
if (spark2RemovedClasses.contains(className)) {
throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
"Please check if your library is compatible with Spark 2.0", e)
} else {
throw e
}
}
case head :: Nil =>
// there is exactly one registered alias
head.getClass
case sources =>
// There are multiple registered aliases for the input
sys.error(s"Multiple sources found for $provider1 " +
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
"please specify the fully qualified class name.")
}
} catch {
case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
// NoClassDefFoundError's class name uses "/" rather than "." for packages
val className = e.getCause.getMessage.replaceAll("/", ".")
if (spark2RemovedClasses.contains(className)) {
throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " +
"Please remove the incompatible library from classpath or upgrade it. " +
s"Error: ${e.getMessage}", e)
} else {
throw e
}
}
}
其实,从这个点,你可以思考一下,自己能学到多少东西:类加载,可扩展的编程思路。
主要思路是:
实现DefaultSource。
实现工厂类。
实现具体的数据加载类。
首先,主要是有三个实现吧,需要反射加载的类DefaultSource,这个名字很固定的:1
2
3
4
5
6
7
8package bigdata.spark.SparkSQL.DataSources
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport}
class DefaultSource extends DataSourceV2 with ReadSupport {
def createReader(options: DataSourceOptions) = new SimpleDataSourceReader()
}
然后是,要实现DataSourceReader,负责创建阅读器工厂:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16package bigdata.spark.SparkSQL.DataSources
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.reader.{DataReaderFactory, DataSourceReader}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
class SimpleDataSourceReader extends DataSourceReader {
def readSchema() = StructType(Array(StructField("value", StringType)))
def createDataReaderFactories = {
val factoryList = new java.util.ArrayList[DataReaderFactory[Row]]
factoryList.add(new SimpleDataSourceReaderFactory())
factoryList
}
}
数据源的具体实现类:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22package bigdata.spark.SparkSQL.DataSources
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
class SimpleDataSourceReaderFactory extends
DataReaderFactory[Row] with DataReader[Row] {
def createDataReader = new SimpleDataSourceReaderFactory()
val values = Array("1", "2", "3", "4", "5")
var index = 0
def next = index < values.length
def get = {
val row = Row(values(index))
index = index + 1
row
}
def close() = Unit
}
使用我们默认的数据源:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28package bigdata.spark.SparkSQL.DataSources
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object App {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName(this.getClass.getName).setMaster("local[*]")
.set("yarn.resourcemanager.hostname", "mt-mdh.local")
.set("spark.executor.instances","2")
.setJars(List("/opt/sparkjar/bigdata.jar"
,"/opt/jars/spark-streaming-kafka-0-10_2.11-2.3.1.jar"
,"/opt/jars/kafka-clients-0.10.2.2.jar"
,"/opt/jars/kafka_2.11-0.10.2.2.jar"))
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()
val simpleDf = spark.read
.format("bigdata.spark.SparkSQL.DataSources")
.load()
simpleDf.show()
spark.stop()
}
}
format里面指定的是包路径,然后加载的时候会加上默认类名:DefaultSource。
删除多个字段
1 | def dropColumns(columns: Seq[String]):DataFAME={ |
spark sql 窗口函数
https://mp.weixin.qq.com/s/A5CiLWPdg1nkjuNImetaAw
最近理了下spark sql 中窗口函数的知识,打算开几篇文章讲讲,窗口函数有很多应用场景,比如说炒股的时候有个5日移动均线,或者让你对一个公司所有销售所有部门按照销售业绩进行排名等等,都是要用到窗口函数,在spark sql中,窗口函数和聚合函数的区别,之前文章中也提到过,就是聚合函数是按照你聚合的维度,每个分组中算出来一个聚合值,而窗口函数是对每一行,都根据当前窗口(5日均线就是今日往前的5天组成的窗口),都聚合出一个值。
1 开胃菜,spark sql 窗口函数的基本概念和使用姿势
2 spark sql 中窗口函数深入理解
3 不是UDF,也不是UDAF,教你自定义一个窗口函数(UDWF)
4 从 spark sql 源码层面理解窗口函数
什么是简单移动平均值
简单移动平均(英语:Simple Moving Average,SMA)是某变数之前n个数值的未作加权算术平均。例如,收市价的10日简单移动平均指之前10日收市价的平均数。
例子1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31val spark = SparkSession
.builder()
.master("local")
.appName("DateFrameFromJsonScala")
.config("spark.some.config.option", "some-value")
.getOrCreate()
import spark.implicits._
val df = List(
("站点1", "201902025", 50),
("站点1", "201902026", 90),
("站点1", "201902026", 100),
("站点2", "201902027", 70),
("站点2", "201902028", 60),
("站点2", "201902029", 40))
.toDF("site", "date", "user_cnt")
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
/**
* 这个 window spec 中,数据根据用户(customer)来分去。
* 每一个用户数据根据时间排序。然后,窗口定义从 -1(前一行)到 1(后一行)
* ,每一个滑动的窗口总用有3行
*/
val wSpce = Window.partitionBy("site").orderBy("date").rowsBetween(-1, 1)
df.withColumn("movinAvg", avg("user_cnt").over(wSpce)).show()
spark.stop()
结果:
2~3~21
2
3
4
5
6
7
8
9
10+----+---------+--------+------------------+
|site| date|user_cnt| movinAvg|
+----+---------+--------+------------------+
| 站点1|201902025| 50| 70.0|
| 站点1|201902026| 90| 80.0|
| 站点1|201902026| 100| 95.0|
| 站点2|201902027| 70| 65.0|
| 站点2|201902028| 60|56.666666666666664|
| 站点2|201902029| 40| 50.0|
+----+---------+--------+------------------+
窗口函数和窗口特征定义
正如上述例子中,窗口函数主要包含两个部分:
指定窗口特征(wSpec):
“partitionyBY” 定义数据如何分组;在上面的例子中,是用户 site
“orderBy” 定义分组中的排序
“rowsBetween” 定义窗口的大小
指定窗口函数函数:
指定窗口函数函数,你可以使用 org.apache.spark.sql.functions 的“聚合函数(Aggregate Functions)”和”窗口函数(Window Functions)“类别下的函数.
累计汇总
1 |
|
结果1
2
3
4
5
6
7
8
9
10+----+---------+--------+------+
|site| date|user_cnt|cumsum|
+----+---------+--------+------+
| 站点1|201902025| 50| 140|
| 站点1|201902026| 90| 240|
| 站点1|201902026| 100| 240|
| 站点2|201902027| 70| 130|
| 站点2|201902028| 60| 170|
| 站点2|201902029| 40| 170|
+----+---------+--------+------+
前一行数据
1 | val spark = SparkSession |
结果1
2
3
4
5
6
7
8
9
10+----+---------+--------+----------+
|site| date|user_cnt|preUserCnt|
+----+---------+--------+----------+
| 站点1|201902025| 50| null|
| 站点1|201902026| 90| 50|
| 站点1|201902026| 100| 90|
| 站点2|201902027| 70| null|
| 站点2|201902028| 60| 70|
| 站点2|201902029| 40| 60|
+----+---------+--------+----------+
如果计算环比的时候,是不是特别有用啊?!
在介绍几个常用的行数:
first/last(): 提取这个分组特定排序的第一个最后一个,在获取用户退出的时候,你可能会用到
lag/lead(field, n): lead 就是 lag 相反的操作,这个用于做数据回测特别用,结果回推条件
排名
1 | val spark = SparkSession |
结果1
2
3
4
5
6
7
8
9
10+----+---------+--------+----+
|site| date|user_cnt|rank|
+----+---------+--------+----+
| 站点1|201902025| 50| 1|
| 站点1|201902026| 90| 2|
| 站点1|201902026| 100| 2|
| 站点2|201902027| 70| 1|
| 站点2|201902028| 60| 2|
| 站点2|201902029| 40| 3|
+----+---------+--------+----+
这个数据在提取每个分组的前n项时特别有用,省了不少麻烦。
自定义窗口函数
https://mp.weixin.qq.com/s/SMNX5lVPb0DWRf27QCcjIQ
https://github.com/zheniantoushipashi/spark-udwf-session
背景
在使用 spark sql 的时候,有时候默认提供的sql 函数可能满足不了需求,这时候可以自定义一些函数,可以自定义 UDF 或者UDAF。
UDF 在sql中只是简单的处理转换一些字段,类似默认的trim 函数把一个字符串类型的列的头尾空格去掉, UDAF函数不同于UDF,是在sql聚合语句中使用的函数,必须配合 GROUP BY 一同使用,类似默认提供的count,sum函数,但是还有一种自定义函数叫做 UDWF, 这种一般人就不知道了,这种叫做窗口自定义函数,不了解窗口函数的,可以参考 1 开胃菜,spark sql 窗口函数的基本概念和使用姿势 ,或者官方的介绍 https://databricks.com/blog/2015/07/15/introducing-window-functions-in-spark-sql.html
窗口函数是 SQL 中一类特别的函数。和聚合函数相似,窗口函数的输入也是多行记录。不同的是,聚合函数的作用于由 GROUP BY 子句聚合的组,而窗口函数则作用于一个窗口
这里怎么理解一个窗口呢,spark君在这里得好好的解释解释,一个窗口是怎么定义的,
窗口语句中,partition by用来指定分区的列,在同一个分区的行属于同一个窗口
order by用来指定窗口内的多行,如何排序
windowing_clause 用来指定开窗方式,在spark sql 中开窗方式有那么几种
一个分区中的所有行作为一个窗口:UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING(上下都没有边界),这种情况下,spark sql 会把所有行作为一个输入,进行一次求值
Growing frame:UNBOUNDED PRECEDING AND ….(上无边界), 这种就是不断的把当前行加入的窗口中,而不删除, 例子:.rowsBetween(Long.MinValue, 0) :窗口的大小是按照排序从最小值到当前行,在数据迭代过程中,不断的把当前行加入的窗口中。
Shrinking frame:… AND UNBOUNDED FOLLOWING(下无边界)和Growing frame 相反,窗口不断的把迭代到的当前行从窗口中删除掉。
Moving frame:滑动的窗口,举例:.rowsBetween(-1, 1) 就是指窗口定义从 -1(当前行前一行)到 1(当前行后一行) ,每一个滑动的窗口总用有3行
Offset frame: 窗口中只有一条数据,就是偏移当前行一定距离的那一行,举例:lag(field, n): 就是取从当前行往前的n个行的字段值
这里就针对窗口函数就介绍这么多,如果不懂请参考相关文档,加强理解,我们在平时使用 spark sql 的过程中,会发现有很多教你自定义 UDF 和 UDAF 的教程,却没有针对UDWF的教程,这是为啥呢,这是因为 UDF 和UDAF 都作为上层API暴露给用户了,使用scala很简单就可以写一个函数出来,但是UDWF没有对上层用户暴露,只能使用 Catalyst expressions. 也就是Catalyst框架底层的表达式语句才可以定义,如果没有对源码有很深入的研究,根本就搞不出来。spark 君在工作中写了一些UDWF的函数,但是都比较复杂,不太好单独抽出来作为一个简明的例子给大家讲解,挑一个网上的例子来进行说明,这个例子 spark君亲测可用。
窗口函数的使用场景
我们来举个实际例子来说明 窗口函数的使用场景,在网站的统计指标中,有一个概念叫做用户会话,什么叫做用户会话呢,我来说明一下,我们在网站服务端使用用户session来管理用户状态,过程如下
1) 服务端session是用户第一次访问应用时,服务器就会创建的对象,代表用户的一次会话过程,可以用来存放数据。服务器为每一个session都分配一个唯一的sessionid,以保证每个用户都有一个不同的session对象。
2)服务器在创建完session后,会把sessionid通过cookie返回给用户所在的浏览器,这样当用户第二次及以后向服务器发送请求的时候,就会通过cookie把sessionid传回给服务器,以便服务器能够根据sessionid找到与该用户对应的session对象。
3)session通常有失效时间的设定,比如1个小时。当失效时间到,服务器会销毁之前的session,并创建新的session返回给用户。但是只要用户在失效时间内,有发送新的请求给服务器,通常服务器都会把他对应的session的失效时间根据当前的请求时间再延长1个小时。
也就是说如果用户在1个超过一个小时不产生用户事件,当前会话就结束了,如果后续再产生用户事件,就当做新的用户会话,我们现在就使用spark sql 来统计用户的会话数,首先我们先加一个列,作为当前列的session,然后再 distinct 这个列就得到用户的会话数了,所以关键是怎么加上这个newSession 这个列,这种场景就很适合使用窗口函数来做统计,因为判断当前是否是一个新会话的依据,需要依赖当前行的前一行的时间戳和当前行的时间戳的间隔来判断,下面的表格可以帮助你理解这个概念,例子中有3列数据,用户,event字段代表用户访问了一个页面产生了一个用户事件,time字段代表访问页面的时间戳:
user | event | time | session | |
---|---|---|---|---|
user1 | page1 | 10:12 | session1(new session) | |
user1 | page2 | 10:20 | session1(same session,8 minutes from last event) | |
user1 | page1 | 11:13 | session1(same session,53 minutes from last event) | |
user1 | page3 | 14:12 | session1(new session,3 minutes from last event) |
上面只有一个用户,如果多个用户,可以使用 partition by 来进行分区。
深入研究
构造数据1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19case class UserActivityData(user:String, ts:Long, session:String)
val st = System.currentTimeMillis()
val one_minute = 60 * 1000
val d = Array[UserActivityData](
UserActivityData("user1", st, "f237e656-1e53-4a24-9ad5-2b4576a4125d"),
UserActivityData("user2", st + 5*one_minute, null),
UserActivityData("user1", st + 10*one_minute, null),
UserActivityData("user1", st + 15*one_minute, null),
UserActivityData("user2", st + 15*one_minute, null),
UserActivityData("user1", st + 140*one_minute, null),
UserActivityData("user1", st + 160*one_minute, null))
"a CustomWindowFunction" should "correctly create a session " in {
val sqlContext = new SQLContext(sc)
val df = sqlContext.createDataFrame(sc.parallelize(d))
val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
val res = df.withColumn( "newsession", MyUDWF.calculateSession(f.col("ts"), f.col("session")) over specs)
怎么使用 spark sql 来统计会话数目呢,因为不同用户产生的是不同的会话,首先使用user字段进行分区,然后按照时间戳进行排序,然后我们需要一个自定义函数来加一个列,这个列的值的逻辑如下:1
2
3
4IF (no previous event) create new session
ELSE (if cerrent event was past session window)
THEN create new session
ELSE use current session
运行结果如下:
我们使用 UUID 来作为会话id, 当后一行的时间戳和前一行的时间戳间隔大于1小时的时候,就创建一个新的会话id作为列值,否则使用老的会话id作为列值。
这种就涉及到状态,我们在内部需要维护的状态数据
当前的session ID
当前session的最后活动事件的时间戳
下面我们就看看这个怎样自定义 caculateSession 这个窗口函数,先自定义一个静态对象:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47import java.util.UUID
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, AttributeReference, Expression, If, IsNotNull, LessThanOrEqual, Literal, ScalaUDF, Subtract}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
object MyUDWF {
val defaultMaxSessionLengthms = 3600 * 1000
case class SessionUDWF(timestamp:Expression, session:Expression,
sessionWindow:Expression = Literal(defaultMaxSessionLengthms)) extends AggregateWindowFunction {
self: Product =>
override def children: Seq[Expression] = Seq(timestamp, session)
override def dataType: DataType = StringType
protected val zero = Literal( 0L )
protected val nullString = Literal(null:String)
protected val curentSession = AttributeReference("currentSession", StringType, nullable = true)()
protected val previousTs = AttributeReference("lastTs", LongType, nullable = false)()
override val aggBufferAttributes: Seq[AttributeReference] = curentSession :: previousTs :: Nil
protected val assignSession = If(LessThanOrEqual(Subtract(timestamp, aggBufferAttributes(1)), sessionWindow),
aggBufferAttributes(0), // if
ScalaUDF( createNewSession, StringType, children = Nil))
override val initialValues: Seq[Expression] = nullString :: zero :: Nil
override val updateExpressions: Seq[Expression] =
If(IsNotNull(session), session, assignSession) ::
timestamp ::
Nil
override val evaluateExpression: Expression = aggBufferAttributes(0)
override def prettyName: String = "makeSession"
}
protected val createNewSession = () => org.apache.spark.unsafe.types.UTF8String.fromString(UUID.randomUUID().toString)
def calculateSession(ts:Column,sess:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, Literal(defaultMaxSessionLengthms)) }
def calculateSession(ts:Column,sess:Column, sessionWindow:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, sessionWindow.expr) }
private def withExpr(expr: Expression): Column = new Column(expr)
}
代码说明:
状态保存在 Seq[AttributeReference]中
重写 initialValues方法进行初始化
updateExpressions 函数针对每一行数据都会调用
spark sql 在迭代处理每一行数据的时候,都会调用 updateExpressions 函数来处理,根据当后一行的时间戳和前一行的时间戳间隔大于1小时来进行不同的逻辑处理,如果不大于,就使用 aggBufferAttributes(0) 中保存的老的sessionid,如果大于,就把 createNewSession 包装为一个scalaUDF作为一个子表达式来创建一个新的sessionID,并且每次都把当前行的时间戳作为用户活动的最后时间戳。
最后包装为静态对象的方法,就可以在spark sql中使用这个自定义窗口函数了,下面是两个重载的方法,一个最大间隔时间使用默认值,一个可以运行用户自定义,perfect
现在,我们就可以拿来用在我们的main函数中了。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17val st = System.currentTimeMillis()
val one_minute = 60 * 1000
val d = Array[UserActivityData](
UserActivityData("user1", st, "f237e656-1e53-4a24-9ad5-2b4576a4125d"),
UserActivityData("user2", st + 5*one_minute, null),
UserActivityData("user1", st + 10*one_minute, null),
UserActivityData("user1", st + 15*one_minute, null),
UserActivityData("user2", st + 15*one_minute, null),
UserActivityData("user1", st + 140*one_minute, null),
UserActivityData("user1", st + 160*one_minute, null))
val df = spark.createDataFrame(sc.parallelize(d))
val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
val res = df.withColumn( "newsession", MyUDWF.calculateSession(f.col("ts"), f.col("session")) over specs)
df.show(20)
res.show(20, false)
如果我们学会了自定义 spark 窗口函数,原理是就可以处理一切这种对前后数据依赖的统计需求,不过这种自定义毕竟需要对 spark 源码有很深入的研究才可以,这就需要功力了,希望 spark君的读者可以跟着spark君日益精进。文章中的代码可能不太清楚,完整demo参考 https://github.com/zheniantoushipashi/spark-udwf-session。