Skip to content

Commit

Permalink
ENH - Adds support for L1 + L2 regularization in SparseLogisticRegres…
Browse files Browse the repository at this point in the history
…sion (#278)

Co-authored-by: mathurinm <[email protected]>
  • Loading branch information
AnavAgrawal and mathurinm authored Nov 5, 2024
1 parent 495333b commit 1225970
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
14 changes: 11 additions & 3 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,12 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
alpha : float, default=1.0
Regularization strength; must be a positive float.
l1_ratio : float, default=1.0
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
combination of L1 and L2.
tol : float, optional
Stopping criterion for the optimization.
Expand Down Expand Up @@ -1003,10 +1009,11 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
Number of subproblems solved to reach the specified tolerance.
"""

def __init__(self, alpha=1.0, tol=1e-4, max_iter=20, max_epochs=1_000, verbose=0,
fit_intercept=True, warm_start=False):
def __init__(self, alpha=1.0, l1_ratio=1.0, tol=1e-4, max_iter=20, max_epochs=1_000,
verbose=0, fit_intercept=True, warm_start=False):
super().__init__()
self.alpha = alpha
self.l1_ratio = l1_ratio
self.tol = tol
self.max_iter = max_iter
self.max_epochs = max_epochs
Expand Down Expand Up @@ -1035,7 +1042,8 @@ def fit(self, X, y):
max_iter=self.max_iter, max_pn_iter=self.max_epochs, tol=self.tol,
fit_intercept=self.fit_intercept, warm_start=self.warm_start,
verbose=self.verbose)
return _glm_fit(X, y, self, Logistic(), L1(self.alpha), solver)
return _glm_fit(X, y, self, Logistic(), L1_plus_L2(self.alpha, self.l1_ratio),
solver)

def predict_proba(self, X):
"""Probability estimates.
Expand Down
20 changes: 20 additions & 0 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,5 +600,25 @@ def test_GroupLasso_estimator_sparse_vs_dense(positive):
np.testing.assert_allclose(coef_sparse, coef_dense, atol=1e-7, rtol=1e-5)


@pytest.mark.parametrize("X, l1_ratio", product([X, X_sparse], [1., 0.7, 0.]))
def test_SparseLogReg_elasticnet(X, l1_ratio):

estimator_sk = clone(dict_estimators_sk['LogisticRegression'])
estimator_ours = clone(dict_estimators_ours['LogisticRegression'])
estimator_sk.set_params(fit_intercept=True, solver='saga',
penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000)
estimator_ours.set_params(fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000)

estimator_sk.fit(X, y)
estimator_ours.fit(X, y)
coef_sk = estimator_sk.coef_
coef_ours = estimator_ours.coef_

np.testing.assert_array_less(1e-5, norm(coef_ours))
np.testing.assert_allclose(coef_ours, coef_sk, atol=1e-6)
np.testing.assert_allclose(
estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4)


if __name__ == "__main__":
pass

0 comments on commit 1225970

Please sign in to comment.