不一致的原因是statsmodels
根据模型是否包含截距,使用不同的公式来计算 R 平方。如果包括截距,statsmodels
将残差平方和除以中心总平方和,而如果不包括截距,statsmodels
将残差平方和除以非中心总平方和。这意味着statsmodels
使用以下公式计算 R 平方,可在文档 https://www.statsmodels.org/dev/generated/statsmodels.regression.linear_model.RegressionResults.rsquared.html:
import numpy as np
def rsquared(y_true, y_pred, fit_intercept=True):
'''
Statsmodels R-squared, see https://www.statsmodels.org/dev/generated/statsmodels.regression.linear_model.RegressionResults.rsquared.html.
'''
if fit_intercept:
return 1 - np.sum((y_true - y_pred) ** 2) / np.sum((y_true - np.mean(y_true)) ** 2)
else:
return 1 - np.sum((y_true - y_pred) ** 2) / np.sum(y_true ** 2)
另一方面,sklearn
始终使用分母处的中心平方和,无论截距是否实际包含在模型中(即无论是否fit_intercept=True
or fit_intercept=False
)。也可以看看这个答案 https://stackoverflow.com/a/54618898/11989081.
import numpy as np
import statsmodels.api as sm
from sklearn.linear_model import LinearRegression
def rsquared(y_true, y_pred, fit_intercept=True):
'''
Statsmodels R-squared, see https://www.statsmodels.org/dev/generated/statsmodels.regression.linear_model.RegressionResults.rsquared.html.
'''
if fit_intercept:
return 1 - np.sum((y_true - y_pred) ** 2) / np.sum((y_true - np.mean(y_true)) ** 2)
else:
return 1 - np.sum((y_true - y_pred) ** 2) / np.sum(y_true ** 2)
# dummy data:
y = np.array([1, 3, 4, 5, 2, 3, 4])
X = np.array(range(1, 8)).reshape(-1, 1) # reshape to column
# intercept is not zero: the result are the same
# scikit-learn:
lr = LinearRegression(fit_intercept=True)
lr.fit(X, y)
print(lr.score(X, y))
# 0.16118421052631582
print(rsquared(y, lr.predict(X), fit_intercept=True))
# 0.16118421052631582
# statsmodels
X_ = sm.add_constant(X)
model = sm.OLS(y, X_)
results = model.fit()
print(results.rsquared)
# 0.16118421052631582
print(rsquared(y, results.fittedvalues, fit_intercept=True))
# 0.16118421052631593
# intercept is zero: the result are different
# scikit-learn:
lr = LinearRegression(fit_intercept=False)
lr.fit(X, y)
print(lr.score(X, y))
# -0.4309210526315792
print(rsquared(y, lr.predict(X), fit_intercept=True))
# -0.4309210526315792
print(rsquared(y, lr.predict(X), fit_intercept=False))
# 0.8058035714285714
# statsmodels
model = sm.OLS(y, X)
results = model.fit()
print(results.rsquared)
# 0.8058035714285714
print(rsquared(y, results.fittedvalues, fit_intercept=False))
# 0.8058035714285714