pyspark-typedschema documentation#

This is minimally intrusive library to type or annotate pyspark data frames.

There are existing projects which try to change how you interact with pyspark, but this is not the goal of this library. Goals:

  • Create a simple way to define a schema for a pyspark DataFrame.

  • Supply some utility functions to test if the DataFrame adheres to a predefined schema.

  • Enable schema column autocompletion in your editor

  • Re-use existing stuff, such as StructField, from pyspark

In pyspark you have 3 ways of referencing a column:

df.col
df['col']
F.col('col')

All have their advantages and disadvantages, but a consistent way of dealing with column names, pyspark columns and schemas in general feels patchy. typedschema tries to make it a bit more natural.

Define a schema

class MySchema(Schema):
    name = Column(StringType(), nullable=False)
    # you can decouple the attribute name from the column name
    favourite_animal = Column(StringType(), nullable=False, name="favourite animal")


myschema = MySchema()

df = spark.createDataFrame(
    [
        ("Carl", "Cat"),
        ("Homer", "Pig"),
    ],
    schema=myschema.spark_schema,
)

Because MySchema is a normal Python class, it can be auto-completed by almost all editors. Additionally myschema.name is also a string. You can use it more naturally:

df.select(myschema.name)  # used as string
df.select(myschema.name.col)  # used as pyspark column
df.withColumn(f"{myschema.name}_uc", F.upper(myschema.name.col))
df.withColumnRenamed(myschema.name, f"{myschema.name}_old")
df.show()

df.createOrReplaceTempView("friends")
# for pyspark sql use "`" (backticks) to deal with space in column names
sqldf = spark.sql(f"SELECT `{myschema.favourite_animal}` FROM friends")

Installation#

pip install pyspark-typedschema

Tour#

Definition#

You define a new schema by inheriting from typedschema.Schema:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime, date
from pyspark.sql.types import (
    DoubleType,
    StringType,
    LongType,
    DateType,
    TimestampType,
)

from typedschema import Column, Schema, diff_schemas, generate_schema_def


class MySchema(Schema):
    a = Column(LongType(), nullable=False)
    b = Column(DoubleType(), nullable=False)
    c = Column(StringType(), nullable=False)
    d = Column(DateType(), nullable=False)
    e = Column(TimestampType(), nullable=False)


myschema = MySchema()

spark = SparkSession.builder.master("local[*]").appName("typedspark").getOrCreate()

Now you can use myschema to generate a DataFrame:

df1 = spark.createDataFrame(
    [
        (1, 2.0, "string1", date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
        (2, 3.0, "string2", date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
        (3, 4.0, "string3", date(2000, 3, 1), datetime(2000, 1, 3, 12, 0)),
    ],
    schema=myschema.spark_schema,
)
df1.show()
df1.printSchema()
+---+---+-------+----------+-------------------+
|a  |b  |c      |d         |e                  |
+---+---+-------+----------+-------------------+
|1  |2.0|string1|2000-01-01|2000-01-01 12:00:00|
|2  |3.0|string2|2000-02-01|2000-01-02 12:00:00|
|3  |4.0|string3|2000-03-01|2000-01-03 12:00:00|
+---+---+-------+----------+-------------------+
root
 |-- a: long (nullable = false)
 |-- b: double (nullable = false)
 |-- c: string (nullable = false)
 |-- d: date (nullable = false)
 |-- e: timestamp (nullable = false)

A Column is also a string, so it behaves like a string. If needed, you can also get a pyspark.sql.column.Column object using .col or .c.

df1.select(F.col(myschema.a)).show()
df1.select(myschema.a).show()
df1.select(myschema.a.col).show()
df1.select(myschema.a.c).show()
df1 = (
    spark.range(3)
    .withColumnsRenamed({"id": myschema.a})
    .withColumn(myschema.b, F.upper(F.concat(myschema.a.col, F.lit("_"), myschema.a.col)))
)
df1.show()
+---+---+
|a  |b  |
+---+---+
|0  |0_0|
|1  |1_1|
|2  |2_2|
+---+---+

Comparing Schemas#

I can test schema equality using Python’s set operations.

Implemented are issubset, issuperset, isequal and contains.

df2 = spark.createDataFrame(
    [
        (1, 2.0, "string1", date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
        (2, 3.0, "string2", date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
        (3, 4.0, "string3", date(2000, 3, 1), datetime(2000, 1, 3, 12, 0)),
    ],
    schema="a long, z double, c string, d date, e timestamp",
)
myschema <= df2.schema  # False, col b missing

Attention

Make sure that the typed schema object (or class) is on the left:

myschema >= df.schema # WORKS
myschema <= df.schema # WORKS

df.schema >= myschema # will not work
df.schema <= myschema # will not work

It can be difficult to see what exactly is different, therefore diff_schemas() is available:

differences = diff_schemas(myschema, df2.schema)

for change, my, other in differences:
    print(f"{change} {my} {other}")
- StructField('b', DoubleType(), False) None
+ None StructField('z', DoubleType(), True)
! StructField('a', LongType(), False) StructField('a', LongType(), True)
! StructField('c', StringType(), False) StructField('c', StringType(), True)
! StructField('d', DateType(), False) StructField('d', DateType(), True)
! StructField('e', TimestampType(), False) StructField('e', TimestampType(), True)

You can also dump it in a DataFrame if you want:

differences_df = spark.createDataFrame(
    [(t, str(a), str(b)) for t, a, b in differences],
    schema="type string, myschema string, df2 string",
)
differences_df.show(truncate=False)
+----+----------------------------------------+---------------------------------------+
|type|myschema                                |df2                                    |
+----+----------------------------------------+---------------------------------------+
|-   |StructField('b', DoubleType(), False)   |None                                   |
|+   |None                                    |StructField('z', DoubleType(), True)   |
|!   |StructField('a', LongType(), False)     |StructField('a', LongType(), True)     |
|!   |StructField('c', StringType(), False)   |StructField('c', StringType(), True)   |
|!   |StructField('d', DateType(), False)     |StructField('d', DateType(), True)     |
|!   |StructField('e', TimestampType(), False)|StructField('e', TimestampType(), True)|
+----+----------------------------------------+---------------------------------------+

Often nullability is not so important, you can disable it via a function param (this won’t work with the operators, though):

df3 = spark.createDataFrame(
    [
        (1, 2.0, "string1", date(2000, 1, 1), datetime(2000, 1, 1, 12, 0), 10),
        (2, 3.0, "string2", date(2000, 2, 1), datetime(2000, 1, 2, 12, 0), 20),
        (3, 4.0, "string3", date(2000, 3, 1), datetime(2000, 1, 3, 12, 0), 30),
    ],
    schema="a long, b double, c string, d date, e timestamp, f long",
)
myschema <= df3.schema # False, differences in nullable
myschema.issubset(df3.schema, strict_null=False) # True, nullable ignored

Code Generation#

Most developers are relatively lazy, so typedschema has the functionality to generate a class definiton from an existing schema or data frame:

class_def = generate_schema_def(df3, name="CustomerDataSchema")
print(class_def)

will output

class CustomerDataSchema(Schema):
    a = Column(LongType(), True)
    b = Column(DoubleType(), True)
    c = Column(StringType(), True)
    d = Column(DateType(), True)
    e = Column(TimestampType(), True)
    f = Column(LongType(), True)
customer_data_schema = CustomerDataSchema()

API Reference#