Skip to content

226. Invert Binary Tree (Easy)

Problem

Given the root of a binary tree, invert the tree and return its root. Inverting swaps the left and right children at every node.

Example

  • root = [4,2,7,1,3,6,9][4,7,2,9,6,3,1]

LeetCode 226 · Link · Easy

Approach 1: Recursive DFS (canonical)

Swap children at the current node, then recurse into each subtree.

class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def invert_tree(root):
if not root:
return None
root.left, root.right = invert_tree(root.right), invert_tree(root.left) # L1: swap + recurse
return root

Where the time goes, line by line

Variables: n = number of nodes in the tree, h = tree height, w = max tree width.

LinePer-call costTimes executedContribution
L1 (swap + two recursive calls)O(1) per nodenO(n) ← dominates

Every node triggers exactly one swap and two recursive calls. No node is visited more than once.

Complexity

  • Time: O(n), driven by L1 visiting each node once.
  • Space: O(h) recursion depth (O(log n) balanced, O(n) skewed).

The one-liner swap is the canonical interview answer.

Approach 2: Iterative BFS with a queue

Level-order walk, swapping children at each popped node.

from collections import deque
def invert_tree(root):
if not root:
return None
q = deque([root]) # L1: O(1) init
while q:
node = q.popleft() # L2: O(1) dequeue
node.left, node.right = node.right, node.left # L3: O(1) swap
if node.left:
q.append(node.left) # L4: O(1) enqueue
if node.right:
q.append(node.right) # L5: O(1) enqueue
return root

Where the time goes, line by line

Variables: n = number of nodes in the tree, h = tree height, w = max tree width.

LinePer-call costTimes executedContribution
L2 (dequeue)O(1)nO(n)
L3 (swap)O(1)nO(n)
L4/L5 (enqueue)O(1)nO(n) ← dominates (all lines tie)

Complexity

  • Time: O(n).
  • Space: O(w), max width of the tree.

Approach 3: Iterative DFS with a stack

Same as BFS but with a LIFO stack instead of a FIFO queue.

def invert_tree(root):
if not root:
return None
stack = [root] # L1: O(1) init
while stack:
node = stack.pop() # L2: O(1) pop
node.left, node.right = node.right, node.left # L3: O(1) swap
if node.left:
stack.append(node.left) # L4: O(1) push
if node.right:
stack.append(node.right) # L5: O(1) push
return root

Where the time goes, line by line

Variables: n = number of nodes in the tree, h = tree height, w = max tree width.

LinePer-call costTimes executedContribution
L2 (pop)O(1)nO(n)
L3 (swap)O(1)nO(n)
L4/L5 (push)O(1)nO(n) ← dominates (all lines tie)

Complexity

  • Time: O(n).
  • Space: O(h).

Summary

ApproachTimeSpace
Recursive DFSO(n)O(h)
Iterative BFSO(n)O(w)
Iterative DFSO(n)O(h)

All three are optimal in time. The recursive one-liner is the canonical answer; the iterative variants are useful when recursion depth is a concern on skewed trees.

Test cases

from collections import deque
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def build_tree(vals):
if not vals: return None
root = TreeNode(vals[0])
q = [root]
i = 1
while q and i < len(vals):
node = q.pop(0)
if i < len(vals) and vals[i] is not None:
node.left = TreeNode(vals[i])
q.append(node.left)
i += 1
if i < len(vals) and vals[i] is not None:
node.right = TreeNode(vals[i])
q.append(node.right)
i += 1
return root
def tree_to_list(root):
if not root:
return []
result, q = [], deque([root])
while q:
node = q.popleft()
if node:
result.append(node.val)
q.append(node.left)
q.append(node.right)
else:
result.append(None)
while result and result[-1] is None:
result.pop()
return result
def invert_tree(root):
if not root:
return None
root.left, root.right = invert_tree(root.right), invert_tree(root.left)
return root
def _run_tests():
assert tree_to_list(invert_tree(build_tree([4, 2, 7, 1, 3, 6, 9]))) == [4, 7, 2, 9, 6, 3, 1]
assert invert_tree(None) is None
assert tree_to_list(invert_tree(build_tree([1]))) == [1]
# invert twice = original
t = build_tree([1, 2, 3])
assert tree_to_list(invert_tree(invert_tree(t))) == [1, 2, 3]
print("all tests pass")
if __name__ == "__main__":
_run_tests()