How To Multiply Two Matrices More Efficiently?

How to multiply two matrices more efficiently?
How to multiply two matrices more efficiently?

Matrices are quite intimidating! Right? Well not anymore, in this article we will be discussing a very popular way of matrix multiplication. It’s way more efficient than naïve multiplication.

We all know how tedious it is to multiply two matrices. Even for a program it takes a bit of time and writing the code for such is even worse. But do not worry here we will walk through the whole process.

Introduction

But first let us first recall what a matrix multiplication looks like:

Naïve Approach For Multiplying Matrices

The most naïve way to this is to use three loops and calculate the values of each cell in the answer matrix.

Code Implementation



CPP
#include <bits/stdc++.h>
using namespace std;

void matrixmultiply(int **A, int **B, int **C,int n) 
{ 
    for (int i = 0; i < n; i++) 
    { 
        for (int j = 0; j < n; j++) 
        { 
            C[i][j] = 0; 
            for (int k = 0; k < n; k++) 
            { 
                C[i][j] += A[i][k]*B[k][j]; 
            } 
        } 
    }
} 
int main() {
	// your code goes here
	int n;
	cin>>n;
	int **A = new int*[n];
	for(int i = 0;i<n;i++)
		{
			A[i] = new int[n];
			for(int j =0;j<n;j++)
				cin>>A[i][j];
		}
	int **B = new int*[n];
	for(int i = 0;i<n;i++)
		{
			B[i] = new int[n];
			for(int j =0;j<n;j++)
				cin>>B[i][j];
		}
	int **C = new int*[n];
	for(int i = 0;i<n;i++)
		C[i] = new int[n];
	matrixmultiply(A,B,C,n);
	for(int i =0;i<n;i++){
		cout<<endl;
		for(int j = 0;j<n;j++)
			cout<<C[i][j]<<" ";}
	return 0;
}



JAVA
import java.util.*;
import java.lang.*;
import java.io.*;
class Main
{
public static	void matrixmultiply(int[][] A, int[][] B, int[][] C,int n) 
{ 
    for (int i = 0; i < n; i++) 
    { 
        for (int j = 0; j < n; j++) 
        { 
            C[i][j] = 0; 
            for (int k = 0; k < n; k++) 
            { 
                C[i][j] += A[i][k]*B[k][j]; 
            } 
        } 
    }
} 
	public static void main (String[] args) throws java.lang.Exception
	{
		// your code goes here
		int n;
	Scanner sc = new Scanner(System.in);
	n = sc.nextInt();
	
	int[][] A = new int[n][n];
	for(int i =0;i<n;i++)
		for(int j = 0;j<n;j++)
			A[i][j] = sc.nextInt();
	int[][] B = new int[n][n];
	for(int i =0;i<n;i++)
		for(int j = 0;j<n;j++)
			B[i][j] = sc.nextInt();
	int[][] C = new int[n][n];
	matrixmultiply(A,B,C,n);
	for(int i =0;i<n;i++){
		System.out.println();
		for(int j = 0;j<n;j++)
			System.out.println(C[i][j]+" ");}
	}
}


Time Complexity:On3, where the matrices are of nxn size.
Space Complexity:On2 for C matrix

Now as you can see this method requires cubic time complexity, hence we had to develop some other method to reduce the matrix multiplication complexity.

This method is called Strassen’s Matrix Multiplication. Let’s see this in detail.

Recursive Matrix Multiplication – Strassen has modified this a bit 🙂

The idea of this method is we can find out the matrix multiplication of a 2×2 matrix in constant time. Essentially-

Multiplying these two matrices and putting them in C:

These four operations will take constant time always. This method is a Divide and Conquer Method which means the bigger matrices and broken into smaller subproblems such that –

Can you now observe that we again have 2×2 matrices which have sub problems? Each quadrant can be treated as a subproblem and hence we can have reduced the bigger problem into subproblems. Now let’s see how can we actually multiply them-

blog banner 1

Algorithm For Matrices Multiplication

Multiply(A,B,n){
	if(n<=2){
		formulas above shown
	}
	else{ // each Aij is basically the quadrants shown above
		mid = n/2
		//matrix addition not normal integer addition
		Multiply(A11,B11,n/2) + Multiply(A12,B21,n/2);
		Multiply(A11,B12,n/2) + Multiply(A12,B22,n/2);
		Multiply(A21,B11,n/2) + Multiply(A22,B21,n/2);
		Multiply(A21,B12,n/2) + Multiply(A22,B22,n/2);
	}
}

