![PyTorch深度学习应用实战](https://wfqqreader-1252317822.image.myqcloud.com/cover/410/52842410/b_52842410.jpg)
2-2 万般皆自“回归”起
要探究神经网络优化的过程,要先了解简单线性回归求解,线性回归方程式如下:
y=wx+b
已知样本(x, y),要求解方程式中的参数权重(w)、偏差(b)。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P21_1735.jpg?sign=1739216748-WjEFxB7fj9wCVh7lmb1piXQEVn9udhhf-0-4b8f96f41163074ed21ac17d7320ca38)
图2.2 简单线性回归
一般求解方法有两种:
(1)最小平方法(Ordinary Least Square, OLS);
(2)最大似然估计法(Maximum Likelihood Estimation, MLE)。
以最小平方法为例,首先定义目标函数(Object Function)或称损失函数(Loss Function)为均方误差(MSE),即预测值与实际值差距的平方和,MSE当然越小越好,所以它是一个最小化的问题,我们可以利用偏微分推导出公式,过程如下。
(1)
其中ε:误差,即实际值(y)与预测值之差;
n:样本个数。
(2)MSE=SSE/n,n为常数,不影响求解,可忽略。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P21_25720.jpg?sign=1739216748-UzBXn8xVMFUTc7NQcemrMbXIfSH9N1si-0-0e52df5c8ada2a9c8a535bb170abc6aa)
(3)分别对w及b偏微分,并且令一阶导数=0,可以得到两个联立方程式,进而求得w及b。
(4)先对b偏微分,又因
f′(x)=g(x)g(x)=g′(x)g(x)+g(x)g′(x)=2g(x)g′(x)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25730.jpg?sign=1739216748-zykcGvjI2i74hZb8RRnax0lmSKpp2mPb-0-838a49edb143ce9a9798ed1b6867d36b)
→两边同除以2
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25731.jpg?sign=1739216748-RnVbcvj2TzxxoNS9t314hu64VDIA6oUA-0-f6b0ed43d4179db76a921dfcfc297b3f)
→分解
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25732.jpg?sign=1739216748-FIFtyZoQd6vqt2UZzmBiBnbZ8SXKIuwb-0-ea4bfa1c746e7bca50e76a6a13bab4c3)
→除以n,为x、y的平均数
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25733.jpg?sign=1739216748-BUSwDYFwstR90rBIKGUgpyMBNdADEJtv-0-4b423e26c2f8668503d0eae61ecd7ad7)
→移项
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25734.jpg?sign=1739216748-k3EoUZJcghShU0bsC5IcyjwXK5HKXX1i-0-0a0cb3e02ed66ac62a04b1486ba3aa32)
(5)对w偏微分:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25735.jpg?sign=1739216748-Oy8mMmvkswTTCRC2QcvgCu66aQGKsbSN-0-f0ede6ba532768929e6e35ce88427581)
→两边同除以-2
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25737.jpg?sign=1739216748-uyxzTqazc0hCUE9vaAFuZfg8GMFfWriC-0-a2023da95b6047081184de1154645ebb)
→分解
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25738.jpg?sign=1739216748-rsu1T4jmrrBLWQX7T1HzbSm1iCDxp4Wu-0-5c08addea7bda3e3f14a4cbafeba685e)
→代入步骤(4)的计算结果
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25739.jpg?sign=1739216748-lVYXVvAisjvcESgMtOa2caV5K1M8hPoy-0-578ff9573265090dd5a7ec8017f4a780)
→化简
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25740.jpg?sign=1739216748-xkqYJqqNhzB0JkY9CIE7gVrT1qdtqsdF-0-3b494a394e7d7e7d8cdb826f0ffeec03)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25741.jpg?sign=1739216748-p9CbeoG7IVOm90PesZLzOlE0n9vVpO5n-0-a75f96493808828981c9dcc457250300)
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25742.jpg?sign=1739216748-KiHpiyc8FRuAnXStRvMTd2Hbs2teMmoX-0-19bede217879308754404f79977d9f06)
结论:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P22_25743.jpg?sign=1739216748-WVgCL8aaBERoVssZzOWinWqvqd4q2R2e-0-8f33b465dd37e0ccf8904ffcac44d98a)
范例1.现有一个世界人口统计数据集,以年度(year)为x,人口数为y,按上述公式计算回归系数w、b。
下列程序代码请参考【02_01_线性回归.ipynb】。
(1)使用Pandas相关函数计算,程序如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1947.jpg?sign=1739216748-6T7mu1IsYZN2sU8J6LAiw0ebR4YzLz7r-0-c4d29b27a59d87608c87966afc56f6eb)
执行结果:
w=0.061159358661557375,b=-116.35631056117687
(2)改用NumPy的现成函数polyfit验算:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1954.jpg?sign=1739216748-Si9Gj4qLOcSsN91b2e4hVAz7H52FpQar-0-570d171dc4b50fbf3da4b6d4641e5c63)
执行结果:答案相差不大。
w=0.061159358661554586, b=-116.35631056117121
(3)上面公式,x只限一个,若以矩阵计算则更具通用性,多元回归亦可适用,即模型可以有多个特征(x),为简化模型,将b视为w的一环:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_25752.jpg?sign=1739216748-2zxC9hSolpNedpTlBMlwrIdQ9hAGYudj-0-dfdd572001185de2e55087200dfaa5d3)
一样对SSE偏微分,一阶导数=0有最小值,公式推导如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_25754.jpg?sign=1739216748-vmNNPqbgzxp7bZ4UxQjQlPfgwPQDUXPk-0-95b6e47e432db2c2437f23e126e8b01e)
→移项、整理
(xx′)w=xy
→移项
w=(xx′)−1xy
(4)使用NumPy相关函数计算,程序如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P23_1969.jpg?sign=1739216748-Yr8c3gqqd0lbJDWbhWOZVtzJ8z3YmbOo-0-91677277c4fb4886f1d595023cacba9b)
执行结果与上一段相同。
范例2.再以Scikit-Learn的房价数据集为例,求解线性回归,该数据集有多个特征(x)。
(1)以矩阵计算的方式,完全不变。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_1991.jpg?sign=1739216748-BPO3EsEecXyHoAcNy9uUSA9fWJVrEOjA-0-d37dd160a839e6e2e198e5659b733d83)
执行结果如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_1998.jpg?sign=1739216748-AAZV0VkZumnTsCylFNvpwCHH2Wr2JNBv-0-f2e71b041234d1cf681f71d18ec2fe11)
(2)以Scikit-Learn的线性回归类别验证答案。
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_2001.jpg?sign=1739216748-QFajnd4UTyynfHX85uZS1iP5LzPVTvJA-0-ab44550e22d0913bda4e9106ca893a7f)
执行结果与采用矩阵计算的结果完全相同。
(3)PyTorch自v1.9起提供线性代数函数库[1],可直接调用,程序改写如下:
![](https://epubservercos.yuewen.com/128DEE/31397898903670606/epubprivate/OEBPS/Images/Figure-P24_2008.jpg?sign=1739216748-dMZDKQeqFX1gAy213Kk6Zd9sYT1hP1T0-0-218ea9544455ed19d8787c36cc2797d0)
执行结果与NumPy计算完全相同。