Objective
A mining company extracts rare minerals from various locations worldwide. It maintains two DataFrames to track its operations: mines, which holds metadata about each location, and extraction, which logs the daily mineral output.
Task
Write a PySpark function that calculates the total quantity of each mineral extracted per location.
The total_quantity column should contain the sum of all quantities of a particular mineral extracted at a specific location, and it must be cast to a Double type. The resulting rows should be sorted first by location (in ascending order) and then by mineral (in ascending order). Save your result as result_df.
File Path
- Mines Dataset:
/home/interview/mines.csv
- Extraction Dataset:
/home/interview/extraction.csv
- Starter script:
/home/interview/mining_aggregation.py
Schema
mines.csv
| Column Name |
Data Type |
| id |
Integer |
| name |
String |
| location |
String |
extraction.csv
| Column Name |
Data Type |
| mine_id |
Integer |
| date |
Date |
| mineral |
String |
| quantity |
Double |
Expected Output Schema
| Column Name |
Data Type |
| location |
String |
| mineral |
String |
| total_quantity |
Double |
Example
Given this sample input:
mines
| id |
name |
location |
| 1 |
Mine Alpha |
Australia |
| 2 |
Mine Beta |
Canada |
| 3 |
Mine Gamma |
South Africa |
extraction
| mine_id |
date |
mineral |
quantity |
| 1 |
2023-06-30 |
Gold |
1000.0 |
| 2 |
2023-06-30 |
Silver |
1200.0 |
| 3 |
2023-06-30 |
Diamond |
800.0 |
| 1 |
2023-06-29 |
Gold |
900.0 |
| 2 |
2023-06-29 |
Silver |
1300.0 |
| 3 |
2023-06-29 |
Diamond |
750.0 |
The expected output would be:
| location |
mineral |
total_quantity |
| Australia |
Gold |
1900.0 |
| Canada |
Silver |
2500.0 |
| South Africa |
Diamond |
1550.0 |
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("PrepareshSpark").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
# Read the datasets
mines = spark.read.csv("/home/interview/mines.csv", header=True, inferSchema=True)
extraction = spark.read.csv("/home/interview/extraction.csv", header=True, inferSchema=True)
# Join the DataFrames on the mine identifier
joined_df = mines.join(extraction, mines.id == extraction.mine_id, "inner")
# Group by location and mineral, calculate total quantity, and sort
result_df = joined_df.groupBy("location", "mineral") \
.agg(F.sum("quantity").cast("double").alias("total_quantity"))
# Enforce strict schema ordering and sort
result_df = result_df.select("location", "mineral", "total_quantity") \
.orderBy(F.col("location").asc(), F.col("mineral").asc())
# --- Do not edit below this line ---
result_df.coalesce(1).write.csv("/home/interview/output", header=True, mode="overwrite")
spark.stop()
Explanation
Step 1: Joining the DataFrames
joined_df = mines.join(extraction, mines.id == extraction.mine_id, "inner")
To calculate the total quantity per location, we must associate every extraction record with its corresponding mine's geographical location. We use an inner join linking the id from the mines table to the mine_id from the extraction table.
Step 2: Grouping by Multiple Columns
result_df = joined_df.groupBy("location", "mineral")
PySpark allows grouping by more than one column. By passing both "location" and "mineral" into .groupBy(), Spark creates unique sub-buckets for every combination. For example, if a single location extracts both Gold and Silver, it will create two distinct buckets for that location rather than clumping them together.
Step 3: Aggregating and Casting Types
.agg(F.sum("quantity").cast("double").alias("total_quantity"))
Within the .agg() function, we sum the quantities. The prompt specifically requests the total_quantity column to be of type Double. Depending on how PySpark infers the CSV schema, simple addition might default to an integer or float. By chaining .cast("double"), we strictly enforce the schema requirement.
Step 4: Formatting and Multi-Level Sorting
result_df = result_df.select("location", "mineral", "total_quantity") \
.orderBy(F.col("location").asc(), F.col("mineral").asc())
To guarantee the output schema perfectly matches the prompt's requirements, we use a .select() block to strictly order the columns. Finally, we sort the aggregated DataFrame. The instructions require ordering first by location ascending, and then by mineral ascending. Passing multiple conditions into .orderBy() applies them in sequence, ensuring a perfectly organized report.