# Copyright (c) Carole Sudre
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Overall process - :mod:`MetricsReloaded.processes.overall_process`
====================================================================
This module provides class to perform the :ref:`overall evaluation process <processeval>`.
.. _processeval:
.. autoclass:: ProcessEvaluation
:members:
The different categories of task considered are:
* ImLC - Image Level Classification
* SemS - Semantic Segmentation
* ObD - Object detection
* InS - Instance segmentation
For each of these tasks only certain metrics are available and suitable. Error messages will be given and the processing interrupted if the chosen
task and the chosen evaluation measures are not compatible.
Evaluation measures are classified into the following categories:
* Per class counting measures - measures_pcc
* Multi class counting measures - measures_mcc
* Overlap measures - measures_overlap
* Boundary measures - measures_boundary
* Multi threshold measures - measures_mt
* Calibration measures - measures_cal
* Combined detection and segmentation metrics - measures_detseg
The available measures per task are:
* ImLC:
* multi threshold measures:
* auroc - Area under the Receiver Operator Curve
* ap - Average Precision
* sens@spec - Sensitivity at Specificity
* spec@sens - Specificity at Sensitivity
* ppv@sens - Positive Predictive value at sensitivity
* per class counting measures:
* fbeta - FBeta score
* lr+ - positive likelihood ratio
* accuracy
* ba - balance accuracy
* ec - expected cost
* nb - net benefit
* numb_ref - number in reference
* numb_pred - number in prediction
* numb_tp - number of true positives
* numb_fp - number of false positives
* numb_fn - number of false negatives
* cohens_kappa
* multi class counting measures:
* mcc - matthews correlation coefficient
* wck - weighted cohen's kappa
* ec - expected cost
* calibration measures:
* ls - logarithmic score
* bs - Brier Score
* cwece - Class-wise expectation calibration error
* nll - Negative log-likelihood
* rbs - Root Brier Score
* ece_kde - Expectation Calibration Error with Kernel density estimation
* kce - Kernel Calibration error
* ece - Expectation Calibration Error
* Object Detection - ObD:
* per class counting measures:
* fbeta - FBeta score
* numb_pred - number of predicted elements
* numb_tp - number of true positives
* numb_fp - number of false positives
* numb_fn - number of false negatives
* numb_ref - number of reference elements
* sensitivity - sensitivity
* multi-threshold measures:
* sens@spec - sensitivity at specificity
* spec@sens - specificity at sensitivity
* sens@ppv - sensitivity at positive predictive value
* ppv@sens - positive predictive value at sensitivity
* sens@fppi - sensitivity at false positive per image
* fppi@sens - false positive per image at sensitivity
* ap - average precision
* froc - free receiver operator curve
* Semantic segmentation - SemS:
* per class measures of overlap:
* dsc - dice similarity coefficient
* fbeta - FBeta score
* cldice - centreline dice
* iou - intersection over union
* measures of boundary quality:
* assd - average symmetric surface distance
* masd - mean average surface distance
* hd - hausdorff distance
* hd_perc - percentile of hausdorff distance
* nsd - normalised surface dice
* boundary_iou - boundary intersection over union
* per class counting :
* numb_ref - number of reference elements
* numb_pred - number of predicted elements
* numb_tp - number of true positives
* numb_fp - number of false positives
* numb_fn - number of false negatives
* Instance segmentation - InS:
* combined measures of detection and segmentation
* pq - panoptic quality
* per class counting measures:
* fbeta - FBeta score
* numb_ref - number of reference instances
* numb_pred - number of prediction instances
* numb_tp - number of true positives
* numb_fp - number of false positives
* numb_fn - number of false negatives
* multi-threshold measures:
* sens@spec - sensitivity at specificity
* spec@sens - specificity at sensitivity
* sens@ppv - sensitivity at positive predictive value
* ppv@sens - positive predictive value at sensitivity
* fppi@sens - false positive per image at sensitivity
* sens@fppi - sensitivity at false positive per image
* ap - average precision
* froc - free receiver operator curve
* measures of overlap:
* dsc - dice similarity coefficient
* fbeta - fbeta score
* cldice - centreline dice similarity coefficient
* iou - intersection over union
* measures of boundary quality:
* hd - hausdorff distance
* boundary_iou - boundary intersection over union
* masd - mean average surface distance
* assd - average symmetric surface distance
* nsd - normalised surface dice
* hd_perc - percentile of hausdorff distance
"""
from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures
from MetricsReloaded.processes.mixed_measures_processes import *
import warnings
from MetricsReloaded.utility.utils import combine_df, merge_list_df
import pandas as pd
import numpy as np
__all__ = [
"ProcessEvaluation",
]
dict_valid={
'ImLC': ['auroc','ap','sens@spec','spec@sens',
'ppv@sens','fbeta','accuracy','ba',
'ec','nb','mcc',
'wck','lr+','bs','cwece',
'nll','rbs','ece_kde','kce','ece',"numb_ref",'ls',
"numb_pred",
"numb_tp",
"numb_fp",
"numb_fn",]
,
'ObD': ['fbeta','sens@spec','spec@sens','sens@ppv','ppv@sens','sens@fppi','fppi@sens','sensitivity','ap','froc', "numb_ref",
"numb_pred",
"numb_tp",
"numb_fp",
"numb_fn",
],
'SemS': ['dsc','fbeta','cldice','iou','assd','masd','hd','hd_perc','nsd','boundary_iou',"numb_ref",
"numb_pred",
"numb_tp",
"numb_fp",
"numb_fn",
],
'InS': ['pq','fbeta','sens@spec','spec@sens','sens@ppv','ppv@sens',
'fppi@sens','sens@fppi','ap','froc','dsc','cldice','iou','hd','boundary_iou',
'masd','assd','nsd','hd_perc',"numb_ref",
"numb_pred",
"numb_tp",
"numb_fp",
"numb_fn",]
}
MAX = 1000
WORSE = {
"ap": 0,
"auroc": 0,
"froc": 0,
"sens@spec": 0,
"sens@ppv": 0,
"spec@sens": 0,
"fppi@sens": MAX,
"ppv@sens": 0,
"sens@fppi": 0,
"fbeta": 0,
"ec":1,
"accuracy": 0,
"ba": 0,
"lr+": 0,
"youden_ind": -1,
"mcc": 0,
"wck": -1,
"cohens_kappa": -1,
"iou": 0,
"dsc": 0,
"cldice": 0,
"masd": MAX,
"assd": MAX,
"hd_perc": MAX,
"hd": MAX,
"boundary_iou": 0,
"nsd": 0,
}
[docs]class ProcessEvaluation(object):
"""
Performs the evaluation of the data stored in a pickled file according to all the measures, categories and choices of processing
:param data: dictionary containing all the data to be used for the comparison; possible keys include "pred_loc", "ref_loc", "pred_prob",
:param category: task to be considered choice among ImLC, ObD, SemS, InS
:param measures_pcc: list of per class counting measures (these need to be adequate for the chosen task category)
:param measures_mcc: list of multi class counting measures
:param measures_boundary: list of measures to assess boundary quality
:param measures_overlap: list of measures to assess overlap quality
:param measures_mt: list of multi-threshold measures
:param measures_detseg: list of measures assessing jointly detection and segmentation performance
:param measures_cal: list of calibration measures (only available for image level classification class)
:param localization: choice for localization strategy (used in Instance segmentation and Object detection tasks)
:param assignment: choice for the assignment strategy (used in Instance segmentation and Object detection tasks)
:param pixdim: pixel dimensions as list
:param flag_map: indication whether nifti images indicating true positive elements for the reference, the prediction and errors should be created (done only for instance segmentation)
:param file: name of files
:param thresh_ass: threshold chosen for the assignment (default 0.5)
:param case: indication of the handling of cases separately (True) or jointly (False)
:param flag_fp_in: indicates that false positive should be accounted for
:param ignore_missing: indicates whether the missing predictions should be considered in the overall assessment (True) or not (False)
"""
def __init__(
self,
data,
category,
measures_pcc=[],
measures_mcc=[],
measures_boundary=[],
measures_overlap=[],
measures_mt=[],
measures_detseg=[],
measures_cal=[],
localization='mask_iou',
assignment='greedy_matching',
pixdim=[],
flag_map=False,
file=[],
thresh_ass=0.5,
case=True,
flag_fp_in=True,
ignore_missing = False
):
self.data = data
self.category = category
self.assignment = assignment
self.localization = localization
self.measures_overlap = measures_overlap
self.measures_boundary = measures_boundary
self.measures_mt = measures_mt
self.measures_mcc = measures_mcc
self.measures_pcc = measures_pcc
self.measures_detseg = measures_detseg
self.measures_cal = measures_cal
self.flag_map = flag_map
self.thresh_ass = thresh_ass
self.case = case
self.flag_fp_in = flag_fp_in
self.flag_ignore_missing = ignore_missing
self.flag_valid = self.check_valid_measures_cat()
self.pixdim = pixdim
if self.flag_valid:
self.process_data()
if 'ref_missing' in self.data.keys():
self.complete_missing_cases()
if 'weights_labels' in self.data.keys():
self.weights_labels = self.data['weights_labels']
else:
self.weights_labels = {}
for v in self.data['list_values']:
self.weights_labels[v] = 1
self.grouped_lab = self.label_aggregation()
if self.case:
self.get_stats_res()
[docs] def check_valid_measures_cat(self):
"""
Function checking whether the category and the combination of measures suggested are suitable for continuing the process
:return: flag_valid
"""
flag_valid = True
if self.category not in ['ImLC','SemS','InS','ObD']:
warnings.warn('No appropriate category chosen, please choose between ImLC, SemS, InS and ObD')
return False
all_measures = self.measures_boundary + self.measures_cal + self.measures_detseg + self.measures_mcc + self.measures_mt + self.measures_overlap + self.measures_pcc
for k in all_measures:
if k not in dict_valid[self.category]:
warnings.warn( '%s is not a suitable metric for %s' %(k,self.category))
flag_valid = False
return flag_valid
[docs] def process_data(self):
"""
Performs the processing of the data according to the details given in the setting up of the process
Contributes to the attribution of one dataframe per type of measures :
* resdet - detection results
* resseg - segmentation results
* resmt - multi-threshold results
* resmcc - multi class counting results
* rescal - calibration results
All these dataframes are initialised as None and replaced according to the chosen task. The tasks should yield the following outputs:
* ImLC:
* resdet
* rescal
* resmt
* resmcc
* SemS:
* resseg
* ObD:
* resdet
* resmt
* resmcc
* InS:
* resdet
* resseg
* resmt
* resmcc
"""
data = self.data
df_resdet = None
df_resseg = None
df_resmt = None
df_resmcc = None
df_rescal = None
if self.category == "InS":
MLLS = MultiLabelLocSegPairwiseMeasure(
pred_loc=data["pred_loc"],
ref_loc=data["ref_loc"],
pred_prob=data["pred_prob"],
ref_class=data["ref_class"],
pred_class=data["pred_class"],
file=data["file"],
flag_map=self.flag_map,
assignment=self.assignment,
localization=self.localization,
measures_mt=self.measures_mt,
measures_pcc=self.measures_pcc,
measures_overlap=self.measures_overlap,
measures_boundary=self.measures_boundary,
measures_detseg=self.measures_detseg,
thresh=self.thresh_ass,
list_values=data["list_values"],
per_case=self.case,
flag_fp_in=self.flag_fp_in,
pixdim=self.pixdim
)
df_resseg, df_resdet, df_resmt = MLLS.per_label_dict()
elif self.category == "ObD":
MLDT = MultiLabelLocMeasures(
pred_loc=data["pred_loc"],
ref_loc=data["ref_loc"],
pred_prob=data["pred_prob"],
ref_class=data["ref_class"],
pred_class=data["pred_class"],
list_values=data["list_values"],
localization=self.localization,
assignment=self.assignment,
thresh=self.thresh_ass,
measures_pcc=self.measures_pcc,
measures_mt=self.measures_mt,
per_case=self.case,
flag_fp_in=self.flag_fp_in,
pixdim=self.pixdim
)
df_resdet, df_resmt = MLDT.per_label_dict()
df_resseg = None
elif self.category in ["ImLC", "SemS"]:
if 'names' in data.keys():
list_names=data['names']
else:
list_names = []
MLPM = MultiLabelPairwiseMeasures(
data["pred_class"],
data["ref_class"],
data["pred_prob"],
measures_pcc=self.measures_pcc,
measures_overlap=self.measures_overlap,
measures_boundary=self.measures_boundary,
measures_mcc=self.measures_mcc,
measures_mt=self.measures_mt,
measures_calibration=self.measures_cal,
list_values=data["list_values"],
names=list_names,
per_case=self.case,
pixdim=self.pixdim
)
df_bin, df_mt = MLPM.per_label_dict()
df_mcc, df_cal = MLPM.multi_label_res()
# print(df_bin, 'BIN')
# print(df_mt, 'MT')
# print(df_mcc, 'MCC'),
# print(df_cal, 'CAL')
if self.category == "ImLC":
df_resdet = df_bin
df_resseg = None
df_resmt = df_mt
df_resmcc = df_mcc
df_rescal = df_cal
else:
df_resdet = None
df_resseg = df_bin
df_resmt = df_mt
df_resmcc = df_mcc
self.resdet = df_resdet
self.resseg = df_resseg
self.resmt = df_resmt
self.resmcc = df_resmcc
self.rescal = df_rescal
return
def complete_missing_cases(self):
if len(self.data['ref_missing']) == 0:
return
if self.flag_ignore_missing:
warnings.warn("The set up currently ignores any missing case / dataset")
return
else:
list_missing_det = []
list_missing_seg = []
list_missing_mt = []
list_missing_mcc = []
numb_valid = len(self.data['ref_class'])
if self.case:
for (i,f) in enumerate(self.data['ref_missing']):
dict_mt = {}
dict_mcc = {}
dict_seg = {}
dict_det = {}
dict_mcc['case'] = i + numb_valid
for m in self.measures_mcc:
dict_mcc[m] = WORSE[m]
list_missing_mcc.append(dict_mcc)
for l in self.data['list_values']:
dict_seg = {}
dict_mt = {}
dict_det = {}
for m in self.measures_boundary:
dict_seg[m] = WORSE[m]
for m in self.measures_overlap:
dict_seg[m] = WORSE[m]
for m in self.measures_pcc:
dict_det[m] = WORSE[m]
for m in self.measures_mt:
dict_mt[m] = WORSE[m]
for m in self.measures_detseg:
dict_seg[m] = WORSE[m]
if len(self.measures_boundary) + len(self.measures_overlap) > 0:
dict_seg['case'] = i + numb_valid
dict_seg["label"] = l
list_missing_seg.append(dict_seg)
if len(self.measures_pcc) + len(self.measures_detseg) > 0 :
dict_det['case'] = i + numb_valid
dict_det["label"] = l
list_missing_det.append(dict_det)
if len(self.measures_mt) > 0:
dict_mt['case'] = i + numb_valid
dict_mt["label"] = l
list_missing_mt.append(dict_mt)
df_miss_det = pd.DataFrame.from_dict(list_missing_det)
df_miss_seg = pd.DataFrame.from_dict(list_missing_seg)
df_miss_mcc = pd.DataFrame.from_dict(list_missing_mcc)
df_miss_mt = pd.DataFrame.from_dict(list_missing_mt)
self.resdet = combine_df(self.resdet, df_miss_det)
self.resseg = combine_df(self.resseg, df_miss_seg)
self.resmt = combine_df(self.resmt, df_miss_mt)
self.resmcc = combine_df(self.resmcc, df_miss_mcc)
def label_aggregation(self, option='average',dict_args={}):
if len(self.data['list_values']) == 1:
# print('DET', self.resdet,'CAL',self.rescal, 'SEG',self.resseg,'MT', self.resmt,'MCC', self.resmcc)
df_grouped_all = merge_list_df([self.resdet, self.resseg, self.resmt,self.resmcc, self.rescal])
return df_grouped_all
df_all_labels = merge_list_df([self.resdet, self.resseg, self.resmt], on=['label','case'])
df_all_labels['weights_labels'] = 1
df_all_labels['prevalence_labels'] = 1
for k in self.weights_labels.keys():
df_all_labels['weights_labels'] = np.where(df_all_labels['label']==k,self.weights_labels[k],df_all_labels['weights_labels'])
for (c,rc) in enumerate(self.data['ref_class']):
values,counts = np.unique(rc, return_counts=True)
for (v,co) in zip(values,counts):
df_all_labels['prevalence_labels'] = np.where(np.logical_and(df_all_labels['case']==c, df_all_labels['label']==v),co,df_all_labels['prevalence_labels'])
wm = lambda x: np.ma.average(np.ma.masked_array(x,np.isnan(x)), weights=df_all_labels.loc[x.index, "prevalence_labels"])
wm2 = lambda x: np.ma.average(np.ma.masked_array(x,np.isnan(x)), weights=df_all_labels.loc[x.index, "weights_labels"])
wm3 = lambda x: np.ma.average(np.ma.masked_array(x,np.isnan(x)))
list_measures = self.measures_boundary + self.measures_overlap + self.measures_detseg + self.measures_pcc + self.measures_mt
dict_measures = {k:[('prevalence',wm),('weights',wm2),('average',wm3)] for k in list_measures}
df_grouped_lab = df_all_labels.groupby('case',as_index=False).agg(dict_measures).reset_index()
df_grouped_lab.columns = ['_'.join(col).rstrip('_') for col in df_grouped_lab.columns.values
]
# print(df_grouped_lab, " grouped lab ")
df_grouped_all = merge_list_df([df_grouped_lab.reset_index(), self.resmcc, self.rescal], on=['case'])
# print(df_grouped_all, 'grouped all')
return df_grouped_all
def get_stats_res(self):
df_stats_all = self.grouped_lab.describe()
df_all_labels = merge_list_df([self.resdet, self.resseg, self.resmt], on=['label','case'])
df_stats_lab = df_all_labels.groupby('label').describe()
self.stats_lab = df_stats_lab
self.stats_all = df_stats_all
return