In statistical learning, the multiclass classification is the problem of classifying samples into a specific category when there are more than two possible categories. There are various real filed applications of multiclass classification. For example, we can conduct cancer diagnosis from gene microarrays (Zhu and Hastie, 2004) or distinguish car types from various care images (Huttunen
For years, the penalized multinomial logistic regression has been studied by many authors since there can be many noisy variables among the input variables. We can avoid unnecessary modeling biases by deleting the noisy input variables from the model, which often results in higher classification accuracy. For example, Krishnapuram
In general, convex penalties such as the LASSO and elastic net are known to select input variables more than necessary unless a certain condition on the design matrix (Zhao and Yu, 2006) is satisfied. On the other hand, non-convex penalties have been proven to have the oracle property for a wide range of statistical models, including the generalized linear models (Fan and Peng, 2004; Kwon and Kim, 2012), random effect models, (Bondell
In this paper, we introduce an efficient algorithm for the non-convex penalized multinomial logistic regression that can be uniformly applied to a class of non-convex penalties. The class includes most non-convex penalties such as the smoothly clipped absolute deviation (SCAD) (Fan and Li, 2001), minimax concave (MC) (Zhang, 2010) and bridge (Huang
The rest of the paper consists of the following. Section 2 introduces the non-convex penalized multinomial logistic regression. Section 3 presents some details on the algorithm. Numerical studies and concluding remarks follow in Sections 4 and 5.
Let (
and
where
Let
where
We consider a class of non-convex penalties that satisfy:
(C1)
(C2)
There is a number of non-convex penalties that satisfy (C1) and (C2). Examples include flat-tailed non-convex penalties such as the SCAD penalty (Fan and Lv, 2011),
for
for
for
for
for
for
for
Note that the h-likelihood penalty has more complex form than the one defined in this paper. However, it is sufficient to understand the h-likelihood penalty as a weighted sum of the log and LASSO penalties as described in Kwon
In this section, we introduce an efficient algorithm for minimizing the penalized negative log-likelihood in (2.3). Since the objective function is non-convex, we first introduce the CCCP (Yuille and Rangarajan, 2002) and then apply local quadratic approximation (LQA) (Lee
The CCCP is one of powerful optimization algorithms for minimizing non-convex functions that can be decomposed as a sum of convex and concave functions. Assume that
where ∇
From (C2), we can see that
where the first two terms are convex and the third term is concave and continuously differentiable. Hence the upper tight convex function, ignoring the constant, to be minimized becomes
given an initial solution
Note that the algorithm in (3.1) iteratively solves LASSO penalized convex objective functions which is an important advantage from the CCCP.
The algorithm (3.1) includes minimizing
where ∇ℓ(
where
Note that the objective function
The computational time of the algorithm in (3.2) can be significantly slow since we repeatedly calculate
where
where
where
and
The two core algorithms in (3.1) and (3.3) for minimizing ℓ
(CCCP) Set an initial
(UBQA) Set an initial
We finish the section giving the solution
as a function of
where
Now, the CCCP-UBQA algorithm applied with the CD algorithm becomes as follows:
CCCP-UBQA-CD algorithm for minimizing ℓ(CCCP) Set an initial
(UBQA) Set an initial
(CD) Set an initial
Note that an immediate and reasonable initial solution for the UBQA and CD steps are
In this section, we present results from numerical studies via simulations and data analysis. We investigate the finite sample performance of the penalized multinomial logistic regression. We compare the SCAD, moderately clipped LASSO and modified bridge penalties with the LASSO penalty for illustration, which are denoted by lasso, scad, classo and mbridge in the tables. The non-convex penalized estimators are obtained by
We generate
We repeat the simulation 100 times and present the averages of the measures in Tables 1 and 2. For comparison we also consider the oracle estimator obtained by using signal variables only as well as the ordinary non-penalized estimator available only when
We apply the penalized multinomial regression for the ‘zoo’ sample that is available from the UCI machine learning repository. The sample includes
The estimated regression coefficients are listed in Table 4, where the variables with zero coefficients for all methods are deleted. All the penalized estimators share the same variables with non-zero regression coefficients for each class but the effect size is different. We calculate the leave-one-out errors and summarize the results in Table 5. The ordinary non-penalized estimator performs the best and the SCAD is the worst. However we note that the number of variables used for the classification is only 4, which can be an advantage from penalized estimation.
We introduced the CCCP-UBQA algorithm for the non-convex penalized multinomial logistic regression which can cover most non-convex penalties. The algorithm implemented in
This paper was written as part of Konkuk University’s research support program for its faculty on sabbatical leave in 2018 and Chungnam National University fund.
Simulation results for the selection
Sensitivity | ||||||||
---|---|---|---|---|---|---|---|---|
oracle | ordinary | lasso | scad | classo | mbridge | |||
3 | 10 | 200 | 1 | 1 | 0.96 | 0.758 | 0.718 | 0.754 |
400 | 1 | 1 | 0.99 | 0.896 | 0.904 | 0.902 | ||
800 | 1 | 1 | 1 | 0.996 | 0.99 | 0.996 | ||
1600 | 1 | 1 | 1 | 1 | 1 | 1 | ||
100 | 200 | 1 | 0 | 0.552 | 0.522 | 0.566 | 0.568 | |
400 | 1 | 1 | 0.936 | 0.79 | 0.832 | 0.85 | ||
800 | 1 | 1 | 0.996 | 0.932 | 0.972 | 0.932 | ||
1600 | 1 | 1 | 1 | 0.992 | 1 | 1 | ||
5 | 10 | 200 | 1 | 1 | 0.828 | 0.613 | 0.583 | 0.617 |
400 | 1 | 1 | 0.962 | 0.76 | 0.711 | 0.788 | ||
800 | 1 | 1 | 0.998 | 0.951 | 0.926 | 0.943 | ||
1600 | 1 | 1 | 1 | 1 | 0.998 | 0.997 | ||
100 | 200 | 1 | 0 | 0.002 | 0.278 | 0.274 | 0.305 | |
400 | 1 | 0 | 0.063 | 0.452 | 0.479 | 0.499 | ||
800 | 1 | 1 | 0.878 | 0.723 | 0.764 | 0.766 | ||
1600 | 1 | 1 | 0.983 | 0.864 | 0.941 | 0.903 | ||
Specificity | ||||||||
oracle | ordinary | lasso | scad | classo | mbridge | |||
3 | 10 | 200 | 1 | 0 | 0.726 | 0.928 | 0.962 | 0.946 |
400 | 1 | 0 | 0.726 | 0.928 | 0.952 | 0.95 | ||
800 | 1 | 0 | 0.816 | 0.968 | 0.984 | 0.974 | ||
1600 | 1 | 0 | 0.9 | 0.988 | 0.996 | 0.99 | ||
100 | 200 | 1 | 1 | 0.996 | 0.99 | 0.992 | 0.989 | |
400 | 1 | 0 | 0.974 | 0.969 | 0.975 | 0.968 | ||
800 | 1 | 0 | 0.966 | 0.975 | 0.986 | 0.981 | ||
1600 | 1 | 0 | 0.962 | 0.992 | 0.996 | 0.992 | ||
5 | 10 | 200 | 1 | 0 | 0.755 | 0.913 | 0.935 | 0.92 |
400 | 1 | 0 | 0.705 | 0.921 | 0.947 | 0.921 | ||
800 | 1 | 0 | 0.718 | 0.906 | 0.948 | 0.926 | ||
1600 | 1 | 0 | 0.757 | 0.94 | 0.963 | 0.951 | ||
100 | 200 | 1 | 1 | 1 | 0.994 | 0.995 | 0.995 | |
400 | 1 | 1 | 1 | 0.995 | 0.995 | 0.995 | ||
800 | 1 | 0 | 0.983 | 0.984 | 0.983 | 0.983 | ||
1600 | 1 | 0 | 0.98 | 0.981 | 0.985 | 0.979 | ||
Accuracy | ||||||||
oracle | ordinary | lasso | scad | classo | mbridge | |||
3 | 10 | 200 | 1 | 0 | 0 | 0.02 | 0.04 | 0.06 |
400 | 1 | 0 | 0.02 | 0.12 | 0.16 | 0.18 | ||
800 | 1 | 0 | 0.16 | 0.72 | 0.8 | 0.78 | ||
1600 | 1 | 0 | 0.42 | 0.88 | 0.96 | 0.90 | ||
100 | 200 | 1 | 0 | 0 | 0 | 0 | 0 | |
400 | 1 | 0 | 0 | 0 | 0 | 0 | ||
800 | 1 | 0 | 0 | 0.02 | 0.06 | 0.04 | ||
1600 | 1 | 0 | 0 | 0.30 | 0.48 | 0.26 | ||
5 | 10 | 200 | 1 | 0 | 0 | 0 | 0 | 0 |
400 | 1 | 0 | 0 | 0 | 0 | 0 | ||
800 | 1 | 0 | 0 | 0.06 | 0.16 | 0.10 | ||
1600 | 1 | 0 | 0 | 0.32 | 0.52 | 0.40 | ||
100 | 200 | 1 | 0 | 0 | 0 | 0 | 0 | |
400 | 1 | 0 | 0 | 0 | 0 | 0 | ||
800 | 1 | 0 | 0 | 0 | 0 | 0 | ||
1600 | 1 | 0 | 0 | 0 | 0.02 | 0 |
Simulation results for the prediction
Prediction error | ||||||||
---|---|---|---|---|---|---|---|---|
oracle | ordinary | lasso | scad | classo | mbridge | |||
3 | 10 | 200 | 0.369 | 0.376 | 0.376 | 0.383 | 0.385 | 0.382 |
400 | 0.365 | 0.370 | 0.368 | 0.369 | 0.370 | 0.368 | ||
800 | 0.362 | 0.363 | 0.364 | 0.362 | 0.363 | 0.362 | ||
1600 | 0.362 | 0.363 | 0.365 | 0.362 | 0.362 | 0.362 | ||
100 | 200 | 0.366 | 0.717 | 0.441 | 0.406 | 0.397 | 0.399 | |
400 | 0.367 | 0.491 | 0.381 | 0.382 | 0.377 | 0.378 | ||
800 | 0.361 | 0.390 | 0.367 | 0.366 | 0.364 | 0.364 | ||
1600 | 0.364 | 0.380 | 0.367 | 0.364 | 0.363 | 0.363 | ||
5 | 10 | 200 | 0.523 | 0.529 | 0.535 | 0.539 | 0.535 | 0.537 |
400 | 0.518 | 0.522 | 0.527 | 0.529 | 0.53 | 0.528 | ||
800 | 0.514 | 0.516 | 0.522 | 0.516 | 0.517 | 0.516 | ||
1600 | 0.515 | 0.516 | 0.52 | 0.515 | 0.515 | 0.516 | ||
100 | 200 | 0.520 | 0.845 | 0.615 | 0.571 | 0.567 | 0.565 | |
400 | 0.520 | 0.847 | 0.615 | 0.549 | 0.542 | 0.544 | ||
800 | 0.514 | 0.550 | 0.551 | 0.528 | 0.524 | 0.523 | ||
1600 | 0.517 | 0.531 | 0.539 | 0.518 | 0.517 | 0.520 |
Simulation results for the computation time in seconds
scad | classo | mbridge | scad | classo | mbridge | ||||||
---|---|---|---|---|---|---|---|---|---|---|---|
3 | 10 | 200 | 0.290 | 0.228 | 0.292 | 5 | 10 | 200 | 2.567 | 2.248 | 2.474 |
400 | 1.221 | 1.059 | 1.329 | 400 | 9.904 | 8.927 | 9.968 | ||||
800 | 4.321 | 3.909 | 5.107 | 800 | 35.505 | 32.863 | 38.414 | ||||
1600 | 15.500 | 14.219 | 19.486 | 1600 | 112.720 | 116.600 | 119.740 | ||||
100 | 200 | 0.656 | 0.559 | 0.574 | 100 | 200 | 2.204 | 2.201 | 1.938 | ||
400 | 2.682 | 2.362 | 2.320 | 400 | 8.601 | 8.359 | 7.988 | ||||
800 | 9.814 | 8.873 | 9.265 | 800 | 33.290 | 34.680 | 32.099 | ||||
1600 | 26.414 | 25.002 | 29.977 | 1600 | 135.88 | 187.800 | 142.320 |
Estimated coefficients of zoo sample
Class1 | Class2 | |||||||
---|---|---|---|---|---|---|---|---|
lasso | scad | classo | mbridge | lasso | scad | classo | mbridge | |
intercept | −1.0925 | −15.4251 | −1.0926 | −5.2049 | −1.4835 | −14.4535 | −1.4835 | −5.3949 |
feathers | 0 | 0 | 0 | 0 | 5.0542 | 31.5816 | 5.0542 | 15.7573 |
milk | 4.9497 | 37.6148 | 4.9498 | 15.336 | 0 | 0 | 0 | 0 |
airborne | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
fins | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
Class3 | Class4 | |||||||
lasso | scad | classo | mbridge | lasso | scad | classo | mbridge | |
intercept | −0.6931 | −0.6931 | −0.6931 | −0.6931 | −1.5134 | −7.5558 | −1.5134 | −5.2502 |
feathers | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
milk | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
airborne | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
fins | 0 | 0 | 0 | 0 | 3.9357 | 16.6094 | 3.9357 | 10.8944 |
Class5 | Class6 | |||||||
lasso | scad | classo | mbridge | lasso | scad | classo | mbridge | |
intercept | −0.9163 | −0.9163 | −0.9163 | −0.9163 | −0.6472 | −0.2231 | −0.6472 | −1.5649 |
feathers | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
milk | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
airborne | 0 | 0 | 0 | 0 | 1.4255 | 0 | 1.4255 | 6.0378 |
fins | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
Number of wrong classifications of zoo sample
ordinary | lasso | scad | classo | mbridge |
---|---|---|---|---|
5 | 11 | 18 | 11 | 11 |