将线性回归线延伸至整个图的宽度。

23 浏览
0 Comments

将线性回归线延伸至整个图的宽度。

我不知道如何让线性回归线(也称为最佳拟合线)跨越整个图的宽度。它似乎只能达到最左边和最右边的数据点,没有更远。我该如何解决这个问题?

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb
# 连接MySQL数据库
def mysql_select_all():
    conn = MySQLdb.connect(host='localhost',
                           user='root',
                           passwd='XXXXX',
                           db='world')
    cursor = conn.cursor()
    sql = """
        SELECT
            GNP, Population
        FROM
            country
        WHERE
            Name LIKE 'United States'
                OR Name LIKE 'Canada'
                OR Name LIKE 'United Kingdom'
                OR Name LIKE 'Russia'
                OR Name LIKE 'Germany'
                OR Name LIKE 'Poland'
                OR Name LIKE 'Italy'
                OR Name LIKE 'China'
                OR Name LIKE 'India'
                OR Name LIKE 'Japan'
                OR Name LIKE 'Brazil';
    """
    cursor.execute(sql)
    result = cursor.fetchall()
    list_x = []
    list_y = []
    for row in result:
        list_x.append(('%r' % (row[0],)))
    for row in result:
        list_y.append(('%r' % (row[1],)))
    list_x = list(map(float, list_x))
    list_y = list(map(float, list_y))
    fig = plt.figure()
    ax1 = plt.subplot2grid((1,1), (0,0))
    p1 = np.polyfit(list_x, list_y, 1)          # 这行是指回归线
    ax1.xaxis.labelpad = 50
    ax1.yaxis.labelpad = 50
    plt.plot(list_x, np.polyval(p1,list_x),'r-') # 这是指回归线
    plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
    plt.xlabel("GNP (US dollars)", fontsize=30)
    plt.ylabel("Population(in billions)", fontsize=30)
    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 
                7000000, 8000000, 9000000],  rotation=45, fontsize=14)
    plt.yticks(fontsize=14)
    plt.show()
    cursor.close()
mysql_select_all()

0