Strassen’s Matrix Multiplication

The above recursive function also takes On3 time and has eight calls to multiply smaller matrices. Now the whole-time complexity of matrix multiplication depends on these eight calls. This is where Strassen has modified the algorithm to reduce the eight calls to seven and along with it formulated some different formulas for the same process.

These are the 7 equations that correspond to the previous 4 equations and hence this reduces the time for 8 calls to 7 calls.

Code Implementation:



CPP
#include<bits/stdc++.h>
int size;
using namespace std;
void strassen(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C, unsigned int t);
unsigned int nextPowerOfTwo(int n);
void strassenR(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C,
              int t);
void sum(vector< vector<int> > &A,
         vector< vector<int> > &B,
         vector< vector<int> > &C, int t);
void subtract(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C, int t);

void printMatrix(vector< vector<int> > matrix, int n);
void read(string filename, vector< vector<int> > &A, vector< vector<int> > &B);

void multiply(vector< vector<int> > A,
                                   vector< vector<int> > B,
                                   vector< vector<int> > &C, int n) {
    for (int i = 0; i < n; i++) {
        for (int k = 0; k < n; k++) {
            for (int j = 0; j < n; j++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
}

void strassenR(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C, int t) {
    if (t <= size) {
        multiply(A, B, C, t);
        return;
    }
    else {
        int newT = t/2;
        vector<int> inner (newT);
        vector< vector<int> >
            a11(newT,inner), a12(newT,inner), a21(newT,inner), a22(newT,inner),
            b11(newT,inner), b12(newT,inner), b21(newT,inner), b22(newT,inner),
              c11(newT,inner), c12(newT,inner), c21(newT,inner), c22(newT,inner),
            p1(newT,inner), p2(newT,inner), p3(newT,inner), p4(newT,inner),
            p5(newT,inner), p6(newT,inner), p7(newT,inner),
            aResult(newT,inner), bResult(newT,inner);

        int i, j;

        for (i = 0; i < newT; i++) {
            for (j = 0; j < newT; j++) {
                a11[i][j] = A[i][j];
                a12[i][j] = A[i][j + newT];
                a21[i][j] = A[i + newT][j];
                a22[i][j] = A[i + newT][j + newT];

                b11[i][j] = B[i][j];
                b12[i][j] = B[i][j + newT];
                b21[i][j] = B[i + newT][j];
                b22[i][j] = B[i + newT][j + newT];
            }
        }


        sum(a11, a22, aResult, newT); // a11 + a22
        sum(b11, b22, bResult, newT); // b11 + b22
        strassenR(aResult, bResult, p1, newT); // p1 = (a11+a22) * (b11+b22)

        sum(a21, a22, aResult, newT); // a21 + a22
        strassenR(aResult, b11, p2, newT); // p2 = (a21+a22) * (b11)

        subtract(b12, b22, bResult, newT); // b12 - b22
        strassenR(a11, bResult, p3, newT); // p3 = (a11) * (b12 - b22)

        subtract(b21, b11, bResult, newT); // b21 - b11
        strassenR(a22, bResult, p4, newT); // p4 = (a22) * (b21 - b11)

        sum(a11, a12, aResult, newT); // a11 + a12
        strassenR(aResult, b22, p5, newT); // p5 = (a11+a12) * (b22)

        subtract(a21, a11, aResult, newT); // a21 - a11
        sum(b11, b12, bResult, newT); // b11 + b12
        strassenR(aResult, bResult, p6, newT); // p6 = (a21-a11) * (b11+b12)

        subtract(a12, a22, aResult, newT); // a12 - a22
        sum(b21, b22, bResult, newT); // b21 + b22
        strassenR(aResult, bResult, p7, newT); // p7 = (a12-a22) * (b21+b22)

        // calculating c21, c21, c11 e c22:

        sum(p3, p5, c12, newT); // c12 = p3 + p5
        sum(p2, p4, c21, newT); // c21 = p2 + p4

        sum(p1, p4, aResult, newT); // p1 + p4
        sum(aResult, p7, bResult, newT); // p1 + p4 + p7
        subtract(bResult, p5, c11, newT); // c11 = p1 + p4 - p5 + p7

        sum(p1, p3, aResult, newT); // p1 + p3
        sum(aResult, p6, bResult, newT); // p1 + p3 + p6
        subtract(bResult, p2, c22, newT); // c22 = p1 + p3 - p2 + p6

        for (i = 0; i < newT ; i++) {
            for (j = 0 ; j < newT ; j++) {
                C[i][j] = c11[i][j];
                C[i][j + newT] = c12[i][j];
                C[i + newT][j] = c21[i][j];
                C[i + newT][j + newT] = c22[i][j];
            }
        }
    }
}

unsigned int nextPowerOfTwo(int n) {
    return pow(2, int(ceil(log2(n))));
}

void strassen(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C, unsigned int n) {
    //unsigned int n = t;
    unsigned int m = nextPowerOfTwo(n);
    vector<int> inner(m);
    vector< vector<int> > APrep(m, inner), BPrep(m, inner), CPrep(m, inner);

    for(unsigned int i=0; i<n; i++) {
        for (unsigned int j=0; j<n; j++) {
            APrep[i][j] = A[i][j];
            BPrep[i][j] = B[i][j];
        }
    }

    strassenR(APrep, BPrep, CPrep, m);
    for(unsigned int i=0; i<n; i++) {
        for (unsigned int j=0; j<n; j++) {
            C[i][j] = CPrep[i][j];
        }
    }
}

void sum(vector< vector<int> > &A,
         vector< vector<int> > &B,
         vector< vector<int> > &C, int t) {
    int i, j;

    for (i = 0; i < t; i++) {
        for (j = 0; j < t; j++) {
            C[i][j] = A[i][j] + B[i][j];
        }
    }
}

void subtract(vector< vector<int> > &A,
              vector< vector<int> > &B,
              vector< vector<int> > &C, int t) {
    int i, j;

    for (i = 0; i < t; i++) {
        for (j = 0; j < t; j++) {
            C[i][j] = A[i][j] - B[i][j];
        }
    }
}

int getMatrixSize(string filename) {
    string line;
    ifstream infile;
    infile.open (filename.c_str());
    getline(infile, line);
    return count(line.begin(), line.end(), '\t') + 1;
}

void read(string filename, vector< vector<int> > &A, vector< vector<int> > &B) {
    string line;
    FILE* matrixfile = freopen(filename.c_str(), "r", stdin);

    if (matrixfile == 0) {
        cerr << "Could not read file " << filename << endl;
        return;
    }

    int i = 0, j, a;
    while (getline(cin, line) && !line.empty()) {
        istringstream iss(line);
        j = 0;
        while (iss >> a) {
            A[i][j] = a;
            j++;
        }
        i++;
    }

    i = 0;
    while (getline(cin, line)) {
        istringstream iss(line);
        j = 0;
        while (iss >> a) {
            B[i][j] = a;
            j++;
        }
        i++;
    }

    fclose (matrixfile);
}

void printMatrix(vector< vector<int> > matrix, int n) {
    for (int i=0; i < n; i++) {
        for (int j=0; j < n; j++) {
            if (j != 0) {
                cout << "\t";
            }
            cout << matrix[i][j];
        }
        cout << endl;
    }
}

int main (int argc, char* argv[]) {
    string filename;
    if (argc < 3) {
        filename = "2000.in";
    } else {
        filename = argv[2];
    }

    if (argc < 5) {
        size = 16;
    } else {
        size = atoi(argv[4]);
    }
    int n = getMatrixSize(filename);
    vector<int> inner (n);
    vector< vector<int> > A(n, inner), B(n, inner), C(n, inner);
    read (filename, A, B);
    strassen(A, B, C, n);
    printMatrix(C, n);
    return 0;
}


T(N) = 7T(N/2) +  ON2
From Master's Theorem, time complexity of above method is 
O(NLog7) which is approximately O(N2.8074)

Frequently Asked Questions

What are other ways to multiply two matrices?

Well, there are a lot of ways although another interesting one is matrix chain multiplication.

How much faster is Strassen’s algorithm?

When we are talking about a large value of n, we will have nearly 12.5% increment in speed.

Can this be applied to matrices of different size?

Well, this is applicable till two matrices are valid for multiplication.

This is quite a long article but I hope this is helpful to an aspiring developer or programmer.

By Aniruddha Guin