Python 梯度下降法

系统 1921 0

接上篇博客

题目描述:
自定义一个可微并且存在最小值的一元函数,用梯度下降法求其最小值。并绘制出学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数的关系曲线,根据该曲线给出简单的分析。

代码:

            
              
                # -*- coding: utf-8 -*-
              
              
                """
Created on Tue Jun  4 10:19:03 2019

@author: Administrator
"""
              
              
                import
              
               numpy 
              
                as
              
               np

              
                import
              
               matplotlib
              
                .
              
              pyplot 
              
                as
              
               plt
plot_x
              
                =
              
              np
              
                .
              
              linspace
              
                (
              
              
                -
              
              
                1
              
              
                ,
              
              
                6
              
              
                ,
              
              
                150
              
              
                )
              
              
                #在-1到6之间等距的生成150个数
              
              
plot_y
              
                =
              
              
                (
              
              plot_x
              
                -
              
              
                2.5
              
              
                )
              
              
                **
              
              
                2
              
              
                +
              
              
                3
              
              
                # 同时根据plot_x来生成plot_y(y=(x-2.5)²+3)
              
              

plt
              
                .
              
              plot
              
                (
              
              plot_x
              
                ,
              
              plot_y
              
                )
              
              
plt
              
                .
              
              show
              
                (
              
              
                )
              
              
                ###定义一个求二次函数导数的函数dJ
              
              
                def
              
              
                dJ
              
              
                (
              
              x
              
                )
              
              
                :
              
              
                return
              
              
                2
              
              
                *
              
              
                (
              
              x
              
                -
              
              
                2.5
              
              
                )
              
              
                ###定义一个求函数值的函数J
              
              
                def
              
              
                J
              
              
                (
              
              x
              
                )
              
              
                :
              
              
                try
              
              
                :
              
              
                return
              
              
                (
              
              x
              
                -
              
              
                2.5
              
              
                )
              
              
                **
              
              
                2
              
              
                +
              
              
                3
              
              
                except
              
              
                :
              
              
                return
              
              
                float
              
              
                (
              
              
                'inf'
              
              
                )
              
              

x
              
                =
              
              
                0.0
              
              
                #随机选取一个起始点
              
              
eta
              
                =
              
              
                0.1
              
              
                #eta是学习率,用来控制步长的大小
              
              
epsilon
              
                =
              
              
                1e
              
              
                -
              
              
                8
              
              
                #用来判断是否到达二次函数的最小值点的条件
              
              
history_x
              
                =
              
              
                [
              
              x
              
                ]
              
              
                #用来记录使用梯度下降法走过的点的X坐标
              
              
count
              
                =
              
              
                0
              
              
                min
              
              
                =
              
              
                0
              
              
                while
              
              
                True
              
              
                :
              
              
    gradient
              
                =
              
              dJ
              
                (
              
              x
              
                )
              
              
                #梯度(导数)
              
              
    last_x
              
                =
              
              x
    x
              
                =
              
              x
              
                -
              
              eta
              
                *
              
              gradient
    history_x
              
                .
              
              append
              
                (
              
              x
              
                )
              
              
    count
              
                =
              
              count
              
                +
              
              
                1
              
              
                if
              
              
                (
              
              
                abs
              
              
                (
              
              J
              
                (
              
              last_x
              
                )
              
              
                -
              
              J
              
                (
              
              x
              
                )
              
              
                )
              
              
                <
              
              epsilon
              
                )
              
              
                :
              
              
                #用来判断是否逼近最低点
              
              
                min
              
              
                =
              
              x
        
              
                break
              
              
    
plt
              
                .
              
              plot
              
                (
              
              plot_x
              
                ,
              
              plot_y
              
                )
              
                   
plt
              
                .
              
              plot
              
                (
              
              np
              
                .
              
              array
              
                (
              
              history_x
              
                )
              
              
                ,
              
              J
              
                (
              
              np
              
                .
              
              array
              
                (
              
              history_x
              
                )
              
              
                )
              
              
                ,
              
              color
              
                =
              
              
                'r'
              
              
                ,
              
              marker
              
                =
              
              
                '*'
              
              
                )
              
              
                #绘制x的轨迹
              
              
plt
              
                .
              
              show
              
                (
              
              
                )
              
              
                print
              
              
                'min_x ='
              
              
                ,
              
              
                (
              
              
                min
              
              
                )
              
              
                print
              
              
                'min_y ='
              
              
                ,
              
              
                (
              
              J
              
                (
              
              
                min
              
              
                )
              
              
                )
              
              
                #打印到达最低点时y的值
              
              
                print
              
              
                'count ='
              
              
                ,
              
              
                (
              
              count
              
                )
              
              

sum_eta
              
                =
              
              
                [
              
              
                ]
              
              
result
              
                =
              
              
                [
              
              
                ]
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              
                1
              
              
                ,
              
              
                10
              
              
                ,
              
              
                1
              
              
                )
              
              
                :
              
              
    x
              
                =
              
              
                0.0
              
              
                #随机选取一个起始点
              
              
    eta
              
                =
              
              i
              
                *
              
              
                0.1
              
              
    sum_eta
              
                .
              
              append
              
                (
              
              eta
              
                )
              
              
    epsilon
              
                =
              
              
                1e
              
              
                -
              
              
                8
              
              
                #用来判断是否到达二次函数的最小值点的条件
              
              
    num
              
                =
              
              
                0
              
              
                min
              
              
                =
              
              
                0
              
              
                while
              
              
                True
              
              
                :
              
              
        gradient
              
                =
              
              dJ
              
                (
              
              x
              
                )
              
              
                #梯度(导数)
              
              
        last_x
              
                =
              
              x
        x
              
                =
              
              x
              
                -
              
              eta
              
                *
              
              gradient
        num
              
                =
              
              num
              
                +
              
              
                1
              
              
                if
              
              
                (
              
              
                abs
              
              
                (
              
              J
              
                (
              
              last_x
              
                )
              
              
                -
              
              J
              
                (
              
              x
              
                )
              
              
                )
              
              
                <
              
              epsilon
              
                )
              
              
                :
              
              
                #用来判断是否逼近最低点
              
              
                min
              
              
                =
              
              x
            
              
                break
              
              
    
    result
              
                .
              
              append
              
                (
              
              num
              
                )
              
              
                #记录学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数
              
              

plt
              
                .
              
              scatter
              
                (
              
              sum_eta
              
                ,
              
              result
              
                ,
              
              c
              
                =
              
              
                'r'
              
              
                )
              
              
plt
              
                .
              
              plot
              
                (
              
              sum_eta
              
                ,
              
              result
              
                ,
              
              c
              
                =
              
              
                'r'
              
              
                )
              
              
plt
              
                .
              
              title
              
                (
              
              
                "relation"
              
              
                )
              
              
plt
              
                .
              
              xlabel
              
                (
              
              
                "eta"
              
              
                )
              
              
plt
              
                .
              
              ylabel
              
                (
              
              
                "count"
              
              
                )
              
              
plt
              
                .
              
              show

              
                print
              
              
                (
              
              result
              
                )
              
            
          

运行结果:
Python 梯度下降法_第1张图片
Python 梯度下降法_第2张图片
结果分析:
函数y=(x-2.5)²+3从学习率和迭代次数的关系图上我们可以知道当学习率较低时迭代次数较多,随着学习率的增大,迭代次数开始逐渐减少,当学习率为0.5时迭代次数最少,之后随着学习率的增加,迭代次数开始增加,当学习率为0.9时迭代次数和0.1时相等。关于0.5成对称分布。


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论