diff --git a/bool.go b/bool.go index d4977a2..0e1fc96 100644 --- a/bool.go +++ b/bool.go @@ -2,7 +2,12 @@ // better performance. package abool -import "sync/atomic" +import ( + "encoding/json" + "errors" + "fmt" + "sync/atomic" +) // New creates an AtomicBool with default set to false. func New() *AtomicBool { @@ -69,3 +74,24 @@ func (ab *AtomicBool) SetToIf(old, new bool) (set bool) { } return atomic.CompareAndSwapInt32((*int32)(ab), o, n) } + +// Marshal an AtomicBool into JSON like a normal bool +func (ab *AtomicBool) MarshalJSON() ([]byte, error) { + return json.Marshal(ab.IsSet()) +} + +// Unmarshall normal bool's into AtomicBool +func (ab *AtomicBool) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch value := v.(type) { + case bool: + ab.SetTo(value) + return nil + default: + return errors.New(fmt.Sprintf("%s is an invalid JSON representation for an AtomicBool\n", b)) + } +} diff --git a/bool_test.go b/bool_test.go index b3db9f1..16bb3b5 100644 --- a/bool_test.go +++ b/bool_test.go @@ -1,6 +1,7 @@ package abool import ( + "encoding/json" "math" "sync" "sync/atomic" @@ -181,6 +182,52 @@ func TestRace(t *testing.T) { wg.Wait() } +func TestFalseUnmarshall(t *testing.T) { + // Marshall a normal bool into JSON byte slice + b, err := json.Marshal(false) + if err != nil { + t.Error(err) + } + + // Create an AtomicBool + v := New() + + // Try to unmarshall the JSON byte slice of a + // a normal bool into an AtomicBool + err = v.UnmarshalJSON(b) + if err != nil { + t.Error(err) + } + + // Check if our AtomicBool is set to false + if v.IsSet() == true { + t.Errorf("Expected AtomicBool to represent false but IsSet() returns true") + } +} + +func TestTrueUnmarshall(t *testing.T) { + // Marshall a normal bool into JSON byte slice + b, err := json.Marshal(true) + if err != nil { + t.Error(err) + } + + // Create an AtomicBool + v := New() + + // Try to unmarshall the JSON byte slice of a + // a normal bool into an AtomicBool + err = v.UnmarshalJSON(b) + if err != nil { + t.Error(err) + } + + // Check if our AtomicBool is set to false + if v.IsSet() == false { + t.Errorf("Expected AtomicBool to represent true but IsSet() returns false. %+v", v) + } +} + func ExampleAtomicBool() { cond := New() // default to false any := true