From 571c58a20097c5cd7ec666ecd24eb75e827a7b3f Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Sat, 5 Dec 2020 09:53:16 +0000 Subject: initial redblack bst --- search/balancedbst.go | 5 -- search/redblackbst.go | 209 ++++++++++++++++++++++++++++++++++++++++++++++++++ search/search_test.go | 130 +++++++++++++++++++++++++++++++ search/set_test.go | 130 ------------------------------- 4 files changed, 339 insertions(+), 135 deletions(-) delete mode 100644 search/balancedbst.go create mode 100644 search/redblackbst.go create mode 100644 search/search_test.go delete mode 100644 search/set_test.go diff --git a/search/balancedbst.go b/search/balancedbst.go deleted file mode 100644 index 12673c5..0000000 --- a/search/balancedbst.go +++ /dev/null @@ -1,5 +0,0 @@ -package set - -func NewBalancedBST() *BST { - return NewBST() -} diff --git a/search/redblackbst.go b/search/redblackbst.go new file mode 100644 index 0000000..167f567 --- /dev/null +++ b/search/redblackbst.go @@ -0,0 +1,209 @@ +package search + +import "fmt" + +type linkColor bool + +const ( + red linkColor = true + black linkColor = false +) + +type rbNode struct { + key int + val int + left *rbNode + right *rbNode + color linkColor +} + +func (n *rbNode) String() string { + recurse := func(n *rbNode) string { + if n == nil { + return "" + } + return n.String() + } + + return fmt.Sprintf("rbNode{%d:%d,%v,%s,%s}", + n.key, + n.val, + n.color, + recurse(n.left), + recurse(n.right), + ) +} + +type RedBlackBST struct { + root *rbNode +} + +func NewRedBlackBST() *RedBlackBST { + return &RedBlackBST{} +} + +func (t *RedBlackBST) String() string { + if t.Empty() { + return "RedBlackBST{}" + } + + return fmt.Sprintf("RedBlackBST{%s}", t.root) +} + +func (t *RedBlackBST) Empty() bool { + return t.root == nil +} + +func (t *RedBlackBST) Put(key, val int) { + if t.Empty() { + t.root = &rbNode{key, val, nil, nil, black} + return + } + parent, ptr, _, err := t.search(nil, &t.root, key) + switch err { + case nil: + // key already in the tree + return + case NotFound: + *ptr = &rbNode{key, val, nil, nil, red} + return + default: + panic(err) + } + + t.balance(parent) +} + +func (t *RedBlackBST) balance(n *rbNode) { + switch { + case n == nil, n.right == nil: + return + case n.right.color == black: + return + case n.left.color == black: + // Left is black and right is red + case n.left.color == red: + // Left and right are red + } +} + +func (t *RedBlackBST) Get(key int) (int, error) { + _, _, n, err := t.search(nil, &t.root, key) + if err != nil { + return 0, err + } + return n.val, nil +} + +func (t *RedBlackBST) Del(key int) (int, error) { + _, ptr, n, err := t.search(nil, &t.root, key) + if err != nil { + return 0, err + } + + // Case 1: n is leaf rbNode + // Case 2: n has one child + // Case 3: n has two childs + + switch { + case n.left == nil: + if n.right == nil { + // I am a leaf rbNode + *ptr = nil + return n.val, nil + } + // I have a right child + *ptr = n.right + return n.val, nil + + case n.right == nil: + // I have a left child + *ptr = n.left + return n.val, nil + default: + // I have two children! + + o, err := t.deleteMin(&n.right) + if err != nil { + return 0, err + } + + o.left = n.left + o.right = n.right + *ptr = o + + return n.val, nil + } +} + +func (t *RedBlackBST) search(parent *rbNode, ptr **rbNode, key int) (*rbNode, **rbNode, *rbNode, error) { + n := *ptr + if n == nil { + return parent, ptr, nil, NotFound + } + + switch { + case key < n.key: + return t.search(n, &n.left, key) + case n.key < key: + return t.search(n, &n.right, key) + default: + return parent, ptr, n, nil + } +} + +func (t *RedBlackBST) deleteMin(ptr **rbNode) (*rbNode, error) { + ptr, n, err := t.min(ptr) + if err != nil { + return nil, err + } + + *ptr = n.right + n.right = nil + return n, nil +} + +func (t *RedBlackBST) min(ptr **rbNode) (**rbNode, *rbNode, error) { + n := *ptr + if n == nil { + return nil, nil, NotFound + } + + for { + if n.left == nil { + return ptr, n, nil + } + ptr = &n.left + n = *ptr + } +} + +func (t *RedBlackBST) rotateLeft(ptr **rbNode) { + x := *ptr + y := x.right + + x.right = y.left + x.left = y + + *ptr = y + x.color = red + y.color = black +} + +func (t *RedBlackBST) rotateRight(ptr **rbNode) { + y := *ptr + x := y.right + + y.left = x.right + x.right = y + + *ptr = x + x.color = black + y.color = red +} + +func (t *RedBlackBST) flipColors(n *rbNode) { + n.left.color = black + n.right.color = black + n.color = red +} diff --git a/search/search_test.go b/search/search_test.go new file mode 100644 index 0000000..098947d --- /dev/null +++ b/search/search_test.go @@ -0,0 +1,130 @@ +package set + +import ( + "fmt" + "testing" + + "github.com/snonux/algorithms/ds" + "github.com/snonux/algorithms/sort" +) + +const factor int = 10 +const minLength int = 1 +const maxLength int = 10000 + +// Store results here to avoid compiler optimizations +var benchResult int + +func TestElementary(t *testing.T) { + for i := minLength; i <= maxLength; i *= factor { + test(NewElementary(), i, t) + } +} + +func TestBST(t *testing.T) { + for i := minLength; i <= maxLength; i *= factor { + test(NewBST(), i, t) + } +} + +func test(s Set, l int, t *testing.T) { + keys := ds.NewRandomArrayList(l, l) + randoms := ds.NewRandomArrayList(l, -1) + mapping := make(map[int]int, l) + + get := func(key int, del bool) int { + var val int + var err error + switch { + case del: + defer delete(mapping, key) + val, err = s.Del(key) + //t.Log("Del", key, val, err) + default: + val, err = s.Get(key) + //t.Log("Get", key, val, err) + } + + if mVal, ok := mapping[key]; ok { + if err != nil { + t.Error("Could not get element", key, val, mVal, err) + } + if mVal != val { + t.Error("Got wrong value for element", key, val, mVal) + } + return val + } + + if err == nil { + t.Error("Got element but expected not to", key, val) + } + return val + } + testGet := func(key int) int { return get(key, false) } + testDel := func(key int) int { return get(key, true) } + + testSet := func(key, val int) { + s.Set(key, val) + mapping[key] = val + //t.Log("Set", key, val) + testGet(key) + } + + t.Log("Set random key-values", l) + var prevKey int + for _, key := range sort.Shuffle(keys) { + testSet(key, randoms[key]) + testGet(prevKey) + prevKey = key + } + t.Log("Del random key-values", l) + for _, key := range sort.Shuffle(keys) { + testDel(key) + testGet(prevKey) + prevKey = key + } + if !s.Empty() { + t.Error("Expected set to be empty", l) + } +} + +func TestBalancedBST(t *testing.T) { + for i := minLength; i <= maxLength; i *= factor { + test(NewBalancedBST(), i, t) + } +} + +func BenchmarkElementary(t *testing.B) { + s := NewElementary() + for i := minLength; i <= maxLength; i *= factor { + benchmark(s, i, t) + } +} + +func BenchmarkBST(t *testing.B) { + s := NewBST() + for i := minLength; i <= maxLength; i *= factor { + benchmark(s, i, t) + } +} + +func BenchmarkBalancedBST(t *testing.B) { + s := NewBalancedBST() + for i := minLength; i <= maxLength; i *= factor { + benchmark(s, i, t) + } +} + +func benchmark(s Set, l int, b *testing.B) { + list := ds.NewRandomArrayList(l, -1) + + b.Run(fmt.Sprintf("random(%d)", l), func(b *testing.B) { + b.ResetTimer() + for i, a := range list { + s.Set(a, i) + } + for _, a := range list { + benchResult, _ = s.Get(a) + } + }) +} diff --git a/search/set_test.go b/search/set_test.go deleted file mode 100644 index 098947d..0000000 --- a/search/set_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package set - -import ( - "fmt" - "testing" - - "github.com/snonux/algorithms/ds" - "github.com/snonux/algorithms/sort" -) - -const factor int = 10 -const minLength int = 1 -const maxLength int = 10000 - -// Store results here to avoid compiler optimizations -var benchResult int - -func TestElementary(t *testing.T) { - for i := minLength; i <= maxLength; i *= factor { - test(NewElementary(), i, t) - } -} - -func TestBST(t *testing.T) { - for i := minLength; i <= maxLength; i *= factor { - test(NewBST(), i, t) - } -} - -func test(s Set, l int, t *testing.T) { - keys := ds.NewRandomArrayList(l, l) - randoms := ds.NewRandomArrayList(l, -1) - mapping := make(map[int]int, l) - - get := func(key int, del bool) int { - var val int - var err error - switch { - case del: - defer delete(mapping, key) - val, err = s.Del(key) - //t.Log("Del", key, val, err) - default: - val, err = s.Get(key) - //t.Log("Get", key, val, err) - } - - if mVal, ok := mapping[key]; ok { - if err != nil { - t.Error("Could not get element", key, val, mVal, err) - } - if mVal != val { - t.Error("Got wrong value for element", key, val, mVal) - } - return val - } - - if err == nil { - t.Error("Got element but expected not to", key, val) - } - return val - } - testGet := func(key int) int { return get(key, false) } - testDel := func(key int) int { return get(key, true) } - - testSet := func(key, val int) { - s.Set(key, val) - mapping[key] = val - //t.Log("Set", key, val) - testGet(key) - } - - t.Log("Set random key-values", l) - var prevKey int - for _, key := range sort.Shuffle(keys) { - testSet(key, randoms[key]) - testGet(prevKey) - prevKey = key - } - t.Log("Del random key-values", l) - for _, key := range sort.Shuffle(keys) { - testDel(key) - testGet(prevKey) - prevKey = key - } - if !s.Empty() { - t.Error("Expected set to be empty", l) - } -} - -func TestBalancedBST(t *testing.T) { - for i := minLength; i <= maxLength; i *= factor { - test(NewBalancedBST(), i, t) - } -} - -func BenchmarkElementary(t *testing.B) { - s := NewElementary() - for i := minLength; i <= maxLength; i *= factor { - benchmark(s, i, t) - } -} - -func BenchmarkBST(t *testing.B) { - s := NewBST() - for i := minLength; i <= maxLength; i *= factor { - benchmark(s, i, t) - } -} - -func BenchmarkBalancedBST(t *testing.B) { - s := NewBalancedBST() - for i := minLength; i <= maxLength; i *= factor { - benchmark(s, i, t) - } -} - -func benchmark(s Set, l int, b *testing.B) { - list := ds.NewRandomArrayList(l, -1) - - b.Run(fmt.Sprintf("random(%d)", l), func(b *testing.B) { - b.ResetTimer() - for i, a := range list { - s.Set(a, i) - } - for _, a := range list { - benchResult, _ = s.Get(a) - } - }) -} -- cgit v1.2.3