机器学习-李宏毅| 回归演示 | python_李宏毅线性回归python-程序员宅基地

技术标签: python  机器学习  logistic regression  

回归的定义

Regression就是指找到一个函数 f u n c t i o n function function,通过输入特征x,输出一个数值 S c a l a r Scalar Scalar

看了李宏毅老师的机器学习课程视频,其中的Regression demo部分,关于预测宝可梦的CP值的应用代码,在jupyter notebook中实现。
现在假设有10个x_data和y_data,x和y之间的关系是y_data=b+w*x_data。b,w都是参数,是需要学习出来的。现在我们来练习用梯度下降找到b和w。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.rcParams['font.sans-serif'] = ['Simhei']  # 显示中文
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

x_data= [ 338.,  333., 328., 207., 226., 25., 179., 60., 208., 606.]
y_data= [ 640.,  633., 619., 393., 428., 27., 193., 66., 226., 1591.] 
# ydata =b + w * xdata

x = np.arange(-200, -100, 1) #bias
y = np.arange(-5,5,0.1) #weight
Z = np.zeros((len(x), len(y)))
X, Y = np.meshgrid(x, y)
for i in range(len(x)):
    for j in range(len(y)):
        b = x[i]
        w = y[j]
        Z[j][i] = 0
        for n in range(len(x_data)):
            Z[j][i] = Z[j][i]  + (y_data[n] - b - w*x_data[n]) **2
        Z[j][i] =   Z[j][i] /len(x_data)

# ydata = b + w * xdata
b = -120 # initial b
w = -4 #intial w
lr =0.0000001 
iteration = 100000 
# Store initial values for plotting.
b_history = [b]
w_history = [w]

#lr_b = 0 #客制化b的learning rate 的初始值
#lr_w = 0 #客制化w的learning rate 的初始值

# Iterations
for i in range(iteration):
    
    b_grad = 0.0
    w_grad = 0.0
    for n in range(len(x_data)):
        b_grad = b_grad - 2.0*(y_data[n] - b - w*x_data[n]) *1.0
        w_grad = w_grad - 2.0*(y_data[n] - b - w*x_data[n])*x_data[n]
        
   # lr_b = lr_b + b_grad ** 2 #客制化b的learning rate
   # lr_w = lr_w + w_grad ** 2 #客制化w的learning rate
        
    # Update parameters.
    b = b - lr * b_grad
    w = w - lr * w_grad
    
    # Store parameters for plotting
    b_history.append(b)
    w_history.append(w)

# plot the figure
plt.contourf(x, y, Z, 50, alpha = 0.5, cmap=plt.get_cmap('jet'))
plt.plot([-188.4], [2.67], 'x', ms = 12, markeredgewidth = 3, color='orange')
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5,5)
plt.xlabel(r'$b$', fontsize=16)
plt.ylabel(r'$w$', fontsize=16)
plt.title("线性回归")
plt.show()
        

输出结果图:
在这里插入图片描述
横坐标是b,纵坐标是w,标记×位最优解,显然,在图中我们并没有运行得到最优解,最优解十分的遥远。那么我们就调大learning rate,lr = 0.000001(调大10倍),得到结果如下图。
在这里插入图片描述
我们再调大learning rate,lr = 0.00001(调大10倍),得到结果如下图。
在这里插入图片描述
结果发现learning rate太大了,结果很不好。
所以我们给b和w特制化两种learning rate
修改后代码如下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.rcParams['font.sans-serif'] = ['Simhei']  # 显示中文
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

x_data= [ 338.,  333., 328., 207., 226., 25., 179., 60., 208., 606.]
y_data= [ 640.,  633., 619., 393., 428., 27., 193., 66., 226., 1591.] 
# ydata =b + w * xdata

x = np.arange(-200, -100, 1) #bias
y = np.arange(-5,5,0.1) #weight
Z = np.zeros((len(x), len(y)))
X, Y = np.meshgrid(x, y)
for i in range(len(x)):
    for j in range(len(y)):
        b = x[i]
        w = y[j]
        Z[j][i] = 0
        for n in range(len(x_data)):
            Z[j][i] = Z[j][i]  + (y_data[n] - b - w*x_data[n]) **2
        Z[j][i] =   Z[j][i] /len(x_data)

