import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

f = lambda x :(x-3)**2+2.5*x-7.5
f2 = lambda x :-(x-3)**2+2.5*x-7.5

求解导数 导数为0 取最小值

x = np.linspace(-2,5,100)
y = f(x)
plt.plot(x,y)

在这里插入图片描述

梯度下降求最小值

#导数函数
d = lambda x:2*(x-3)*1+2.5

#学习率 需调节 每次改变数值的时候,改变多少
learning_rate = 0.1

min_value = np.random.randint(-3,5,size =1)[0]
print('-'*30,min_value)
#记录数据更新了,原来的值,上一步的值
min_value_last = min_value+0.1

#tollerence容忍度,误差,在万分之一,任务结束
tol = 0.0001
count = 0
while True:
    if np.abs(min_value - min_value_last)<tol:
        break
#梯度下降
    min_value_last = min_value
#更新值
    min_value = min_value - learning_rate*d(min_value)
    print("++++++++++%d"%count,min_value)
    count+=1
print("*"*30,min_value)

------------------------------ -2 ++++++++++0 -1.25 ++++++++++1 -0.6499999999999999 ++++++++++2 -0.16999999999999993 ++++++++++3 0.21400000000000008 ++++++++++4 0.5212000000000001 ++++++++++5 0.7669600000000001 ++++++++++6 0.9635680000000001 ++++++++++7 1.1208544 ++++++++++8 1.24668352 ++++++++++9 1.347346816 ++++++++++10 1.4278774528 ++++++++++11 1.49230196224 ++++++++++12 1.543841569792 ++++++++++13 1.5850732558336 ++++++++++14 1.6180586046668801 ++++++++++15 1.644446883733504 ++++++++++16 1.6655575069868032 ++++++++++17 1.6824460055894426 ++++++++++18 1.695956804471554 ++++++++++19 1.7067654435772432 ++++++++++20 1.7154123548617946 ++++++++++21 1.7223298838894356 ++++++++++22 1.7278639071115485 ++++++++++23 1.7322911256892388 ++++++++++24 1.735832900551391 ++++++++++25 1.7386663204411128 ++++++++++26 1.7409330563528902 ++++++++++27 1.7427464450823122 ++++++++++28 1.7441971560658498 ++++++++++29 1.74535772485268 ++++++++++30 1.7462861798821439 ++++++++++31 1.7470289439057152 ++++++++++32 1.7476231551245722 ++++++++++33 1.7480985240996578 ++++++++++34 1.7484788192797263 ++++++++++35 1.748783055423781 ++++++++++36 1.7490264443390249 ++++++++++37 1.7492211554712198 ++++++++++38 1.749376924376976 ++++++++++39 1.7495015395015807 ++++++++++40 1.7496012316012646 ****** 1.7496012316012646

更新值learning_rate*d(max_value) 最大/最小值导数为0
就可能满足np.abs(max_value - max_value_last)<precision:
d2 = lambda x:-2*(x-3)*1+2.5
#学习率 需调节 每次改变数值的时候,改变多少
learning_rate = 0.1
max_value = np.random.randint(-3,5,size =1)[0]
print('-'*30,min_value)
#记录数据更新了,原来的值,上一步的值
max_value_last = max_value+0.1
result =[]
#tollerence容忍度,误差,在万分之一,任务结束
#precision精确度, 误差,在万分之一,任务结束
precision = 0.0001
count = 0
while True:
    if count>3000:
#         避免梯度消失 rate =1
#        避免梯度爆炸 导数更新值有问题时  或 rate =10
        break
    if np.abs(max_value - max_value_last)<precision:
        break
#梯度下降
    max_value_last = max_value


#更新值learning_rate*d(max_value) 最大/最小值导数为0 
# 就可能满足np.abs(max_value - max_value_last)<precision:

    max_value = max_value + learning_rate*d2(max_value)
    result.append(max_value)
    print("++++++++++%d"%count,max_value)
    count+=1
print("*"*30,max_value)

------------------------------ 1.7496012316012646 ++++++++++0 0.050000000000000044 ++++++++++1 0.8900000000000001 ++++++++++2 1.5620000000000003 ++++++++++3 2.0996 ++++++++++4 2.52968 ++++++++++5 2.873744 ++++++++++6 3.1489952 ++++++++++7 3.36919616 ++++++++++8 3.545356928 ++++++++++9 3.6862855424 ++++++++++10 3.79902843392 ++++++++++11 3.889222747136 ++++++++++12 3.9613781977088 ++++++++++13 4.01910255816704 ++++++++++14 4.065282046533632 ++++++++++15 4.102225637226906 ++++++++++16 4.131780509781525 ++++++++++17 4.15542440782522 ++++++++++18 4.174339526260176 ++++++++++19 4.18947162100814 ++++++++++20 4.201577296806512 ++++++++++21 4.2112618374452095 ++++++++++22 4.219009469956168 ++++++++++23 4.225207575964935 ++++++++++24 4.230166060771948 ++++++++++25 4.234132848617558 ++++++++++26 4.237306278894047 ++++++++++27 4.239845023115238 ++++++++++28 4.24187601849219 ++++++++++29 4.2435008147937525 ++++++++++30 4.244800651835002 ++++++++++31 4.2458405214680015 ++++++++++32 4.246672417174401 ++++++++++33 4.247337933739521 ++++++++++34 4.247870346991617 ++++++++++35 4.248296277593293 ++++++++++36 4.248637022074634 ++++++++++37 4.248909617659708 ++++++++++38 4.2491276941277665 ++++++++++39 4.249302155302213 ++++++++++40 4.249441724241771 ++++++++++41 4.249553379393417 ++++++++++42 4.249642703514733 ****** 4.249642703514733

ret = ret= ret*step

x = np.linspace(0,6,100)
y = f2(x)
plt.plot(x,y)

result = np.asanyarray(result)
plt.plot(result,f2(result),'*')

在这里插入图片描述

版权声明:如无特殊说明,文章均为本站原创,转载请注明出处

本文链接:http://wakemeupnow.cn/article/tidu/