Red Black BST Algorithm
A red – black tree is a kind of self-balancing binary search tree in computer science. When the tree is modify, the new tree is subsequently rearranged and repainted to restore the coloring property. In a 1978 paper," A Dichromatic Framework for Balanced Trees", Leonidas J. Guibas and Robert Sedgewick derived the red-black tree from the symmetric binary B-tree. Sedgewick originally allowed nodes whose two children are red, make his trees more like 2-3-4 trees, but later this restriction was added, make new trees more like 2-3 trees. These trees maintained all paths from root to leaf with the same number of nodes, make perfectly balanced trees.
package org.gs.symboltable
import math.Ordering
import scala.annotation.tailrec
/** Red Black Node
*
* @tparam A generic key
* @tparam B generic value
* @param count number of subtrees
* @param red true if link to parent is red false if black
* @see [[https://algs4.cs.princeton.edu/33balanced/RedBlackBST.java.html]]
* @author Scala translation by Gary Struthers from Java by Robert Sedgewick and Kevin Wayne.
*/
sealed class Node[A, B](var key: A, var value: B, var count: Int = 1, var red: Boolean = true) {
var left = null.asInstanceOf[Node[A, B]]
var right = null.asInstanceOf[Node[A, B]]
}
/** Balanced search tree with Red/Black nodes
*
* @tparam A generic key type
* @tparam B generic value type
* @param ord implicit Ordering
* @author Scala translation by Gary Struthers from Java by Robert Sedgewick and Kevin Wayne.
*/
class RedBlackBST[A, B](implicit ord: Ordering[A]) {
private var root = null.asInstanceOf[Node[A, B]]
private def isRed(x: Node[A, B]): Boolean = if ((x == null) || (x.red == false)) false else true
/** Make h.right the new root of subtree
*
* if h.right is red rotate left so h becomes the left child and h.right becomes the parent
*/
private def rotateLeft(h: Node[A, B]): Node[A, B] = {
assert(h != null && isRed(h.right), "error: black or null passed to rotateLeft")
val x = h.right
h.right = x.left
x.left = h
x.red = x.left.red
x.left.red = true
x.count = h.count
h.count = 1 + size(h.left) + size(h.right)
x
}
/** Make h.left the new root of subtree
*
* if h.left is red rotate right so h becomes the right child and h.left becomes the parent
*/
private def rotateRight(h: Node[A, B]): Node[A, B] = {
assert(h != null && isRed(h.left), "error: black or null passed to rotateRight")
val x = h.left
h.left = x.right
x.right = h
x.red = x.right.red
x.right.red = true
x.count = h.count
h.count = 1 + size(h.left) + size(h.right)
x
}
/** if both children are red make them black and h red */
private def flipColors(h: Node[A, B]): Unit = {
assert(h != null && h.left != null && h.right != null, "null node passed to flip colors")
assert(!isRed(h) && isRed(h.left) && isRed(h.right) ||
isRed(h) && !isRed(h.left) && !isRed(h.right), "error: flipColors root color must not equal both child colors")
h.red = !h.red
h.left.red = !h.left.red
h.right.red = !h.right.red
}
/** insert key value into tree
*
* overwrite if key already there
*/
def put(key: A, value: B): Unit = {
def loop(x: Node[A, B]): Node[A, B] = {
if (x == null) new Node(key, value) else {
ord.compare(key, x.key) match {
case 0 => x.value = value
case n if (n < 0) => x.left = loop(x.left)
case _ => x.right = loop(x.right)
}
x.count += 1
if (isRed(x.right) && !isRed(x.left)) rotateLeft(x) else {
val j = if (isRed(x.left) && isRed(x.left.left)) rotateRight(x) else x
if (isRed(j.left) && isRed(j.right)) {
flipColors(j)
}
j
}
}
}
root = loop(root)
root.red = false
}
/** get value for key if present */
def get(key: A): Option[B] = {
@tailrec
def loop(x: Node[A, B]): Option[B] = if (x == null) None else {
ord.compare(key, x.key) match {
case 0 => Some(x.value)
case n if (n < 0) => loop(x.left)
case _ => loop(x.right)
}
}
loop(root)
}
/** delete the key */
def delete(key: A): Unit = {
if (!isRed(root.left) && !isRed(root.right)) {
root.red = true
}
def loop(x: Node[A, B], key: A): Node[A, B] = {
val h = ord.compare(key, x.key) match {
case n if (n < 0) => {
val j = if (!isRed(x.left) && !isRed(x.left.left)) moveRedLeft(x) else x
j.left = loop(j.left, key)
j
}
case _ => {
val j = if (isRed(x.left)) rotateRight(x) else x
ord.compare(key, j.key) match {
case 0 if (j.right == null) => return null
case _ => {
val k = if (!isRed(j.right) && !isRed(j.right.left)) moveRedRight(j) else j
if (ord.compare(key, k.key) == 0) {
val y = min(k.right)
k.key = y.key
k.value = y.value
k.right = deleteMin(k.right)
} else {
k.right = loop(k.right, key)
}
k
}
}
}
}
balance(h)
}
root = loop(root, key)
if (!isEmpty) root.red = false
}
private def moveRedRight(h: Node[A, B]): Node[A, B] = {
assert(h != null, "null passed to moveRedRight")
assert(isRed(h) && !isRed(h.right) && !isRed(h.right.left), "error: moveRedRight colors")
flipColors(h)
if (!isRed(h.left.left)) rotateRight(h) else h
}
private def moveRedLeft(h: Node[A, B]): Node[A, B] = {
assert(h != null, "null passed to moveRedLeft")
assert(isRed(h) && !isRed(h.left) && !isRed(h.left), "error: moveRedLeft colors")
flipColors(h)
if (isRed(h.right.left)) {
h.right = rotateRight(h.right)
rotateLeft(h)
} else h
}
private def balance(h: Node[A, B]): Node[A, B] = {
assert(h != null, "null passed to balance")
val x = if (isRed(h.right)) rotateLeft(h) else h
val y = if (isRed(x.left) && isRed(x.left.left)) rotateRight(x) else x
if (isRed(y.left) && isRed(y.right)) flipColors(y)
y.count = 1 + size(y.left) + size(y.right)
y
}
private def deleteMin(h: Node[A, B]): Node[A, B] = {
if (h.left == null) null else {
val j = if (!isRed(h.left) && !isRed(h.left.left)) moveRedLeft(h) else h
j.left = deleteMin(j.left)
balance(j)
}
}
/** delete minimum key */
def deleteMin(): Unit = {
if (!isRed(root.left) && !isRed(root.right)) root.red = true
root = deleteMin(root)
if (!isEmpty) root.red = false
}
/** delete maximum key */
def deleteMax(): Unit = {
def deleteMax(h: Node[A, B]): Node[A, B] = {
val j = if (isRed(h.left)) rotateRight(h) else h
if (j.right == null) null.asInstanceOf[Node[A, B]] else {
val m = if (!isRed(j.right) && !isRed(j.right.left)) moveRedRight(j) else j
m.right = deleteMax(m.right)
balance(m)
}
}
if (!isRed(root.left) && !isRed(root.right)) root.red = true
root = deleteMax(root)
if (!isEmpty) root.red = false
}
private def size(x: Node[A, B]): Int = if (x == null) 0 else x.count
/** subtree count */
def size(): Int = size(root)
/** number of keys in lo..hi */
def size(lo: A, hi: A): Int = {
if (ord.compare(lo, hi) > 0) 0 else if (contains(hi)) rank(hi) - rank(lo) + 1 else rank(hi) - rank(lo)
}
/** is key present */
def contains(key: A): Boolean = get(key) match {
case None => false
case Some(x) => if (x == null) false else true
}
/** are any keys in tree */
def isEmpty(): Boolean = if (root == null) true else false
@tailrec
private def min(x: Node[A, B]): Node[A, B] = {
assert(x != null, "null passed to min")
if (x.left == null) x else min(x.left)
}
/** returns smallest key */
def min(): A = min(root).key
/** returns largest key */
def max(): A = {
@tailrec
def max(x: Node[A, B]): Node[A, B] = {
assert(x != null, "null passed to max")
if (x.right == null) x else max(x.right)
}
max(root).key
}
/** returns largest key less than or equal to key */
def floor(key: A): A = {
def loop(x: Node[A, B]): Node[A, B] = {
if (x == null) null else ord.compare(key, x.key) match {
case 0 => x
case n if (n < 0) => loop(x.left)
case _ => {
val t = loop(x.right)
if (t != null) t else x // t in right tree, x is subroot
}
}
}
loop(root).key
}
/** returns smallest key greater than or equal to key */
def ceiling(key: A): A = {
def loop(x: Node[A, B]): Node[A, B] = if (x == null) null.asInstanceOf[Node[A, B]] else {
val cmp = ord.compare(key, x.key)
if (cmp == 0) x else {
if (cmp < 0) {
val t = loop(x.left)
if (t != null) t else x
} else loop(x.right)
}
}
loop(root).key
}
/** returns number of keys less than key */
def rank(key: A): Int = {
def loop(x: Node[A, B]): Int = if (x == null) 0
else {
ord.compare(key, x.key) match {
case cmp if(cmp == 0) => size(x.left)
case cmp if(cmp < 0) => loop(x.left)
case _ => 1 + size(x.left) + loop(x.right)
}
}
loop(root)
}
/** returns key of rank k */
def select(rank: Int): Option[A] = {
def select(x: Node[A, B], k: Int): Node[A, B] = {
val t = size(x.left)
if (t > k) select(x.left, k) else if (t < k) select(x.right, k - t - 1) else x
}
val node = select(root, rank)
if (node == null) None else Some(node.key)
}
import scala.collection.mutable.Queue
/** returns all keys */
def keys(): List[A] = {
val q = Queue[A]()
val lo = min
val hi = max
def loop(x: Node[A, B]) {
if (x != null) {
val cmpLo = ord.compare(lo, x.key)
if (cmpLo < 0) loop(x.left)
val cmpHi = ord.compare(hi, x.key)
if (cmpLo <= 0 && cmpHi >= 0) q.enqueue(x.key)
if (cmpHi > 0) loop(x.right)
}
}
loop(root)
q.toList
}
/** debugging code */
/** print nodes left to right
* @param full all node arguments if true just key if false
* @return tree nodes as string
*/
def inorderTreeWalk(full: Boolean = false): String = {
val sb = new StringBuilder
def loop(x: Node[A, B]) {
if (x != null) {
loop(x.left)
if (full) sb append (s" key:${x.key} value:${x.value} count:${x.count} red:${x.red}")
else sb append (s" ${x.key}")
loop(x.right)
}
}
sb append (s" root ${root.key} ")
loop(root)
sb.toString
}
/** returns string of nodes from top down by level */
def levelOrderTreeWalk(): String = {
val sb = new StringBuilder()
val q = new Queue[Node[A, B]]()
q.enqueue(root)
@tailrec
def loop(): Unit = {
if (!q.isEmpty) {
val n = q.dequeue
sb append (f" ${n.key} ${n.red}, ${n.count} \n")
if (n.left != null) q.enqueue(n.left)
if (n.right != null) q.enqueue(n.right)
loop()
}
}
loop()
sb.toString
}
/** returns all node args left to right as string */
override def toString(): String = inorderTreeWalk(true)
/** does this satisfy requirements for balanced search tree */
def isBST(): Boolean = {
def loop(x: Node[A, B], min: A, max: A): Boolean = {
if (x == null) true else {
if (min != null && ord.compare(x.key, min) <= 0) false
else if (max != null && ord.compare(x.key, max) >= 0) false else {
loop(x.left, min, x.key) && loop(x.right, x.key, max)
}
}
}
loop(root, null.asInstanceOf[A], null.asInstanceOf[A])
}
/** are field sizes correct */
def isSizeConsistent(): Boolean = {
def loop(x: Node[A, B]): Boolean = x match {
case _ if(x == null) => true
case _ if (x.count != size(x.left) + size(x.right) + 1) => false
case _ => loop(x.left) && loop(x.right)
}
loop(root)
}
def isRankConsistent(): Boolean = {
def checkRank: Boolean = {
@tailrec
def loop(i: Int): Boolean = {
if (i == size) true else {
val goodRank = select(i) match {
case None => true
case Some(x) if (rank(x) != i) => false
case Some(x) => true
}
if (goodRank) loop(i + 1) else false
}
}
loop(0)
}
def checkKeys: Boolean = {
@tailrec
def loop(keys: List[A]): Boolean = keys match {
case Nil => true
case x :: xs => {
val goodKey = select(rank(x)) match {
case None => true
case Some(k) if (ord.compare(x, k) != 0) => false
case Some(k) => true
}
if (goodKey) loop(xs) else false
}
}
loop(keys)
}
checkRank && checkKeys
}
/** are red and black links correct */
def is23(): Boolean = {
def loop(x: Node[A, B]): Boolean = x match {
case _ if(x == null) => true
case _ if(isRed(x.right)) => false
case _ if(x != root && isRed(x) && isRed(x.left)) => false
case _ => loop(x.left) && loop(x.right)
}
loop(root)
}
/** do all paths from root have same number of black links */
def isBalanced(): Boolean = {
@tailrec
def loopRB(n: Node[A, B], black: Int): Int = n match {
case _ if(n == null) => black
case _ if(!isRed(n)) => loopRB(n.left, black + 1)
case _ => loopRB(n.left, black)
}
val black = loopRB(root, 0)
def loop(x: Node[A, B], black: Int): Boolean = x match {
case _ if(x == null) => black == 0
case _ if(!isRed(x)) => loop(x.left, black - 1) && loop(x.right, black - 1)
case _ => loop(x.left, black) && loop(x.right, black)
}
loop(root, black)
}
}