역쇄법칙을 사용하여 미분계수 구하기.(연습)
밑바닥 153p 참고
class Function1:
def forward(self, x):
z = x - 2
return z
def backward(self, dy_dz):
self.dy_dx = 1
self.dy_dx *= dy_dz # chain rule
return self.dy_dx
class Function2:
def forward(self, z):
self.y = 2*(z**2)
return self.y
def backward(self):
return self.y * 4
class FunctionUnion:
def __init__(self):
print(f'f(x) = 2*(x - 2)**2\n')
self.f1 = Function1()
self.f2 = Function2()
def __call__(self, x):
self.x = x
self.forward(self.x)
self.backward()
self.print()
def forward(self, x):
z = self.f1.forward(x)
self.y = self.f2.forward(z)
return self.y
def backward(self):
dy_dz = self.f2.backward()
self.dy_dx = self.f1.backward(dy_dz)
return self.dy_dx
def print(self):
print(f'---- x = {self.x} ----')
print(f'순전파: f({self.x}) = {self.y}')
print(f"역전파:f'({self.x}) = {self.dy_dx}\n")
f = FunctionUnion()
print(f'순전파 : {f.forward(4)}')
print(f'역전파 : {f.backward()}')
X = [i for i in range(-1, 4)]
print(f'x = {X}')
for x in range(len(X)):
f(x)
f(x) = 2*(x - 2)**2
순전파 : 8
역전파 : 32
x = [-1, 0, 1, 2, 3]
---- x = 0 ----
순전파: f(0) = 8
역전파:f'(0) = 32
---- x = 1 ----
순전파: f(1) = 2
역전파:f'(1) = 8
---- x = 2 ----
순전파: f(2) = 0
역전파:f'(2) = 0
---- x = 3 ----
순전파: f(3) = 2
역전파:f'(3) = 8
---- x = 4 ----
순전파: f(4) = 8
역전파:f'(4) = 32
'개발 > AI' 카테고리의 다른 글
[Python] Affine, Sigmoid, BCELoss 구현 (0) | 2023.11.14 |
---|---|
[Python] GBL : Gradient-based Learning (1) | 2023.11.13 |
[Python] Logic Gate 클래스 만들기 (0) | 2023.11.07 |
[Python] Logic Gate 함수 구현 연습 (1) | 2023.11.06 |
[AutoViz] EDA를 도와줄 시각화 툴 (1) | 2023.10.31 |