From 530a23162549a31baa14dfa3b647a9eccee8878f Mon Sep 17 00:00:00 2001 From: siddontang Date: Tue, 31 Mar 2015 21:40:50 +0800 Subject: [PATCH] add semaphore --- sync2/semaphore.go | 65 +++++++++++++++++++++++++++++++++++++++++ sync2/semaphore_test.go | 41 ++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 sync2/semaphore.go create mode 100644 sync2/semaphore_test.go diff --git a/sync2/semaphore.go b/sync2/semaphore.go new file mode 100644 index 0000000..d310da7 --- /dev/null +++ b/sync2/semaphore.go @@ -0,0 +1,65 @@ +package sync2 + +import ( + "sync" + "sync/atomic" + "time" +) + +func NewSemaphore(initialCount int) *Semaphore { + res := &Semaphore{ + counter: int64(initialCount), + } + res.cond.L = &res.lock + return res +} + +type Semaphore struct { + lock sync.Mutex + cond sync.Cond + counter int64 +} + +func (s *Semaphore) Release() { + s.lock.Lock() + s.counter += 1 + if s.counter >= 0 { + s.cond.Signal() + } + s.lock.Unlock() +} + +func (s *Semaphore) Acquire() { + s.lock.Lock() + for s.counter < 1 { + s.cond.Wait() + } + s.counter -= 1 + s.lock.Unlock() +} + +func (s *Semaphore) AcquireTimeout(timeout time.Duration) bool { + done := make(chan bool, 1) + // Gate used to communicate between the threads and decide what the result + // is. If the main thread decides, we have timed out, otherwise we succeed. + decided := new(int32) + go func() { + s.Acquire() + if atomic.SwapInt32(decided, 1) == 0 { + done <- true + } else { + // If we already decided the result, and this thread did not win + s.Release() + } + }() + select { + case <-done: + return true + case <-time.NewTimer(timeout).C: + if atomic.SwapInt32(decided, 1) == 1 { + // The other thread already decided the result + return true + } + return false + } +} diff --git a/sync2/semaphore_test.go b/sync2/semaphore_test.go new file mode 100644 index 0000000..8c48694 --- /dev/null +++ b/sync2/semaphore_test.go @@ -0,0 +1,41 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync2 + +import ( + "testing" + "time" +) + +func TestSemaNoTimeout(t *testing.T) { + s := NewSemaphore(1) + s.Acquire() + released := false + go func() { + time.Sleep(10 * time.Millisecond) + released = true + s.Release() + }() + s.Acquire() + if !released { + t.Errorf("want true, got false") + } +} + +func TestSemaTimeout(t *testing.T) { + s := NewSemaphore(1) + s.Acquire() + go func() { + time.Sleep(10 * time.Millisecond) + s.Release() + }() + if ok := s.AcquireTimeout(5 * time.Millisecond); ok { + t.Errorf("want false, got true") + } + time.Sleep(10 * time.Millisecond) + if ok := s.AcquireTimeout(5 * time.Millisecond); !ok { + t.Errorf("want true, got false") + } +}