Objective
As an analyst working in a Consumer Goods company, you've been provided a DataFrame with sales data for various stores. This DataFrame includes multiple fields tracking the store, product, category, units sold, and a text description.
Task
The Description column contains text information about the product. Some products have a special discount tagged inside square brackets within this column (e.g., [10% off]).
Write a PySpark function that extracts this discount information and creates a new column called Discount. The discount must be expressed as a decimal (e.g., 0.10 for a 10% discount). If no discount is present in the text, the value should be 0.0.
Keep all other columns as they are and save the resulting DataFrame as result_df. Ensure the columns match the exact order specified in the schema.
File Path
- Dataset:
/home/interview/sales.csv
- Starter script:
/home/interview/discount_parser.py
Schema
sales.csv
| Column Name |
Data Type |
| StoreID |
String |
| ProductName |
String |
| Category |
String |
| SoldUnits |
Integer |
| Description |
String |
Expected Output Schema
| Column Name |
Data Type |
| StoreID |
String |
| ProductName |
String |
| Category |
String |
| SoldUnits |
Integer |
| Description |
String |
| Discount |
Float |
Example
Given this sample input:
df
| StoreID |
ProductName |
Category |
SoldUnits |
Description |
| S101 |
Biscuits |
Food |
120 |
Tasty Biscuits [10% off] |
| S102 |
Shampoo |
Hygiene |
85 |
Smoothens Hair [5% off] |
| S103 |
Banana |
Food |
150 |
Fresh Bananas |
| S101 |
Toothpaste |
Hygiene |
300 |
Protects Teeth |
| S102 |
Shirt |
Clothes |
65 |
Cotton Shirts [20% off] |
The expected output would be:
| StoreID |
ProductName |
Category |
SoldUnits |
Description |
Discount |
| S101 |
Biscuits |
Food |
120 |
Tasty Biscuits [10% off] |
0.1 |
| S102 |
Shampoo |
Hygiene |
85 |
Smoothens Hair [5% off] |
0.05 |
| S103 |
Banana |
Food |
150 |
Fresh Bananas |
0.0 |
| S101 |
Toothpaste |
Hygiene |
300 |
Protects Teeth |
0.0 |
| S102 |
Shirt |
Clothes |
65 |
Cotton Shirts [20% off] |
0.2 |
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("PrepareshSpark").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
df = spark.read.csv("/home/interview/sales.csv", header=True, inferSchema=True)
# Step 1: Extract the numeric discount using a regular expression capture group
result_df = df.withColumn(
"extracted_str",
F.regexp_extract(F.col("Description"), r"\[(\d+)%\s*off\]", 1)
)
# Step 2: Convert the extracted string to a float, handling empty strings (no discount) safely
result_df = result_df.withColumn(
"Discount",
F.when(F.col("extracted_str") == "", 0.0)
.otherwise(F.col("extracted_str").cast("float") / 100.0)
)
# Step 3: Select only the required columns in the exact requested order
result_df = result_df.select(
"StoreID", "ProductName", "Category", "SoldUnits", "Description", "Discount"
)
# --- Do not edit below this line ---
result_df.coalesce(1).write.csv("/home/interview/output", header=True, mode="overwrite")
spark.stop()
Explanation
Step 1: Extracting the Pattern with Regex
result_df = df.withColumn(
"extracted_str",
F.regexp_extract(F.col("Description"), r"\[(\d+)%\s*off\]", 1)
)
To pull the exact number out of the string, we use PySpark's regexp_extract(). The pattern r"\[(\d+)%\s*off\]" looks for a literal opening bracket [, followed by one or more digits (\d+), followed by a percent sign, a space, "off", and a closing bracket ]. By wrapping the \d+ in parentheses, we create a "capture group." Passing 1 as the final argument tells PySpark to return only the digits captured inside those parentheses.
Step 2: Safe Casting and Math
result_df = result_df.withColumn(
"Discount",
F.when(F.col("extracted_str") == "", 0.0)
.otherwise(F.col("extracted_str").cast("float") / 100.0)
)
If a row does not contain the discount tag (like "Fresh Bananas"), regexp_extract returns an empty string "". In newer versions of PySpark, trying to .cast("float") on an empty string throws a fatal NumberFormatException and crashes the job.
To prevent this, we use F.when() to safely catch empty strings and immediately output 0.0. If a number was extracted, we cast it to a float and divide by 100.0 to convert the percentage into a decimal (0.10).
Step 3: Formatting the Output Schema
result_df = result_df.select(
"StoreID", "ProductName", "Category", "SoldUnits", "Description", "Discount"
)
Because we created an intermediate column (extracted_str) to make our code readable, we use .select() at the end of the script to drop it and ensure the final columns exactly match the requested schema order.