15. 3Sum (Medium)
Problem
Given an integer array nums, return all unique triplets [nums[i], nums[j], nums[k]] such that i != j != k and nums[i] + nums[j] + nums[k] == 0. The solution set must not contain duplicate triplets.
Example
nums = [-1, 0, 1, 2, -1, -4]→[[-1, -1, 2], [-1, 0, 1]]nums = [0, 1, 1]→[]nums = [0, 0, 0]→[[0, 0, 0]]
LeetCode 15 · Link · Medium
Approach 1: Brute force, every triplet + set dedup
Try every (i, j, k) with i < j < k. Canonicalize each qualifying triplet (sorted tuple) into a set.
def three_sum(nums: list[int]) -> list[list[int]]: n = len(nums) # L1: O(1) found = set() # L2: O(1) for i in range(n): # L3: outer loop for j in range(i + 1, n): # L4: middle loop for k in range(j + 1, n): # L5: inner loop if nums[i] + nums[j] + nums[k] == 0: # L6: O(1) check found.add(tuple(sorted((nums[i], nums[j], nums[k])))) # L7: O(1) add return [list(t) for t in found] # L8: O(m)Where the time goes, line by line
Variables: n = len(nums), m = number of unique triplets found.
| Line | Per-call cost | Times executed | Contribution |
|---|---|---|---|
| L3 (outer loop) | O(1) | n | O(n) |
| L4 (middle loop) | O(1) | n² | O(n²) |
| L5, L6, L7 (inner loop + check) | O(1) | ~n³/6 | O(n³) ← dominates |
| L8 (collect) | O(m) | 1 | O(m) |
Triple nested loops, one per triplet position.
Complexity
- Time: O(n³), driven by L5/L6 (triple nested loop).
- Space: O(m) where m is the number of unique triplets.
Too slow for the problem’s n ≤ 3000 constraint.
Approach 2: Fix i, hash-set for two-sum
For each anchor i, scan the rest with a hash set to find pairs (x, y) such that x + y = -nums[i].
def three_sum(nums: list[int]) -> list[list[int]]: nums.sort() # L1: O(n log n) n = len(nums) # L2: O(1) result = [] # L3: O(1) for i in range(n - 2): # L4: outer loop, n-2 iters if i > 0 and nums[i] == nums[i - 1]: # L5: O(1) skip duplicate anchor continue seen = set() # L6: O(1) fresh set per anchor j = i + 1 # L7: O(1) while j < n: # L8: inner scan need = -nums[i] - nums[j] # L9: O(1) complement if need in seen: # L10: O(1) lookup result.append([nums[i], need, nums[j]]) # L11: O(1) while j + 1 < n and nums[j + 1] == nums[j]: # L12: skip j-dupes j += 1 seen.add(nums[j]) # L13: O(1) j += 1 # L14: O(1) return resultWhere the time goes, line by line
Variables: n = len(nums).
| Line | Per-call cost | Times executed | Contribution |
|---|---|---|---|
| L1 (sort) | O(n log n) | 1 | O(n log n) |
| L4 (outer loop) | O(1) | n | O(n) |
| L8-L14 (inner scan) | O(1) per step | n per outer | O(n²) ← dominates |
The sort is O(n log n); the double loop is O(n²). The loop dominates for large n.
Complexity
- Time: O(n²), driven by L4/L8 (outer loop times inner scan). Sort is O(n log n); the outer loop x inner scan is O(n²).
- Space: O(n) for the hash set per outer iteration.
Approach 3: Sort + two pointers (optimal)
Sort, then for each anchor i, converge pointers l and r from both ends of the remaining suffix. Same O(n²) time but O(1) extra space.
def three_sum(nums: list[int]) -> list[list[int]]: nums.sort() # L1: O(n log n) n = len(nums) # L2: O(1) result = [] # L3: O(1) for i in range(n - 2): # L4: outer loop if nums[i] > 0: # L5: O(1) prune: all remaining >= 0 break if i > 0 and nums[i] == nums[i - 1]: # L6: O(1) skip duplicate anchor continue l, r = i + 1, n - 1 # L7: O(1) init pointers while l < r: # L8: two-pointer scan s = nums[i] + nums[l] + nums[r] # L9: O(1) sum if s < 0: # L10: O(1) l += 1 # L11: O(1) elif s > 0: # L12: O(1) r -= 1 # L13: O(1) else: result.append([nums[i], nums[l], nums[r]]) # L14: O(1) l += 1 # L15: O(1) r -= 1 # L16: O(1) while l < r and nums[l] == nums[l - 1]: # L17: skip l-dupes l += 1 while l < r and nums[r] == nums[r + 1]: # L18: skip r-dupes r -= 1 return resultWhere the time goes, line by line
Variables: n = len(nums).
| Line | Per-call cost | Times executed | Contribution |
|---|---|---|---|
| L1 (sort) | O(n log n) | 1 | O(n log n) |
| L4 (outer loop) | O(1) | n | O(n) |
| L8-L18 (two-pointer scan) | O(1) per step | n per outer | O(n²) ← dominates |
Each outer iteration does at most one linear pass of the suffix (l and r converge). Dedup skips are absorbed into the linear pass.
Complexity
- Time: O(n²), driven by L4/L8 (outer loop x two-pointer scan). Sort is O(n log n); the outer loop x two-pointer sweep is O(n²).
- Space: O(1) extra (ignoring sort’s stack frames and the output list).
Summary
| Approach | Time | Space |
|---|---|---|
| Brute force + set dedup | O(n³) | O(m) |
| Sort + hash two-sum | O(n²) | O(n) |
| Sort + two pointers | O(n²) | O(1) extra |
Sort + two pointers is the canonical answer; it handles duplicates cleanly and uses no extra memory. The template generalizes to 4Sum and kSum.
Test cases
# Quick smoke tests, paste into a REPL or save as test_3sum.py and run.# Uses the canonical implementation (Approach 3: sort + two pointers).
def three_sum(nums: list[int]) -> list[list[int]]: nums.sort() n = len(nums) result = [] for i in range(n - 2): if nums[i] > 0: break if i > 0 and nums[i] == nums[i - 1]: continue l, r = i + 1, n - 1 while l < r: s = nums[i] + nums[l] + nums[r] if s < 0: l += 1 elif s > 0: r -= 1 else: result.append([nums[i], nums[l], nums[r]]) l += 1 r -= 1 while l < r and nums[l] == nums[l - 1]: l += 1 while l < r and nums[r] == nums[r + 1]: r -= 1 return result
def _run_tests(): def normalize(result): return sorted(tuple(t) for t in result)
assert normalize(three_sum([-1, 0, 1, 2, -1, -4])) == [(-1, -1, 2), (-1, 0, 1)] assert three_sum([0, 1, 1]) == [] assert three_sum([0, 0, 0]) == [[0, 0, 0]] assert three_sum([]) == [] assert three_sum([-2, 0, 0, 2, 2]) == [[-2, 0, 2]] print("all tests pass")
if __name__ == "__main__": _run_tests()Related data structures
- Arrays, sort + two-pointer anchor sweep
- Hash Tables, alternative inner loop via complement lookup