private void traverse (TreeNode root) {
if (root == null) return;
traverse(root.left);
// Do some business
traverse(root.right);
}
// space: O(logn)
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
TreeNode firstNode = null;
TreeNode secondNode = null;
TreeNode prevNode = new TreeNode(Integer.MIN_VALUE);
public void recoverTree(TreeNode root) {
traverse(root);
int tmp = firstNode.val;
firstNode.val = secondNode.val;
secondNode.val = tmp;
}
private void traverse(TreeNode root) {
if (root == null) return;
traverse(root.left);
if (firstNode == null && prevNode.val >= root.val) {
firstNode = prevNode;
}
if (firstNode != null && prevNode.val >= root.val) {
secondNode = root;
}
prevNode = root;
traverse(root.right);
}
}