이번에는 새터를 자동으로 구해주는 다른 알고리즘인 Stochastic Gradient descent를 알아보도록 하겠습니다.


앞서 배운 Gradient descent에 Stochastic의 접두사를 붙혀서 SGD라고 불립니다.

GD는 "batch" Gradient descent 라고도 불리며 이것은 매 이터레이션 마다 모든 트래이닝 데이터를 구해주기 때문에 계산의 코스트가 높다는 단점이 있습니다.


SGD는 위 단점을 해결하고자 Sum을 하지 않고 자동 새터를 구해주는 알고리즘 입니다.





위와 같이 SGD는 GD와 거의 동일한데 몇가지 다른 점이 있습니다.


1. 먼저 데이터를 셔플 해준다. (가장 큰 특징)

- 이유는 첫번째 부터 접근을 안하고 임의의 수 부터 접근하여 새터값을 빨리 줄이려는 이유인 것 같음. 

(적정한 수가 랜덤으로 선택되어 일정 새터값으로 수렴 되기 때문에)

2. 매 이터레이션 마다 모든 데이터를 더하는 sum of all 생략

3. m으로 나누지 않는다.



위와 같이 진행되며

매번 마다 sum을 안하기 때문에 코스트가 적게 듭니다.


코드는


function [theta, J_history] = stochasticGradientDescent(x, y, theta, alpha, iterations)


    m = length(y); % number of rows

    n = size(x, 2); % number of features


    h = zeros(n, 1);

    J_history = zeros(iterations, 1);

    

    % data shuffle

    myperm = randperm(m);

    Xshuffle = x(myperm , :);

    Yshuffle = y(myperm);


    for iter=1:iterations

        for i=1:m

            for j=1:n

                h(j) = (theta' * Xshuffle(i, :)' -  Yshuffle(i)) * Xshuffle(i, j);

            end

            for j=1:n

                theta(j) = theta(j) - alpha * h(j);

            end

        end

        % add history

        J_history(iter) = computeCostMulti(Xshuffle, Yshuffle, theta);

    end

end


위와 같이 데이터를 처음 한번 셔플해주고(row의 순서가 바뀜)

sum이 없고 m으로 나눠주는 곳도 없습니다.


위와 같이 구한것에 단항의 데이터를 이용하여 새터와 코스트 펑션을 구해보면


>>  data = load('ex1data1.txt');

>>  n = size(data, 2);

>>  x = data(:, 1:n-1);

>>  y = data(:, n);

>>  m = length(x);

>>  x = [ones(m, 1), x]; % add x0 1

>>  theta = zeros(n, 1);

>>  alpha = 0.001;

>>  iterations = 1000;

>>  [theta, J_history] = stochasticGradientDescent(x, y, theta, alpha, iterations);

>> theta

theta =


  -3.9598

   1.2511


>> computeCostMulti(x, y, theta)

ans =  4.5857




와 앞에서 구한 GD와 비슷한 것을 볼 수 있습니다.


여러번 진행해 보면 진행할 때 마다 새터값이 변경됩니다.

(랜덤으로 셔플하기 때문에 특정 구간으로 수렴하는 것 같은데 정확히 잘 모르겠네요)


일단 여기서 SGD 설명을 마치며 아직 의문이 드는 사항이 몇개 있는데 이것은 추후 공부하고 추가 하도록 하겠습니다.

(현재는 알파값은 GD보다 더 작게, iteration은 더 크게 해야 좋은 수가 나오는 것 같습니다.)




'ML > octave구현 - w1' 카테고리의 다른 글

[octave] Normal equation  (0) 2016.03.13
[octave] Multiple Variable Gradient Descent  (0) 2016.03.11
[octave] feature mean normalization  (0) 2016.03.10
[octave] multiple variable cost function  (0) 2016.03.10
[octave] 임의 값 예측  (0) 2016.03.07

+ Recent posts