개발/AI

[Python] Backpropagation

jykim23 2023. 11. 13. 17:21

역쇄법칙을 사용하여 미분계수 구하기.(연습)

밑바닥 153p 참고

 

class Function1:
    def forward(self, x):
        z = x - 2
        return z
    
    def backward(self, dy_dz):
        self.dy_dx = 1
        self.dy_dx *= dy_dz # chain rule
        return self.dy_dx
    
class Function2:
    def forward(self, z):
        self.y = 2*(z**2)
        return self.y
    
    def backward(self):
        return self.y * 4

class FunctionUnion:
    def __init__(self):
        print(f'f(x) = 2*(x - 2)**2\n')
        self.f1 = Function1()
        self.f2 = Function2()

    def __call__(self, x):
        self.x = x
        self.forward(self.x)
        self.backward()
        self.print()
        
    def forward(self, x):
        z = self.f1.forward(x)
        self.y = self.f2.forward(z)
        return self.y
    
    def backward(self):
        dy_dz = self.f2.backward()
        self.dy_dx = self.f1.backward(dy_dz)
        return self.dy_dx
    
    def print(self):
        print(f'---- x = {self.x} ----')
        print(f'순전파: f({self.x}) = {self.y}')
        print(f"역전파:f'({self.x}) = {self.dy_dx}\n")        

f = FunctionUnion()

print(f'순전파 : {f.forward(4)}')
print(f'역전파 : {f.backward()}')


X = [i for i in range(-1, 4)]
print(f'x = {X}')
for x in range(len(X)):
    f(x)

f(x) = 2*(x - 2)**2

순전파 : 8
역전파 : 32
x = [-1, 0, 1, 2, 3]
---- x = 0 ----
순전파: f(0) = 8
역전파:f'(0) = 32

---- x = 1 ----
순전파: f(1) = 2
역전파:f'(1) = 8

---- x = 2 ----
순전파: f(2) = 0
역전파:f'(2) = 0

---- x = 3 ----
순전파: f(3) = 2
역전파:f'(3) = 8

---- x = 4 ----
순전파: f(4) = 8
역전파:f'(4) = 32