# ydata = b + w * xdata
b = -120 # initial b
w = -4 #intial w
lr =1 #learning rate设为1
iteration = 100000 
# Store initial values for plotting.
b_history = [b]
w_history = [w]

lr_b = 0 #客制化b的learning rate 的初始值
lr_w = 0 #客制化w的learning rate 的初始值

# Iterations
for i in range(iteration):
    
    b_grad = 0.0
    w_grad = 0.0
    for n in range(len(x_data)):
        b_grad = b_grad - 2.0*(y_data[n] - b - w*x_data[n]) *1.0
        w_grad = w_grad - 2.0*(y_data[n] - b - w*x_data[n])*x_data[n]
        
    lr_b = lr_b + b_grad ** 2 #客制化b的learning rate
    lr_w = lr_w + w_grad ** 2 #客制化w的learning rate
        
    # Update parameters.
    b = b - lr/np.sqrt(lr_b ) * b_grad
    w = w - lr/np.sqrt(lr_w ) * w_grad
    
    # Store parameters for plotting
    b_history.append(b)
    w_history.append(w)

# plot the figure
plt.contourf(x, y, Z, 50, alpha = 0.5, cmap=plt.get_cmap('jet'))
plt.plot([-188.4], [2.67], 'x', ms = 12, markeredgewidth = 3, color='orange')
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5,5)
plt.xlabel(r'$b$', fontsize=16)
plt.ylabel(r'$w$', fontsize=16)
plt.title("线性回归")
plt.show()

这样有了新的特制化两种learning rate就可以在10w次迭代之内到达最优点了。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/baidu_36415362/article/details/104818591

智能推荐

泰勒图(Taylor diagram)绘制方法大汇总-程序员宅基地

文章浏览阅读2.6k次,点赞6次,收藏18次。泰勒图(Taylor diagram)的基本介绍 R 绘制泰勒图(Taylor diagram) Python 绘制泰勒图(Taylor diagram) 泰勒图(Taylor diagram)的基本介绍泰勒图(Taylor diagram)可以简单的理解为一种的可同时表示标准差、均方根误差和相关系数三个指标的可视化图表。样例图如下(来源于网络):泰勒图(Taylor diagram)样例通常,泰勒图中的散点代表不同模型,横纵轴代表标准差,辐射线代表相关系数,虚..._泰勒图

√ Unity3D - 怎么添加视频_在unity3d 添加视频-程序员宅基地

文章浏览阅读2.1w次,点赞7次,收藏39次。1、在Hierarchy面板中新建一个VideoPlayer组件,然后将视频资源赋给VideoPlayer组件的VideoClip属性。2、在Project面板中新建一个RenderTexture对象,然后将其赋给VideoPlayer组件的TargetTexture属性。3、在Hierarchy面板中新建一个RawImage组件,最后将RenderTexture对象再赋给RawImage组件的Texture属性即可。4、RenderTexture对象的Size属性用于调整视频的分辨率。..._在unity3d 添加视频

IDEA pom.xml显示灰色并被划线_pom.xml中划线-程序员宅基地

文章浏览阅读559次。在使用 IDEA 进行开发的过程中,有时候会遇到。设置保存后,可以看到。_pom.xml中划线

Python 任意字典生成 SQL(insert 语句)_python dict 转 sql-程序员宅基地

