Aumento de gradiente
Originalmente desarrollado por Friedman et al. [36] Gradient Boosting es un método de conjunto de árboles que se basa en el entrenamiento de una secuencia de alumnos débiles (generalmente árboles de regresión), cada uno ajustado en los residuos del modelo anterior. El modelo final se obtiene simplemente combinando todas las predicciones de cada clasificador individual. Dado que este procedimiento es propenso al sobreajuste, todos los marcos de Gradient Boosting ofrecen una variedad de opciones de regularización, como tasas de aprendizaje para modular la influencia de un alumno individual en la predicción final, muestreo de muestras y variables de entrenamiento, regularización L1 y otras opciones. [31, 32]
Una diferencia clave entre Gradient Boosting y Random Forest está en la forma en que se optimizan los árboles individuales. Un clasificador Gradient Boosting utiliza árboles de regresión, donde las divisiones individuales se optimizan de acuerdo con el gradiente y el hessiano de alguna función de pérdida (es decir, entropía cruzada), y convierte la suma de las predicciones en una probabilidad mediante la aplicación de la función sigmoidea. [31]. En cambio, Random Forest usa árboles de decisión, donde las divisiones individuales se optimizan utilizando criterios como la impureza de Gini o la entropía de Shannon. [37]. Esta distinción permite la implementación de funciones de pérdida personalizadas de manera sencilla en cualquier marco de Gradient Boosting. [38]
Hay varios paquetes de python disponibles para entrenar modelos de Gradient Boosting, siendo el más popular XGBoost [31]CatBoost [39] y LightGBM [32]. En este estudio, desarrollamos todos los modelos utilizando la versión Python de LightGBM 3.3.2.
Funciones de pérdida
La función de pérdida predeterminada para muchos clasificadores basados en gradientes, incluido LightGBM, cuando se trata de una clasificación desequilibrada es la entropía cruzada ponderada (WCE) [22, 23], que mide qué tan cerca están las probabilidades de clase predichas por el clasificador de coincidir con las etiquetas de clase verdaderas. Se define de la siguiente manera:
$$begin_=-sum_^__mathrm( widehat)+left(1-_right)mathrmleft(1-widehat derecha)end$$
(1)
dónde (metro) es el número total de muestras, (_) son las etiquetas de destino, (widehat) son las predicciones, () es un parámetro ajustable para tener en cuenta el desequilibrio de clases. Al manejar conjuntos de datos desequilibrados, los clasificadores tienden a ignorar el primer término, que corresponde a los errores de la clase minoritaria, y solo se enfocan en minimizar el segundo término, que corresponde a los errores de la clase mayoritaria, lo que lleva a un modelo subóptimo. [22, 23]. Esto se puede abordar estableciendo () igual a la relación de compuestos inactivos frente a compuestos activos.
Pérdida focal
La pérdida focal modifica la formulación de entropía cruzada binaria al reducir la influencia de muestras bien clasificadas en la pérdida general [24, 38]. La formulación es la siguiente:
$$begin_=-sum_^_^mathrmleft(widehatright)+left(1-_right )^mathrmleft(1-widehatright)end$$
(2)
dónde (gama) es un parámetro ajustable que afecta la forma de la función de pérdida. Para valores altos de (gama), la contribución de las muestras bien clasificadas a la pérdida general se aproxima a 0, lo que permite que el gradiente se centre más en la clase minoritaria. Si (gama) se establece en 0, la pérdida focal coincide con la pérdida de entropía cruzada estándar.
Pérdida ajustada por logit
En lugar de modular la influencia de la muestra durante el proceso de entrenamiento, como la entropía cruzada ponderada o la pérdida focal, la pérdida ajustada por logit escala los logit sin procesar del clasificador de acuerdo con las probabilidades a priori de las clases. [27]como se muestra en la Fórmula 3
$$begin_=-sum_^_mathrmleft(sigma left( _+tau *_right)right)+left(1-_right)mathrmleft(1 -sigma left(_+tau *_right)right)end$$
dónde (sigma) es la función sigmoidea, () es la predicción logit sin procesar, (_) y (_) son las probabilidades previas para las clases mayoritaria y minoritaria y (tau) es un factor de suavizado que modula la influencia de los ajustes logit en el proceso de aprendizaje. Una diferencia clave de la pérdida ajustada por Logit en comparación con otros enfoques es que garantiza la consistencia de Fisher para el estimador por diseño, a través de una solución óptima de Bayes para el error balanceado. [27]
Pérdida de margen consciente de la distribución de etiquetas
De manera similar a la pérdida ajustada por logit, la pérdida de LDAM aplica una compensación a los logit sin procesar del modelo, pero las compensaciones óptimas se obtienen minimizando un límite de generalización basado en el margen. [25]. Una limitación clave de los enfoques basados en el margen, como las máquinas de vectores de soporte, es que dependen de la pérdida de bisagra. [40]que es problemático de optimizar para métodos basados en gradientes debido a su falta de suavidad [25]. Para abordar este problema, Cao et al. optó por usar una formulación inspirada en la entropía cruzada, como se muestra en la Fórmula 4:
$$begin_=-sum_^_mathrmleft(sigma left( _+frac{raíz cuadrada[4]right)right)+left(1-_right)mathrmleft(1-sigma left( _ +fracright)right)end$$
(4)
Donde C es un hiperparámetro a ajustar y () y () son el número de muestras en la clase minoritaria y mayoritaria respectivamente.
Pérdida de ecualización
Otra forma de tener en cuenta el desequilibrio de clase es operar a nivel de gradiente, por ejemplo, al aumentar la ponderación de los gradientes de las muestras minoritarias y al disminuir la ponderación de las muestras mayoritarias de acuerdo con la relación de gradiente entre las clases. Este enfoque tiene la ventaja teórica de ponderar la clase minoritaria no solo de acuerdo con el desequilibrio de clases, sino también de acuerdo con la dificultad intrínseca del problema de clasificación, lo que podría generar mejores ponderaciones en comparación con las estadísticas simples de conteo de clases. [26]. Otra ventaja es que este enfoque es independiente de la función, en el sentido de que puede implementarse para ajustar cualquier función de pérdida preexistente, es decir, la entropía cruzada.
Para obtener los coeficientes de ponderación de los gradientes de las clases mayoritaria y minoritaria, Equalization loss emplea la siguiente fórmula:
$$begin^=1+alpha left(1-fleft(^ right)right)end$$
(5)
$$begin^=fleft(^right)end$ ps
(6)
dónde (^) es la relación de gradientes acumulados entre las clases mayoritaria y minoritaria en la iteración (t), (alfa) es un hiperparámetro que permite aumentar el peso para la clase minoritaria y (F) es una función de mapeo:
$$beginfleft(xright)=fracend$$
(7)
Con hiperparámetros (gama) y (mu).
Para implementar este enfoque, dado que Gradient Boosting no se entrena con mini lotes, consideramos la adición de un árbol individual como una iteración, recortamos los gradientes para lograr estabilidad numérica y usamos la entropía cruzada binaria como la función de pérdida subyacente.
conjuntos de datos
Para evaluar nuestro enfoque propuesto, recopilamos seis conjuntos de datos de fuentes patentadas y disponibles públicamente. De MoleculeNet [15] seleccionamos Tox21, HIV y MUV, de MolData [20] elegimos Phosphatase y NTPase y finalmente agregamos un conjunto de datos de detección de alto rendimiento (HTS) de Merck KGaA, lo que resultó en aproximadamente 2 millones de compuestos y 42 tareas. Esta selección cubre un amplio rango de desequilibrio y tamaño del conjunto de datos, para garantizar que nuestros hallazgos no estén sesgados por condiciones específicas del conjunto de datos.
Para acceder a los datos disponibles públicamente, descargamos los conjuntos de datos limpios de MoleculeNet de Jiang et al. [13] y los de MolData de Arshadi y colaboradores. [20]
Los conjuntos de datos se resumen en la Tabla 1 e informan el número promedio de compuestos y las proporciones de desequilibrio entre las tareas. Los valores individuales correspondientes a cada punto final se pueden encontrar en Archivo adicional 1: Tabla S1. Dado que el punto de referencia HTS es un conjunto de datos patentado de Merck KGaA, la cantidad exacta de compuestos es confidencial.
Métrica
Un paso crítico en el desarrollo de clasificadores para la clasificación desequilibrada es la elección de la métrica para medir el rendimiento. [41, 42]. Por ejemplo, evaluar los modelos de aprendizaje automático de acuerdo con la precisión cuando se trata de datos desequilibrados puede llevar a conclusiones engañosas, ya que no tiene en cuenta adecuadamente el rendimiento de la clase minoritaria. [5, 41, 42]. Para permitir comparaciones con los resultados informados anteriormente en la literatura para estos puntos de referencia, optamos por evaluar todos los conjuntos de datos utilizando todas las métricas utilizadas por Arshadi et al. [20] y Jiang y compañeros de trabajo [13], con la adición de precisión equilibrada, puntuación F1 y el coeficiente de correlación de Matthews (MCC). Por lo tanto, para cada receptor de referencia se midieron el área característica operativa bajo la curva (ROC-AUC), el área bajo la curva de recuperación de precisión (PR-AUC), la precisión, la precisión equilibrada, la recuperación, la precisión, la puntuación F1 y el MCC. Se puede encontrar una discusión más detallada sobre la elección de métricas y su definición en: Secc. 1 de la. Dada la información del archivo adicional número 1 de clasificadores y métricas involucradas en nuestro estudio, para mayor concisión mostramos en el texto principal solo las métricas reportadas por los autores de los respectivos puntos de referencia. Las tablas de rendimiento con todas las métricas empleadas en este estudio se pueden encontrar en: Secc. 3, 4 y 5 de la información del archivo adicional 1
Procedimiento de evaluación comparativa
Después de descargar los conjuntos de datos de los repositorios respectivos, todos los compuestos se desinfectaron con RDKIT (versión 2022.03.01) como se describe en los documentos originales y se caracterizaron con huellas dactilares de conectividad extendida (ECFP) con tamaño de bit 1024 y radio 2.
Para desarrollar los modelos, seguimos dos procedimientos de evaluación comparativa diferentes según la fuente del conjunto de datos. De esta forma, los resultados obtenidos en este estudio son directamente comparables con el desempeño de otros clasificadores reportados en los respectivos trabajos. Esto nos permite poner en perspectiva las mejoras que proporciona nuestro enfoque sobre la implementación predeterminada de LightGBM en un estudio de comparación de clasificadores más convencional.
Para Tox21, HIV y MUV, optimizamos cada clasificador en validación cruzada mediante divisiones aleatorias, con una proporción de 80:10:10 para el conjunto de entrenamiento, validación y prueba. Cada modelo utilizó la detención anticipada ante la pérdida del conjunto de validación, mientras que el conjunto de prueba se utilizó para evaluar el rendimiento del modelo. Para optimizar los modelos usamos Hyperopt (versión 0.2.7) [43] para 20 iteraciones. Una vez finalizada la optimización, ejecutamos el modelo con hiperparámetros óptimos en 50 divisiones aleatorias, con una relación de 80:10:10 para el conjunto de entrenamiento, validación y prueba. Al igual que en la fase de optimización, utilizamos el conjunto de validación para la detención anticipada y el conjunto de prueba para la evaluación del rendimiento. Con respecto a la elección de las métricas, al comparar nuestro enfoque con los resultados de la literatura, seguimos las pautas de Wu et al. [15]: Tox21 y HIV se evaluaron según ROC-AUC, mientras que MUV con PR-AUC.
Para los conjuntos de datos de fosfatasa y NTPasa, empleamos las divisiones de andamios proporcionadas por Arshadi et al. [20] Para cada tarea, optimizamos cada modelo en el conjunto de validación e informamos el rendimiento en el conjunto de prueba. En todos los casos, utilizamos la detención anticipada en el conjunto de validación para determinar el número óptimo de árboles. Todos los clasificadores fueron optimizados usando Hyperopt [43] durante 20 iteraciones y luego se evaluó 5 veces usando diferentes semillas aleatorias. Para las comparaciones con otros algoritmos de aprendizaje automático, informamos las métricas empleadas por Arshadi et al. (exactitud, ROC-AUC, precisión, recuperación) con la suma de la puntuación F1, para estimar el compromiso entre alta precisión y alta recuperación.
Para el conjunto de datos Merck KGaA HTS, empleamos el procedimiento de evaluación para los puntos de referencia de MolData. Creamos conjuntos de entrenamiento, validación y prueba utilizando la división de andamios con una proporción de 80:10:10. Luego, optimizamos todos los clasificadores con Hyperopt para 20 iteraciones en el conjunto de validación utilizando la detención anticipada. Finalmente, volvimos a entrenar cada modelo con parámetros óptimos 5 veces y medimos todas las métricas en el conjunto de prueba.
Para evaluar la eficacia de las funciones de pérdida personalizada, utilizamos como referencia en todos nuestros puntos de referencia el rendimiento de la entropía cruzada ponderada y evaluamos si la mejora es significativa con Welch de 1 cola. t-pruebas con corrección de Bonferroni. Además, para contextualizar el rendimiento de LightGBM con funciones de pérdida personalizadas, comparamos el modelo de mejor rendimiento de nuestro estudio con los modelos informados por Jiang et al. para MoleculeNet y por Arshadi et al. para MolData. Todos los modelos de estos documentos emplearon entropía cruzada ponderada o esquemas de equilibrio de clases para modelar el desequilibrio de la actividad, según el algoritmo de clasificación subyacente.
En el primer estudio, se investigaron cuatro métodos de aprendizaje automático basados en descriptores y cuatro redes neuronales basadas en gráficos. Los modelos basados en descriptores fueron Random Forest (RF), Support Vector Machine (SVM), XGBoost (XGB) y una red neuronal con capas densas (DNN), utilizando una combinación de descriptores 1D y 2D, así como dos conjuntos de huellas dactilares. [13]. Para los modelos basados en gráficos, consideraron una red convolucional de gráficos (GCN), una red de atención de gráficos (GAT), una red neuronal de paso de mensajes (MPNN) y huellas dactilares atentas (AFP) [13]. Para ser concisos, para cada conjunto de datos de MoleculeNet informamos el rendimiento del mejor modelo basado en descriptores y el modelo basado en gráficos y los comparamos con el modelo LightGBM de mejor rendimiento utilizando Welch de 2 colas. t-pruebas con corrección de Bonferroni.
En el segundo estudio, los autores desarrollaron un DNN multitarea en huellas dactilares ECFP con tamaño de bit 1024 y radio 2 y un GCN multitarea. Para estas líneas de base, omitimos las pruebas estadísticas ya que los autores no informaron las desviaciones estándar de sus resultados.
Los detalles de evaluación comparativa para todos los conjuntos de datos se resumen en la Tabla 2.