作为我的研究项目的一部分,我正在使用 matplotlib 对一些数据执行线性回归。不幸的是,我无法让我的线触及原点; matplotlib 似乎将其截断为我的数据集的最小值。我怎样才能解决这个问题并让我的线接触原点?作为引用,这是我的代码:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from statsmodels import api as sm
def file_analysis(csv_file, state):
"""
This method takes in a file object and the name of a state.
:param csv_file: Pass in a csv file object.
:param state: Name of the state as a string.
:return: None.
"""
data = pd.read_csv(csv_file)
data = data[["Total Cases", "Total Deaths"]]
y = data["Total Deaths"]
x = data["Total Cases"]
results = sm.OLS(y, x).fit()
plt.scatter(x, y)
yhat = results.params[0] * x
print(results.params)
plt.ylim(ymin=0)
plt.xlim(xmin=0)
plt.margins(0)
fig = plt.plot(x, yhat, lw=4, c="orange", label="regressionline")
plt.xlabel("Total Cases", fontsize=20)
plt.ylabel('Total Deaths', fontsize=20)
plt.title(state)
plt.savefig(state + "_scatterplot" + ".png")
plt.show()
with open(state + "_analysis.txt", "w") as file:
file.write(results.summary().as_text())
最佳答案
您应该只更改您希望在回归中包含 0 的 x 值。
yhat = results.params[0] * range(0, x.max())
fig = plt.plot(range(0, x.max()), yhat, lw=4, c="orange", label="regressionline")
关于python - Matplotlib 未绘制整条线,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62271236/