diff options
Diffstat (limited to 'search')
| -rw-r--r-- | search/redblackbst.go | 96 | ||||
| -rw-r--r-- | search/search_test.go | 8 |
2 files changed, 75 insertions, 29 deletions
diff --git a/search/redblackbst.go b/search/redblackbst.go index c1d0600..6644aea 100644 --- a/search/redblackbst.go +++ b/search/redblackbst.go @@ -1,6 +1,8 @@ package search -import "fmt" +import ( + "fmt" +) type linkColor bool @@ -10,12 +12,14 @@ const ( ) type rbNode struct { - key int - val int - color linkColor - size int - left *rbNode - right *rbNode + key int + val int + color linkColor + capacity int + left *rbNode + right *rbNode + // Just mark a node as deleted if deleted. Not fully implemented in lecture. + deleted bool } func (n *rbNode) String() string { @@ -23,14 +27,19 @@ func (n *rbNode) String() string { if n == nil { return "" } - return n.String() + return fmt.Sprintf("\n%s", n.String()) } - return fmt.Sprintf("rbNode{%d:%d,%v,%d,%s,%s}", + color := "red" + if n.color == black { + color = "black" + } + return fmt.Sprintf("rbNode{%v;%d:%d,%s,%d,%s,%s}", + n.deleted, n.key, n.val, - n.color, - n.size, + color, + n.capacity, recurse(n.left), recurse(n.right), ) @@ -43,15 +52,16 @@ func (n *rbNode) isRed() bool { return n.color == red } -func (n *rbNode) Size() int { +func (n *rbNode) Capacity() int { if n == nil { return 0 } - return n.size + return n.capacity } type RedBlackBST struct { root *rbNode + size int } func NewRedBlackBST() *RedBlackBST { @@ -60,21 +70,25 @@ func NewRedBlackBST() *RedBlackBST { func (t *RedBlackBST) String() string { if t.Empty() { - return "RedBlackBST{}" + return fmt.Sprintf("RedBlackBST{%d:%d}", t.Size(), t.Capacity()) } - return fmt.Sprintf("RedBlackBST{%s}", t.root) + return fmt.Sprintf("RedBlackBST{%d:%d;%s}", t.Size(), t.Capacity(), t.root) +} + +func (t *RedBlackBST) Size() int { + return t.size } func (t *RedBlackBST) Empty() bool { - return t.root == nil + return t.Size() == 0 } -func (t *RedBlackBST) Size() int { - if t.Empty() { +func (t *RedBlackBST) Capacity() int { + if t.root == nil { return 0 } - return t.root.size + return t.root.capacity } func (t *RedBlackBST) Put(key, val int) { @@ -84,7 +98,8 @@ func (t *RedBlackBST) Put(key, val int) { func (t *RedBlackBST) put(n *rbNode, key, val int) *rbNode { if n == nil { - return &rbNode{key, val, red, 1, nil, nil} + t.size++ + return &rbNode{key, val, red, 1, nil, nil, false} } switch { @@ -93,6 +108,10 @@ func (t *RedBlackBST) put(n *rbNode, key, val int) *rbNode { case key > n.key: n.right = t.put(n.right, key, val) default: + if n.deleted { + n.deleted = false + } + t.size++ n.val = val } @@ -105,7 +124,7 @@ func (t *RedBlackBST) put(n *rbNode, key, val int) *rbNode { t.flipColors(n) } - n.size = 1 + n.left.Size() + n.right.Size() + n.capacity = 1 + n.left.Capacity() + n.right.Capacity() return n } @@ -124,12 +143,37 @@ func (t *RedBlackBST) get(n *rbNode, key int) (int, error) { case key > n.key: return t.get(n.right, key) default: + if n.deleted { + return 0, NotFound + } return n.val, nil } } func (t *RedBlackBST) Del(key int) (int, error) { - panic("Not yet implemented") + return t.del(t.root, key) +} + +func (t *RedBlackBST) del(n *rbNode, key int) (int, error) { + if n == nil { + return 0, NotFound + } + + switch { + case key < n.key: + return t.del(n.left, key) + case key > n.key: + return t.del(n.right, key) + default: + if n.deleted { + return 0, NotFound + } + t.size-- + n.deleted = true + val := n.val + n.val = -1 + return val, nil + } } func (t *RedBlackBST) rotateLeft(n *rbNode) *rbNode { @@ -140,8 +184,8 @@ func (t *RedBlackBST) rotateLeft(n *rbNode) *rbNode { x.color = n.color n.color = red - x.size = n.size - n.size = 1 + n.left.Size() + n.right.Size() + x.capacity = n.capacity + n.capacity = 1 + n.left.Capacity() + n.right.Capacity() return x } @@ -154,8 +198,8 @@ func (t *RedBlackBST) rotateRight(n *rbNode) *rbNode { x.color = n.color n.color = red - x.size = n.size - n.size = 1 + n.left.Size() + n.right.Size() + x.capacity = n.capacity + n.capacity = 1 + n.left.Capacity() + n.right.Capacity() return x } diff --git a/search/search_test.go b/search/search_test.go index bf1d1b5..09ed906 100644 --- a/search/search_test.go +++ b/search/search_test.go @@ -10,7 +10,9 @@ import ( const factor int = 10 const minLength int = 1 -const maxLength int = 10000 +const maxLength int = 10 + +//const maxLength int = 10000 // Store results here to avoid compiler optimizations var benchResult int @@ -39,10 +41,10 @@ func test(s Put, l int, t *testing.T) { case del: defer delete(mapping, key) val, err = s.Del(key) - //t.Log("Del", key, val, err) + t.Log("Del", key, val, err) default: val, err = s.Get(key) - //t.Log("Get", key, val, err) + t.Log("Get", key, val, err) } if mVal, ok := mapping[key]; ok { |
