diff options
| author | Paul Buetow <git@mx.buetow.org> | 2020-11-15 13:04:55 +0000 |
|---|---|---|
| committer | Paul Buetow <git@mx.buetow.org> | 2020-11-15 13:04:55 +0000 |
| commit | c8915f36887f82ee9b092393289895338df3d7c0 (patch) | |
| tree | 1a3340c377fa86f0414a6b35cf92dac351bb75bc | |
| parent | e912e477538658b535a4ab0adfa7f86b6a33a290 (diff) | |
fixed bst implementation and refactored unit tests
| -rw-r--r-- | go.mod | 2 | ||||
| -rw-r--r-- | set/balancedbst.go | 142 | ||||
| -rw-r--r-- | set/bst.go | 30 | ||||
| -rw-r--r-- | set/set_test.go | 87 |
4 files changed, 91 insertions, 170 deletions
@@ -1,3 +1,5 @@ module github.com/snonux/algorithms go 1.14 + +require golang.org/x/tools/gopls v0.5.2 // indirect diff --git a/set/balancedbst.go b/set/balancedbst.go index 92fe06e..12673c5 100644 --- a/set/balancedbst.go +++ b/set/balancedbst.go @@ -1,143 +1,5 @@ package set -type node struct { - key int - val int - left *node - right *node +func NewBalancedBST() *BST { + return NewBST() } - -type BalancedBST struct { - root *node -} - -func NewBalancedBST() *BalancedBST { - return &BalancedBST{} -} - -func (t *BalancedBST) Empty() bool { - return t.root == nil -} - -func (t *BalancedBST) Set(key, val int) { - if t.Empty() { - t.root = &node{key, val, nil, nil} - return - } - ptr, _, err := t.search(&t.root, key) - switch err { - case nil: - // key already in the tree - return - case NotFound: - *ptr = &node{key, val, nil, nil} - return - default: - panic(err) - } -} - -func (t *BalancedBST) Get(key int) (int, error) { - _, n, err := t.search(&t.root, key) - return n.val, err -} - -func (t *BalancedBST) Del(key int) (int, error) { - ptr, n, err := t.search(&t.root, key) - if err != nil { - return 0, err - } - - // Case 1: n is leaf node - // Case 2: n has one child - // Case 3: n has two childs - - switch { - case n.left == nil: - if n.right == nil { - // I am a leaf node - *ptr = nil - return n.val, nil - } - // I have a right child - *ptr = n.right - return n.val, nil - - case n.right == nil: - // I have a left child - *ptr = n.left - return n.val, nil - default: - // I have two children! - - o, err := t.deleteMin(&n.right) - if err != nil { - return 0, err - } - - o.left = n.left - o.right = n.right - *ptr = o - - return n.val, nil - } -} - -func (t *BalancedBST) search(ptr **node, key int) (**node, *node, error) { - n := *ptr - if n == nil { - return ptr, nil, NotFound - } - - switch { - case key < n.key: - return t.search(&n.left, key) - case n.key < key: - return t.search(&n.right, key) - default: - return ptr, n, nil - } -} - -func (t *BalancedBST) deleteMin(ptr **node) (*node, error) { - ptr, n, err := t.min(ptr) - if err != nil { - return nil, err - } - - *ptr = n.right - n.right = nil - return n, nil -} - -func (t *BalancedBST) min(ptr **node) (**node, *node, error) { - n := *ptr - if n == nil { - return nil, nil, NotFound - } - - for { - if n.left == nil { - return ptr, n, nil - } - ptr = &n.left - n = *ptr - } -} - -/* -func (t *BalancedBST) max(ptr **node) (**node, *node, error) { - n := *ptr - if n == nil { - return nil, nil, NotFound - } - - for { - if n.right == nil { - return ptr, n, nil - } - ptr = &n.right - n = *ptr - } -} -*/ @@ -1,5 +1,7 @@ package set +import "fmt" + type node struct { key int val int @@ -7,6 +9,21 @@ type node struct { right *node } +func (n *node) String() string { + recurse := func(n *node) string { + if n == nil { + return "" + } + return n.String() + } + + return fmt.Sprintf("node{%d:%d,%s,%s}", + n.key, + n.val, + recurse(n.left), + recurse(n.right)) +} + type BST struct { root *node } @@ -15,6 +32,14 @@ func NewBST() *BST { return &BST{} } +func (t *BST) String() string { + if t.Empty() { + return "BST{}" + } + + return fmt.Sprintf("BST{%s}", t.root) +} + func (t *BST) Empty() bool { return t.root == nil } @@ -39,7 +64,10 @@ func (t *BST) Set(key, val int) { func (t *BST) Get(key int) (int, error) { _, n, err := t.search(&t.root, key) - return n.val, err + if err != nil { + return 0, err + } + return n.val, nil } func (t *BST) Del(key int) (int, error) { diff --git a/set/set_test.go b/set/set_test.go index 9c5f102..098947d 100644 --- a/set/set_test.go +++ b/set/set_test.go @@ -10,15 +10,14 @@ import ( const factor int = 10 const minLength int = 1 -const maxLength int = 10 +const maxLength int = 10000 // Store results here to avoid compiler optimizations var benchResult int func TestElementary(t *testing.T) { - s := NewElementary() for i := minLength; i <= maxLength; i *= factor { - test(s, i, t) + test(NewElementary(), i, t) } } @@ -28,41 +27,71 @@ func TestBST(t *testing.T) { } } -func TestBalancedBST(t *testing.T) { - for i := minLength; i <= maxLength; i *= factor { - test(NewBalancedBST(), i, t) - } -} - func test(s Set, l int, t *testing.T) { - cb := func(t *testing.T) { - vals := ds.NewRandomArrayList(l, -1) - keys := ds.NewRandomArrayList(l, -1) - mapping := make(map[int]int, l) - - for i, key := range keys { - val := vals[i] - mapping[key] = val - - t.Log("Inserting", key, val) - s.Set(key, val) + keys := ds.NewRandomArrayList(l, l) + randoms := ds.NewRandomArrayList(l, -1) + mapping := make(map[int]int, l) + + get := func(key int, del bool) int { + var val int + var err error + switch { + case del: + defer delete(mapping, key) + val, err = s.Del(key) + //t.Log("Del", key, val, err) + default: + val, err = s.Get(key) + //t.Log("Get", key, val, err) } - for _, key := range sort.Shuffle(keys) { - val, err := s.Get(key) + if mVal, ok := mapping[key]; ok { if err != nil { - t.Errorf("Element %v->%v: %v\n", key, val, err) + t.Error("Could not get element", key, val, mVal, err) } - - val2 := mapping[key] - if val2 != val { - t.Errorf("Element is %v->%v but expected %v\n", key, val, val2) + if mVal != val { + t.Error("Got wrong value for element", key, val, mVal) } - t.Log("Got", key, val, s) + return val + } + + if err == nil { + t.Error("Got element but expected not to", key, val) } + return val + } + testGet := func(key int) int { return get(key, false) } + testDel := func(key int) int { return get(key, true) } + + testSet := func(key, val int) { + s.Set(key, val) + mapping[key] = val + //t.Log("Set", key, val) + testGet(key) } - t.Run(fmt.Sprintf("%d", l), cb) + t.Log("Set random key-values", l) + var prevKey int + for _, key := range sort.Shuffle(keys) { + testSet(key, randoms[key]) + testGet(prevKey) + prevKey = key + } + t.Log("Del random key-values", l) + for _, key := range sort.Shuffle(keys) { + testDel(key) + testGet(prevKey) + prevKey = key + } + if !s.Empty() { + t.Error("Expected set to be empty", l) + } +} + +func TestBalancedBST(t *testing.T) { + for i := minLength; i <= maxLength; i *= factor { + test(NewBalancedBST(), i, t) + } } func BenchmarkElementary(t *testing.B) { |
