Basics of DP with Trees

Saksham Gupta
Last Updated: May 13, 2022


Dynamic Programming(DP) and Trees are among the most asked topics in tech interviews and what is even more interesting is the combination of both of these. We know that DP is a problem-solving technique that divides problems into overlapping sub-problems that follow the optimal substructure. We've all heard of DP problems like subset sum, knapsack, and coin change. But we can also apply DP to other problems and drastically reduce the time complexity, and this blog will show how we can apply DP on trees.

But what kind of problems? How are we going to identify it?
We'll understand all this step by step, but first, let's see one of the classical problems of the tree from which you are already familiar, i.e., the diameter of the tree.

Example Problem

Now, the diameter of a tree or width of a tree is defined as the number of nodes present between the longest path between two nodes. Let’s understand this by the following example.

Nodes in green represent the longest path between two leaf nodes of the given tree. The diameter of the above tree is 9.

There can be two cases. The diameter of any subtree tree may include its root or may not include its root.

Case in which it includes the root of the tree, the diameter of that particular sub-tree is certainly the maximum sum of the left height and the right height of its root node.

Otherwise, the diameter could be from the left subtree alone or the right subtree alone.

Thus to find out the diameter of the tree, we need the following:

  • Diameter of left subtree
  • Diameter of right subtree
  • Height of left subtree
  • Height of right subtree

For the height of the tree, we can simply create a recursive height() function that will take the root of the tree as a parameter and will calculate the height using the height of the left & right subtree. Then the height of the tree at any point is equal to the maximum height of its left & right subtree plus 1

And the final answer(diameter) is the maximum among the diameter of the left subtree, the diameter of the right subtree, and the sum of the height of the left subtree and height of the right subtree.

Now you guys must be thinking that we know about it, and why are we discussing it again? 

The point is, have you guys ever wondered about the time complexity?

Basically, we are traversing every node of the tree, and from then, we are finding the height of the left and right subtree. Thus our time complexity comes to O(N*N), where ‘N’ is the number of nodes in the tree.

We are repetitively calculating the height of subtrees, and in this process, we end up calculating the height for the same subtree again and again. But do we actually need to calculate the height at each step? Or can we simply use some previously-stored result? The first thing that comes to mind after reading previously-stored is Dynamic Programming, and you are right. Instead of doing the same calculations, again and again, we can take the help of DP and apply it to our tree, i.e., DP on trees.

Thus instead of calling the height() function(used to calculate the height of a binary tree) every time, the primary idea behind this solution is to determine the height of the subtree within the same recursive function.

Let's create a recursive method called getDiameter(TreeNode<int> *root, int& height) that returns the diameter of the provided subtree rooted at the "root" node. Also, we will pass the variable ‘height,’ as a reference which represents the height of that subtree.

Let's look at the algorithm for better understanding.


  • If the root node is NULL, set 'HEIGHT' = 0 and return 0 because, for an empty tree, both height and diameter would be zero.
  • Create two variables, 'LEFTHEIGHT' and 'RIGHTHEIGHT,' and initialize them to 0, which will denote the height of the right subtree and left subtree, respectively.
  • Store the value of the diameter of the left subtree in 'LEFTDIAMETER' and right subtree in 'RIGHTDIAMETER' by calling recursive function getDiameter(root->left,LEFTHEIGHT) and getDiameter(root->right, RIGHTHEIGHT) respectively
  • Height of the tree will be updated as 'height' = max('LEFTHEIGHT', 'RIGHTHEIGHT') + 1
  • Return the maximum of above terms i.e. max(LEFTDIAMETER, RIGHTDIAMETER, LEFTHEIGHT + RIGHTHEIGHT).

Let’s also see the code for better understanding.


#include <iostream>
#include <queue>
using namespace std;

// Binary Tree Node Class.
template <typename T>
class BinaryTreeNode
    T data;
    BinaryTreeNode<T> *left;
    BinaryTreeNode<T> *right;

    BinaryTreeNode(T data)
        this->data = data;
        left = NULL;
        right = NULL;

// Function to calculate the diameter of binary tree rooted at node, 'ROOT'.
int getDiameter(BinaryTreeNode<int> *root, int &height)
    if (root == NULL)
        // The 'HEIGHT' and 'DIAMETER' of an empty tree will be 0.
        height = 0;
        return 0;

    // To store the height of left and right subtrees.
    int leftHeight = 0;
    int rightHeight = 0;

    // Recur for left subtree and get the height as well as diameter.
    int leftDiameter = getDiameter(root->left, leftHeight);

    // Recur for right subtree and get the height as well as diameter.
    int rightDiameter = getDiameter(root->right, rightHeight);

    // Update the height of the given binary tree.
    height = max(leftHeight, rightHeight) + 1;

    // Diameter of given binary tree.
    int diameter = max(leftDiameter, max(rightDiameter, leftHeight + rightHeight));

    return diameter;

int diameterOfBinaryTree(BinaryTreeNode<int> *root)
    // Initialize a variable to store the height of the binary tree.
    int height = 0;

    // Recursive function to find diameter.
    return getDiameter(root, height);

// For taking level order input.
BinaryTreeNode<int> *takeInput()
    int rootData;
    cin >> rootData;
    if (rootData == -1)
        return NULL;
    BinaryTreeNode<int> *root = new BinaryTreeNode<int>(rootData);
    queue<BinaryTreeNode<int> *> q;
    while (!q.empty())
        BinaryTreeNode<int> *currentNode = q.front();
        int leftChild, rightChild;

        cin >> leftChild;
        if (leftChild != -1)
            BinaryTreeNode<int> *leftNode = new BinaryTreeNode<int>(leftChild);
            currentNode->left = leftNode;

        cin >> rightChild;
        if (rightChild != -1)
            BinaryTreeNode<int> *rightNode = new BinaryTreeNode<int>(rightChild);
            currentNode->right = rightNode;
    return root;

int main()
    BinaryTreeNode<int> *root = takeInput();
    cout << diameterOfBinaryTree(root);


1 2 3 4 7 -1 -1 -1 -1 -1 -1

The input tree will look something like this.



Time complexity

O(N), Where ‘N’ is the number of nodes in the given binary tree.

Since we are traversing all the tree nodes only once, which takes O(N) time, thus the overall time complexity will be O(N).
Thus we have reduced the time complexity from O(N*N) to O(N) just by storing the height of every subtree and then reusing it.

Space complexity

O(N), Where ‘N’ is the number of nodes in the given binary tree. Since we are doing a recursive tree traversal and in the worst case (Skewed Trees), all nodes of the given tree can be stored in the call stack. So the overall space complexity will be O(N).

Key Takeaways

Now you have understood how we can apply DP on trees and saw one example on the same. You must be thrilled about learning the blend of these two concepts. But you should not stop here, and a good coder never stops practicing, So head over to our practice platform CodeStudio to practice top problems of DP on trees and many other concepts. Till then, Happy Coding!

Was this article helpful ?