Sunday, January 17, 2021

Gradient Descent Algorithm for Beginners (Quadratic Function Example)

In this article, I explain how gradient descent algorithm (GDA) works. I use a quadratic function for explanation:

f(x) = ax^2+bx+c

As you know, gradient of a function is first derivative of function (f'(x)). First derivative of f(x):

f'(x) = 2ax + b;

GDA is a iterative algorithm that calculates gradient of function f(x) for x in each iteration, and calculates next point by substracting gradient from x. GDA uses a learning parameter (alfa) that regularizes changings of x. So in each iteration, GDA calculates next x value:

x(t+1) = x(t) - alfa*f'(x(t))

So, why does GDA use gradient of function?

Think yourself on a hill and you want to get down. If slope is steep, you take a big step, otherwise a small step. Because while coming to end of hill, slope will be smaller until it is zero.

If you think of gradient as a vector, gradient's direction always guide you to end of hill or top of hill.

Note that, we can use GDA to find minimum point of a convex function, and gradient ascent algorithm to find maximum point of a concave function.

I prepared a working javascript example. You can calculate minimum point of quadratic function and visualize calculated point and function with this code.

In each iteration, calculated point (x,y) is displayed as a red dot.


->x (starting value of x)
->alfa (learning parameter)
->maxIteration (maximum iteration count)
->a
->b
->c

Javascript Codes:

<script>

        var x = 200;
        var alfa = 0.01;
        var maxIteration = 5000;

        var a = 1;
        var b = -5;
        var c = 6;

        var fx = (x) => a * x * x + b * x + c;

        var dxfx = (x) => 2 * a * x + b;

        function start() {

            setValues();
            drawCurve();

            var canvas = document.getElementById("g");
            let y = 0;

            for (var i = 0; i < maxIteration; i++) {

                x = x - alfa * dxfx(x);
                y = fx(x);

                let circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle');;
                circle.setAttributeNS(null, "cx", x);
                circle.setAttributeNS(null, "cy", y);
                circle.setAttributeNS(null, "r", 1);
                circle.setAttributeNS(null, "fill", "red");

                canvas.appendChild(circle);
            }

            let resultDiv = document.getElementById("resultDiv");
            resultDiv.innerHTML = "x:" + x + ", y:" + y;
        }

        function setValues() {

            x = parseFloat(document.getElementById("x").value);
            alfa = parseFloat(document.getElementById("alfa").value);
            maxIteration = parseFloat(document.getElementById("maxIteration").value);
            a = parseFloat(document.getElementById("a").value);
            b = parseFloat(document.getElementById("b").value);
            c = parseFloat(document.getElementById("c").value);
        }

        function drawCurve() {

            var canvas = document.getElementById("g");

            for (var i = 5000; i >= -5000; i -= 1) {

                let line = document.createElementNS('http://www.w3.org/2000/svg', 'line');     
                line.setAttributeNS(null, "x1", i);
                line.setAttributeNS(null, "y1", fx(i));
                line.setAttributeNS(null, "x2", i - 1);
                line.setAttributeNS(null, "y2", fx(i - 1));
                line.setAttributeNS(null, "style", "stroke: rgb(0, 255, 0); stroke-width:1");
                canvas.appendChild(line);
            }
        }

    </script>
HTML Codes:
<div>
        <input id="x" type="text" value="200" />->x
    </div>
    <div>
        <input id="alfa" type="text" value="0.01" />->alfa
    </div>
    <div>
        <input id="maxIteration" type="text" value="5000" />->maxIteration
    </div>
    <div>
        <input id="a" type="text" value="1" />->a
    </div>
    <div>
        <input id="b" type="text" value="-5" />->b
    </div>
    <div>
        <input id="c" type="text" value="6" />->c
    </div>

    <div style="margin-top:20px" id="resultDiv"></div>
    <div style="margin-top:20px">
        <button type="button" onclick="start()">START</button>
    </div>

    <div style="margin-top:50px">
        <svg id="canvas" width="500" height="500">
            <g id="g" transform="translate(250 250) scale(1,-1)">
                <line x1="-1000" y1="0" x2="1000" y2="0" style="stroke: rgb(0, 0, 255); stroke-width:1"></line>
                <line x1="0" y1="-1000" x2="0" y2="1000" style="stroke: rgb(0, 0, 255); stroke-width:1"></line>

            </g>
        </svg>
    </div>

No comments:

Post a Comment