Eval model and figure prob vs emission rate
Eval results of the model
This notebook has the stats of the model evaluated against a couple of baselines: the MBMP threshold model and the CH4Net model of Vaughan et al 2024. We reproduce in this notebook figure 2, 3 and several other figures in the supplementary material. In addition in this notebook we have the statistics that are mentioned in the text (precision, recall, average precision, false positive rates...)
This notebook read the cached results from the Hugging Face repository. For re-evaluate the model see README.md.
Install marss2l package
pip install marss2l
%%time
import matplotlib
from marss2l.utils import setup_stream_logger, get_remote_filesystem, pathjoin
from marss2l import loaders
import logging
import os
import uuid
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import PrecisionRecallDisplay, precision_score, recall_score, accuracy_score,\
average_precision_score, confusion_matrix
from sklearn.metrics import RocCurveDisplay, roc_curve
import numpy as np
from marss2l.plot.colors import C0, C1, C2, C3, C4
import seaborn as sns
import json
from huggingface_hub import hf_file_system
from marss2l.huggingface import REPO_ID
fs = hf_file_system.HfFileSystem()
logger = logging.getLogger(__name__)
setup_stream_logger(logger)
os.makedirs("figures", exist_ok=True)
csv_path = f"datasets/{REPO_ID}/validated_images_all.csv"
dataframe_data_traintest = loaders.read_csv(csv_path,
add_columns_for_analysis=True, fs=fs,
split="all",
add_case_study=True, add_loc_type=True)
dataframe_data_traintest.shape
# Sanity check: fluxrate is positive and defined always
dataframe_data_traintest.loc[(dataframe_data_traintest.ch4_fluxrate < 0) | dataframe_data_traintest.ch4_fluxrate.isna(),["location_name","satellite","tile_date","observability","isplume","ch4_fluxrate","country","wind_speed"]]
# Sanity check: fluxrate is zero for no plumes
dataframe_data_traintest[~dataframe_data_traintest.isplume].ch4_fluxrate.sum()
# Sanity check: All isplume have positive flux rates
dataframe_data_traintest[dataframe_data_traintest.isplume & (dataframe_data_traintest.ch4_fluxrate == 0)]
Load CSV files with evaluation results
In the next cell we load the evaluation results. These results are loaded from the pre-computed CSV files in the basefolder_experiments path. If you re-evaluated the model following the instructions in the README.md, change the basefolder_experiments to your local path.
pd.options.display.float_format = "{:,.2f}".format
from importlib import reload
from marss2l import validation_utils
# basefolder_experiments = "xxxx/MARS-S2L/train_logs_revision/"
basefolder_experiments = f"datasets/{REPO_ID}/trained_models/"
outs = []
ids_all = None
expload = [
("CH4Netsim_20250605", "CH4Net (sim)", "preds_test_2023thr100"),
("MARSS2Lnosim_20250605", "MARS-S2L (no sim)", "preds_test_2023thr100"),
("MARSS2L_20250326", "MARS-S2L", "preds_test_2023th100"),
("MARSS2L_off_20250523", "MARS-S2L (offshore)", "preds_test_2023thr100"),
("CH4Net_20250329","CH4Net", "preds_test_2023thr100"),
("MBMP","MBMP", "preds_test_2023th100"),
]
config_experiments = dict()
for train_folder, model_name, csv_file in expload:
output, config = validation_utils.load_stats_and_config(train_folder, model_name, csv_file=csv_file,
basefolder_experiments=basefolder_experiments,fs=fs,
logger=logger)
if config is not None:
config_experiments[model_name] = config
ids_iter = set(output["id_loc_image"].values)
if ids_all is None:
ids_all = ids_iter
else:
ids_all = ids_all.intersection(ids_iter)
outs.append(output)
print(f"There are {len(ids_all)} common ids")
outs = pd.concat(outs, ignore_index=True)
outs = outs[outs.id_loc_image.isin(ids_all)].copy()
outs.groupby(["model_name","target"])[["id_loc_image"]].count()
threshold_mbmp = -.99
threshold_marss2l = 0.5
outs_merge = outs.drop(["location_name", "tile"], axis=1)
outs_same_period_with_fluxrate = pd.merge(outs_merge, dataframe_data_traintest[dataframe_data_traintest.split_name == "test_2023"],
on ="id_loc_image")
outs_same_period_with_fluxrate = outs_same_period_with_fluxrate.drop("target", axis=1) # Use isplume as GT from dataframe_data_traintest
outs_same_period_with_fluxrate["isplumenum"] = outs_same_period_with_fluxrate["isplume"].astype(int)
# outs_same_period_with_fluxrate["isplumeprednum"] = (outs_same_period_with_fluxrate["scene_pred"] > 0.5).astype(int)
outs_same_period_with_fluxrate["isplumeprednum"] = outs_same_period_with_fluxrate.apply(lambda row: row.scene_pred > threshold_mbmp if row.model_name.startswith("MBMP") else row.scene_pred > threshold_marss2l, axis=1).astype(int)
outs_same_period_with_fluxrate["scenepredcontinuous"] = outs_same_period_with_fluxrate["scene_pred"]
outs_same_period_with_fluxrate.id_loc_image.nunique()
Combine offshore and onshore predictions for MARS-S2L model
marss2l_onshore = outs_same_period_with_fluxrate[((outs_same_period_with_fluxrate.model_name == "MARS-S2L") & ~outs_same_period_with_fluxrate.offshore)].copy()
marss2l_onshore["model_name"] = "MARS-S2L (combined)"
marss2l_offshore = outs_same_period_with_fluxrate[((outs_same_period_with_fluxrate.model_name == "MARS-S2L (offshore)") & outs_same_period_with_fluxrate.offshore)].copy()
marss2l_offshore["model_name"] = "MARS-S2L (combined)"
outs_same_period_with_fluxrate = pd.concat([outs_same_period_with_fluxrate, marss2l_onshore, marss2l_offshore], ignore_index=True)
outs_same_period_with_fluxrate.id_loc_image.nunique()
# Rename MARS-S2L combined and drop offshore
outs_same_period_with_fluxrate = outs_same_period_with_fluxrate[~outs_same_period_with_fluxrate.model_name.isin(["MARS-S2L (offshore)", "MARS-S2L"])].copy()
outs_same_period_with_fluxrate.loc[outs_same_period_with_fluxrate.model_name == "MARS-S2L (combined)","model_name"] = "MARS-S2L"
outs_same_period_with_fluxrate.groupby(["model_name","isplume"])[["id_loc_image"]].count()
# Sanity check: all models are evaluated in the same number of images
outs_same_period_with_fluxrate.groupby("model_name")["tile_date"].agg(["min", "max", "count"])
# Sanity check: all models are evaluated in the same number of images and same number of plumes/noplumes
outs_same_period_with_fluxrate.groupby(["isplume","model_name"])[["id_loc_image"]].count()
outs_same_period_with_fluxrate.model_name.unique()
PR curves and general metrics
Overall metrics
Final model is MARS-S2L (U326v309)
from marss2l.metrics import get_scenelevel_metrics, get_pixellevel_metrics
mets = []
for idx_threshold, (threshold_marss2l_iter, threshold_mbmp_iter) in enumerate(zip([0.5, 0.9, 0.98],[threshold_mbmp, -0.9, -0.85])):
# for model in ['MARS-S2L', 'CH4Net', "MBMP"]:
for model in ['MARS-S2L', 'MARS-S2L (no sim)','CH4Net (sim)', 'CH4Net', "MBMP"]:
dg = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.model_name == model]
threshold = threshold_marss2l_iter if not model.startswith("MBMP") else threshold_mbmp_iter
mets_iter = get_scenelevel_metrics(dg.scenepredcontinuous, dg.isplumenum, threshold=threshold,
as_percentage=True)
mets_seg = get_pixellevel_metrics(TP=dg.TP, TN=dg.TN, FP=dg.FP, FN=dg.FN,
as_percentage=True)
mets_iter.update(mets_seg)
mets_iter.update({"nsamples": dg.shape[0],
"nlocs": dg.location_name.nunique(),
"nplumes": dg.isplumenum.sum(),
"threshold": threshold,
"idx_threshold": idx_threshold,
"nnoplume": (1-dg.isplumenum).sum(),
"model_name": model})
mets.append(mets_iter)
mets = pd.DataFrame(mets)# .sort_values(["balanced_accuracy"], ascending=False)
overall_mets = mets[["model_name", "threshold"]+[c for c in mets.columns if c not in ["model_name", "threshold"]]].copy()
overall_mets
cols_metrics = ['model_name', 'average_precision', 'precision', 'recall', 'accuracy',
'fpr']
cols_metrics_segmentation = ['segmentation_precision', 'segmentation_recall',
'segmentation_accuracy', 'segmentation_fpr', 'iou']
models_plot_recall = ["MBMP", "CH4Net", "MARS-S2L"]
overall_mets_table_7 = overall_mets[(overall_mets.idx_threshold == 0) & overall_mets.model_name.isin(models_plot_recall)].copy()
print(overall_mets_table_7[cols_metrics+cols_metrics_segmentation].to_latex(index=False,float_format="%.2f"))
overall_mets_table_7[cols_metrics+cols_metrics_segmentation]
models_plot_recall = ["MBMP", "CH4Net", "MARS-S2L"]
col_metrics_table = ["model_name","threshold","precision", "recall", "fpr"]
overall_mets_table_thresholds = overall_mets.loc[overall_mets.model_name.isin(models_plot_recall)]
print(overall_mets_table_thresholds[col_metrics_table].to_latex(index=False,float_format="%.2f"))
overall_mets_table_thresholds[col_metrics_table]
col_metrics_table
fig, ax = plt.subplots(1,2,figsize=(8, 4),tight_layout=True, sharey=False)
iaxes_roc = 1
iaxes_pr = 0
overall_mets_indexed = overall_mets.set_index("model_name")
colors = [C0, C4, C2, C3, C1]
# models_plot_prcurve = ["MBMP", "CH4Net", "MARS-S2L"] # "MARS-S2L (no sim)", "CH4Net (sim)"
models_plot_prcurve = ["MBMP", "CH4Net", "MARS-S2L", "MARS-S2L (no sim)", "CH4Net (sim)"] # "MARS-S2L (no sim)", "CH4Net (sim)"
df_show = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.model_name.isin(models_plot_prcurve)].copy()
for _i_prev, model_name in enumerate(models_plot_prcurve):
dg = df_show[df_show.model_name == model_name]
display = PrecisionRecallDisplay.from_predictions(dg.isplumenum, dg.scenepredcontinuous,
plot_chance_level=_i_prev == 0,
color=colors[_i_prev],
ax=ax[iaxes_pr], name=model_name)
display = RocCurveDisplay.from_predictions(dg.isplumenum, dg.scenepredcontinuous,
plot_chance_level=_i_prev == 0,
color=colors[_i_prev],
ax=ax[iaxes_roc], name=model_name)
ax[0].scatter([overall_mets_indexed.loc[model_name, "recall"]/100],
[overall_mets_indexed.loc[model_name, "precision"]/100],
c=colors[_i_prev],
s=100,
marker="x")
ax[1].scatter([overall_mets_indexed.loc[model_name, "fpr"]/100],
[overall_mets_indexed.loc[model_name, "recall"]/100],
c=colors[_i_prev],
s=100,
marker="x")
# ax[iaxes_pr].legend(loc="upper right")
ax[iaxes_pr].get_legend().remove()
ax[iaxes_roc].grid()
ax[iaxes_pr].grid()
ax[iaxes_pr].set_xticks(np.arange(0,1.05,.1))
ax[iaxes_pr].set_yticks(np.arange(0,1.05,.1))
ax[iaxes_pr].set_title("PR-Curve")
ax[iaxes_roc].set_title("ROC curve")
ax[iaxes_roc].set_xticks(np.arange(0,1.05,.1))
ax[iaxes_roc].set_xlabel("FPR")
ax[iaxes_roc].set_ylabel("Recall (TPR)")
ax[iaxes_pr].set_ylabel("Precision")
ax[iaxes_pr].set_xlabel("Recall (TPR)")
_ = ax[iaxes_roc].set_yticks(np.arange(0,1.05,.1))
plt.savefig("figures/pr_and_roc_curves.pdf")
Compute metrics stratified by type of location
There are two types of locs: 1. Locs that were used for training. 1. Locs that were not used for training.
model_ref = "MARS-S2L"
locs_train = set(config_experiments[model_ref]['all_locs_train'])
locs_film = set()
# if "id zero" in model_ref:
# locs_film = set()
# else:
# locs_film = set([k for k,v in config_experiments[model_ref]["film_dict_mapping"].items() if v > 0])
print(len(locs_train), len(locs_film))
outs_same_period_with_fluxrate["loc_type_train"] = outs_same_period_with_fluxrate.location_name.apply(lambda x: "FiLM" if x in locs_film else "some samples" if x in locs_train else "no samples")
aggs_sanity = outs_same_period_with_fluxrate.groupby("loc_type_train")["location_name"].agg(["nunique"])
aggs_sanity["nimages"] = outs_same_period_with_fluxrate.groupby("loc_type_train")["id_loc_image"].nunique()
# aggs_sanity["nplumes"] = outs_same_period_with_fluxrate.groupby("loc_type_train")["isplumenum"].sum()
aggs_sanity.rename(columns={"nunique":"nlocs"})
# Sanity check
aggs_sanity = outs_same_period_with_fluxrate.groupby("loc_type_train")["location_name"].agg(["nunique"])
aggs_sanity["nimages"] = outs_same_period_with_fluxrate.groupby("loc_type_train")["id_loc_image"].nunique()
aggs_sanity["nplumes"] = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.model_name == model_ref].groupby("loc_type_train")["isplumenum"].sum()
aggs_sanity.rename(columns={"nunique":"nlocs"})
from itertools import product
mets = []
for (model, loc_type), dg in outs_same_period_with_fluxrate.groupby(["model_name","loc_type_train"]):
threshold = threshold_marss2l if not model.startswith("MBMP") else threshold_mbmp
mets_iter = get_scenelevel_metrics(dg.scenepredcontinuous, dg.isplumenum, as_percentage=True,threshold=threshold)
mets_seg = get_pixellevel_metrics(TP=dg.TP, TN=dg.TN, FP=dg.FP, FN=dg.FN, as_percentage=True)
mets_iter.update(mets_seg)
mets_iter.update({"nsamples": dg.shape[0],
"nlocs": dg.location_name.nunique(),
"nplumes": dg.isplumenum.sum(),
"nnoplume": (1-dg.isplumenum).sum(),
"loc_type": loc_type,
"model_name": model})
mets.append(mets_iter)
mets = pd.DataFrame(mets).sort_values(["loc_type","balanced_accuracy"], ascending=False)
overall_mets_strat_type_of_loc = mets[["model_name"]+[c for c in mets.columns if c != "model_name"]].copy()
models_show_by_type_of_loc = ["MBMP", "CH4Net", "MARS-S2L", "MARS-S2L (no sim)", "CH4Net (sim)"]
overall_mets_strat_type_of_loc.loc[overall_mets_strat_type_of_loc.model_name.isin(models_show_by_type_of_loc),
['loc_type'] + cols_metrics].set_index("loc_type")
print(overall_mets_strat_type_of_loc.loc[overall_mets_strat_type_of_loc.model_name.isin(models_show_by_type_of_loc),
['loc_type'] + cols_metrics].set_index("loc_type").to_latex(float_format="%.2f"))
# locs_type = ["few samples", "no samples", "FiLM"]
locs_type = ["some samples", "no samples"]
fig, ax = plt.subplots(1, len(locs_type),figsize=(len(locs_type) * 5, 5))
colors = [C0, C2, C2, C1, C1]
models_show = ["MBMP", "CH4Net", "MARS-S2L"]
for i, loc_type_show in enumerate(locs_type):
outs_same_period_with_fluxrate_loc = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.loc_type_train == loc_type_show]
for _i_prev, (model_name, out_same_period) in enumerate(outs_same_period_with_fluxrate_loc.groupby("model_name")):
if model_name not in models_show:
continue
display = PrecisionRecallDisplay.from_predictions(out_same_period.isplumenum, out_same_period.scenepredcontinuous,
plot_chance_level=_i_prev == 0,c=colors[_i_prev],
ax=ax[i], name=model_name)
ax[i].set_title(f"Loc type: {loc_type_show}")
ax[i].set_xticks(np.arange(0,1.05,.1))
ax[i].set_yticks(np.arange(0,1.05,.1))
ax[i].legend(loc="upper right")
ax[i].grid()
Metrics plumes by flux rate
# Sanity check interval of plumes
outs_same_period_with_fluxrate.groupby(["isplume","interval_ch4_fluxrate_str"])[["id_loc_image"]].nunique()
Fig prob vs emission rate
from marss2l import plot
# from importlib import reload
# reload(plot)
model_names_plot = ["CH4Net", "MARS-S2L"] # "CH4Net",
df_plot = outs_same_period_with_fluxrate.loc[outs_same_period_with_fluxrate.model_name.isin(model_names_plot)]
fig, ax = plot.plot_prob_vs_emission_rate(df_plot)
from marss2l import plot
from marss2l.plot import prob_vs_emission_rate
from importlib import reload
# reload(plot)
reload(prob_vs_emission_rate)
model_names_plot = ["MARS-S2L"] # "CH4Net",
df_plot = outs_same_period_with_fluxrate.loc[outs_same_period_with_fluxrate.model_name.isin(model_names_plot)]
fig, ax = prob_vs_emission_rate.plot_prob_vs_emission_rate(df_plot, figsize=(9,4))
Recall vs Fluxrate
from marss2l.plot import recall_fluxrate_plot
from marss2l.plot import plot_recall_fpr_fluxrate
import numpy as np
from importlib import reload
reload(recall_fluxrate_plot)
models_plot_recall = ["MBMP", "CH4Net", "MARS-S2L"]
plt.figure(figsize=(10,3), layout="constrained")
fig, axs = recall_fluxrate_plot.plot_recall_fpr_fluxrate(outs_same_period_with_fluxrate, order_models=models_plot_recall,
loc_legend="center right", add_legend=False)
axs[1].set_title("Global results")
axs[0].set_yticks(np.arange(0,.55,.10))
# axs[1].set_yticks(np.arange(.5,.95,.1).tolist() + np.arange(.95,1.04,.05).tolist())
plt.savefig("figures/fig2_overall.pdf")
Non-cummulative
plt.figure(figsize=(10,3), layout="constrained")
fig, axs = recall_fluxrate_plot.plot_recall_fpr_fluxrate(outs_same_period_with_fluxrate, order_models=models_plot_recall,
loc_legend="center right", add_legend=False,
cummulative=False)
axs[1].set_title("Global results")
axs[0].set_yticks(np.arange(0,.55,.10))
# axs[1].set_yticks(np.arange(.5,.95,.1).tolist() + np.arange(.95,1.04,.05).tolist())
plt.savefig("figures/fig2_overall_noncum.pdf")
Results in case studies
pd.options.display.max_rows = 200
mets = []
for (model, case_study), dg in outs_same_period_with_fluxrate.groupby(["model_name","case_study"]):
# if case_study not in countries:
# continue
# if "(id zero)" in model:
# continue
n_locs_few_samples = dg[dg.loc_type_train == "some samples"].location_name.nunique()
n_locs_no_samples = dg[dg.loc_type_train == "no samples"].location_name.nunique()
# n_locs_film = dg[dg.loc_type == "FiLM"].location_name.nunique()
threshold = threshold_marss2l if not model.startswith("MBMP") else threshold_mbmp
mets_iter = get_scenelevel_metrics(dg.scenepredcontinuous, dg.isplumenum, as_percentage=True,threshold=threshold)
mets_seg = get_pixellevel_metrics(TP=dg.TP, TN=dg.TN, FP=dg.FP, FN=dg.FN, as_percentage=True)
mets_iter.update(mets_seg)
mets_iter.update({
"case_study": case_study,
"model_name": model,
"nimages": dg.shape[0],
"nlocs": dg.location_name.nunique(),
"nplumes": dg.isplumenum.sum(),
"nlocs samples train": n_locs_few_samples,
"nlocs no samples at train time": n_locs_no_samples,
# "nlocs FiLM": n_locs_film,
"nnoplume": (1-dg.isplumenum).sum()})
mets.append(mets_iter)
mets = pd.DataFrame(mets).sort_values(["case_study","model_name"])
mets_case_studies = mets[['case_study', 'model_name', 'precision', 'recall', 'fpr',"balanced_accuracy",
'average_precision', 'nimages','nplumes', # 'nlocs FiLM'
'nnoplume', 'nlocs',
'nlocs samples train', 'nlocs no samples at train time']].sort_values(["nimages","balanced_accuracy"], ascending=False).copy()
mets_case_studies
print(mets_case_studies.loc[mets_case_studies.model_name != "MARS-S2L (first submission)",
['case_study','model_name', 'average_precision', 'precision', 'recall', "fpr"]].to_latex(index=False,float_format="%.2f"))
Fig 2. Results by case study
import matplotlib.gridspec as gridspec
case_studies = loaders.ORDER_CASE_STUDIES[:-1]
fig = plt.figure(figsize=(14, 2 * 6), layout="constrained") # Wider figure to accommodate 4 columns
gs = gridspec.GridSpec(6, 4, width_ratios=[1, 5, 1, 5], figure=fig)
# models_plot_recall = ["MBMP", "CH4Net", "MARS-S2L"]
model_names_plot = models_plot_recall
# Initialize arrays to track shared axes for each column
ax1_col1, ax_col1 = None, None # For left side (columns 0-1)
ax1_col2, ax_col2 = None, None # For right side (columns 2-3)
for i, case_study in enumerate(case_studies[::2] + case_studies[1::2]):
# Calculate position in the grid
row = i % 6 # Rows 0-5
col_offset = (i // 6) * 2 # 0 for first 6 case studies, 2 for next 6
# Get the appropriate shared axes based on which side we're plotting
if col_offset == 0: # Left side
ax1 = fig.add_subplot(gs[row, col_offset], sharex=ax1_col1)
ax = fig.add_subplot(gs[row, col_offset + 1], sharex=ax_col1)
# Update shared axes references for left side
if ax1_col1 is None:
ax1_col1, ax_col1 = ax1, ax
else: # Right side
ax1 = fig.add_subplot(gs[row, col_offset], sharex=ax1_col2)
ax = fig.add_subplot(gs[row, col_offset + 1], sharex=ax_col2)
# Update shared axes references for right side
if ax1_col2 is None:
ax1_col2, ax_col2 = ax1, ax
# Plot the data
df_plot = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.case_study == case_study]
loc_legend = "center left" if i == 0 else "upper left"
fig, axs = recall_fluxrate_plot.plot_recall_fpr_fluxrate(
df_plot,
order_models=model_names_plot,
loc_legend=loc_legend,
yticks_recall=np.arange(0,1.1, 0.2),
add_legend=(i == 0) or (i == 6), # Only add legend for the first two plots
fig=fig,
axs=(ax1, ax)
)
ax1.set_ylim(0,1)
axs[1].set_title(case_study)
if i != 5 and i != 11:
axs[0].xaxis.set_visible(False)
axs[1].xaxis.set_visible(False)
plt.savefig("figures/fig2_case_studies_new.pdf")
Non-cummulative
import matplotlib.gridspec as gridspec
case_studies = loaders.ORDER_CASE_STUDIES[:-1]
fig = plt.figure(figsize=(14, 2 * 6), layout="constrained") # Wider figure to accommodate 4 columns
gs = gridspec.GridSpec(6, 4, width_ratios=[1, 5, 1, 5], figure=fig)
# models_plot_recall = ["MBMP", "CH4Net", "MARS-S2L"]
model_names_plot = models_plot_recall
# Initialize arrays to track shared axes for each column
ax1_col1, ax_col1 = None, None # For left side (columns 0-1)
ax1_col2, ax_col2 = None, None # For right side (columns 2-3)
for i, case_study in enumerate(case_studies[::2] + case_studies[1::2]):
# Calculate position in the grid
row = i % 6 # Rows 0-5
col_offset = (i // 6) * 2 # 0 for first 6 case studies, 2 for next 6
# Get the appropriate shared axes based on which side we're plotting
if col_offset == 0: # Left side
ax1 = fig.add_subplot(gs[row, col_offset], sharex=ax1_col1)
ax = fig.add_subplot(gs[row, col_offset + 1], sharex=ax_col1)
# Update shared axes references for left side
if ax1_col1 is None:
ax1_col1, ax_col1 = ax1, ax
else: # Right side
ax1 = fig.add_subplot(gs[row, col_offset], sharex=ax1_col2)
ax = fig.add_subplot(gs[row, col_offset + 1], sharex=ax_col2)
# Update shared axes references for right side
if ax1_col2 is None:
ax1_col2, ax_col2 = ax1, ax
# Plot the data
df_plot = outs_same_period_with_fluxrate[outs_same_period_with_fluxrate.case_study == case_study]
loc_legend = "center left" if i == 0 else "upper left"
fig, axs = recall_fluxrate_plot.plot_recall_fpr_fluxrate(
df_plot,
order_models=model_names_plot,
loc_legend=loc_legend,
yticks_recall=np.arange(0,1.1, 0.2),
cummulative=False,
add_legend=(i == 0) or (i == 6), # Only add legend for the first two plots
fig=fig,
axs=(ax1, ax)
)
ax1.set_ylim(0,1)
axs[1].set_title(case_study)
if i != 5 and i != 11:
axs[0].xaxis.set_visible(False)
axs[1].xaxis.set_visible(False)
plt.savefig("figures/fig2_case_studies_new_noncummulative.pdf")
Prob. vs fluxrate case studies
model_names_plot = ["MARS-S2L", "CH4Net"] # "CH4Net",
for case_study, dg in outs_same_period_with_fluxrate.groupby("case_study"):
df_plot = dg.loc[dg.model_name.isin(model_names_plot)]
fig, ax = plot.plot_prob_vs_emission_rate(df_plot)
ax[1].set_title(case_study)
plt.show(fig)
plt.close(fig)
Histograms predicted probability in the case studies
from marss2l import plot
from importlib import reload
reload(plot)
model = "MARS-S2L"
outs_same_period_with_fluxrate["target"] = outs_same_period_with_fluxrate["isplumenum"]
for case_study, dg in outs_same_period_with_fluxrate.groupby("case_study"):
if case_study == "None":
continue
fig, axs = plt.subplots(2,1, figsize=(6,8), sharex=True, sharey='row')
dg = outs_same_period_with_fluxrate[(outs_same_period_with_fluxrate.model_name == model) & (outs_same_period_with_fluxrate.case_study == case_study)]
plot.plot_row(dg, model_name="MARS-S2L", axs=axs)
plt.suptitle(case_study)
plt.show()
plt.close(fig)