文章浏览阅读8.1k次,点赞8次,收藏16次。代码示例(复制用)dic = { 'url': 'URL', 'kw': 'KW', 'page': 'PAGE'}ls = list(dic)sentence = 'insert table_name(' + ','.join(ls) + ')values(' + ','.join(['&q_python dict 转 sql

java.sql.Timestamp与java.util.Date的相互转换_java.sql.timestamp 转java.util.date-程序员宅基地

文章浏览阅读5k次。public static void main(String[] args) throws ParseException { //测试java.sql.Timestamp取得当前的系统时间 Timestamp timestamp = new Timestamp(System.currentTimeMillis()); System.out.println("当前的系统时间java.sq_java.sql.timestamp 转java.util.date

ssm上传文件获取路径_又拍云:文件上传+下载+自定义下载路径(SSM+AJAX+JFileChooser)...-程序员宅基地

文章浏览阅读196次。这一篇是前几篇功能的总结http://blog..net/qq_36688143/article/details/79007120http://blog..net/qq_36688143/article/details/79007067http://blog..net/qq_36688143/article/details/78871406http://blog..net/qq_36688143/..._ssm设置下载路径

随便推点

嵌入式 Linux 内核驱动开发【The first day: 36093万字】_linux嵌入式内核及驱动开发 初级+高级+项目+物联网逆人行-程序员宅基地

文章浏览阅读2.2k次,点赞42次,收藏87次。嵌入式 Linux 内核驱动开发【1】第1章 Linux 内核裁剪和定制【1】Linux 内核开发简介【2】 Linux 源码阅读工具【1.2.1】Source Insight【1.2.2 Eclipse】【1.2.3】 vim+ctags+cscope【1.2.4】 LXR【3】Linux 内核源码【1.3.1 目录树概览】【1.3.2】 快速确定主板关联代码【4】 Linux 内核中的 Makefile 文件【1.4.1】 顶层 Makefi_linux嵌入式内核及驱动开发 初级+高级+项目+物联网逆人行

使用lupdate生成Qt的ts翻译文件-程序员宅基地

文章浏览阅读1.8k次。4、.ts文件翻译完成后,使用lrelease xxx.ts命令,生成.qm文件供程序加载,如果程序关联了多个动态库,可能需要加载多个翻译文件。3、使用linguist XXX.ts 打开qt翻译工具。_lupdate

苏州大学推出开源大模型OpenBA;阿里云开源通义千问14B模型;百川智能发布Baichuan2-53B 闭源大模丨每日大事件...-程序员宅基地

文章浏览阅读407次。‍大数据产业创新服务媒体——聚焦数据· 改变商业企业动态腾讯宣布启动“青云计划”9月24日,腾讯启动腾讯青云计划。在全球范围内招募一批顶尖技术学生,通过腾讯的平台培养属于中国的互联网科技人才。据悉,青云计划提供全面定制化的培养和极具竞争力的薪酬,在腾讯核心业务中深度参与最前瞻性的技术课题。台媒:英伟达追单AI芯片,台积电增购设备扩充CoWoS产能9月25日,台湾《经济日报》消息,台积电CoWoS..._百川14b

上海亚商投顾:沪指冲高回落 近期热门板块全线退潮-程序员宅基地

文章浏览阅读950次,点赞13次,收藏15次。民生证券表示,尽管当前是原油消费的淡季,但在OPEC+减产的支撑下,原油供需处于紧平衡状态,油价下跌空间有限,油价此前的溢价从23年的90美元/桶以上回落时已基本被消化,同时考虑到当前原油已对悲观的需求预期进行了定价,在地缘政治的扰动下,油价向上动力更加充足,在消费旺季有望实现较大幅度反弹。2、2月21日互动:公司作为国内胶粘剂行业的领军企业之一,将采用先进的生产技术,规模化生产高性能胶粘剂,应用于新能源、交通、绿色包装等领域,将对推动胶粘剂行业整体技术、工艺进步以及产业升级起到积极的作用。

Linux系统centos6安装Redis_centos6 redis-程序员宅基地

文章浏览阅读738次。Linux系统centos6安装Redis_centos6 redis

使用DOS重定位技术执行isqlw(SQL查询分析器)-程序员宅基地

文章浏览阅读2.2k次。作者:chenjieb520 笔者之前在一个项目里面需要调用SQL查询分析器,并且通过命令行的形式将执行结果返回。于是笔者就采用了 DOS重定位技术进行解决。现在简单说明一下如何用VC++来进行实现。命令行调用SQL询分析器isqlw 实用工具(SQL 查询分析器)使您得以输入 Transact-SQL 语句、系统存储过程和脚本文件。通过设置快捷方式或创建_isqlw

推荐文章

热门文章

相关标签