/*
 *
 * Copyright 2016 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package grpc

import (
	"errors"
	"fmt"
	"math/rand"
	"net"
	"sync"
	"time"

	"golang.org/x/net/context"
	"google.golang.org/grpc/codes"
	lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
	"google.golang.org/grpc/grpclog"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/naming"
)

// Client API for LoadBalancer service.
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
	cc *ClientConn
}

func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption) (*balanceLoadClientStream, error) {
	desc := &StreamDesc{
		StreamName:    "BalanceLoad",
		ServerStreams: true,
		ClientStreams: true,
	}
	stream, err := NewClientStream(ctx, desc, c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
	if err != nil {
		return nil, err
	}
	x := &balanceLoadClientStream{stream}
	return x, nil
}

type balanceLoadClientStream struct {
	ClientStream
}

func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error {
	return x.ClientStream.SendMsg(m)
}

func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) {
	m := new(lbmpb.LoadBalanceResponse)
	if err := x.ClientStream.RecvMsg(m); err != nil {
		return nil, err
	}
	return m, nil
}

// NewGRPCLBBalancer creates a grpclb load balancer.
func NewGRPCLBBalancer(r naming.Resolver) Balancer {
	return &balancer{
		r: r,
	}
}

type remoteBalancerInfo struct {
	addr string
	// the server name used for authentication with the remote LB server.
	name string
}

// grpclbAddrInfo consists of the information of a backend server.
type grpclbAddrInfo struct {
	addr      Address
	connected bool
	// dropForRateLimiting indicates whether this particular request should be
	// dropped by the client for rate limiting.
	dropForRateLimiting bool
	// dropForLoadBalancing indicates whether this particular request should be
	// dropped by the client for load balancing.
	dropForLoadBalancing bool
}

type balancer struct {
	r      naming.Resolver
	target string
	mu     sync.Mutex
	seq    int // a sequence number to make sure addrCh does not get stale addresses.
	w      naming.Watcher
	addrCh chan []Address
	rbs    []remoteBalancerInfo
	addrs  []*grpclbAddrInfo
	next   int
	waitCh chan struct{}
	done   bool
	rand   *rand.Rand

	clientStats lbmpb.ClientStats
}

func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
	updates, err := w.Next()
	if err != nil {
		grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err)
		return err
	}
	b.mu.Lock()
	defer b.mu.Unlock()
	if b.done {
		return ErrClientConnClosing
	}
	for _, update := range updates {
		switch update.Op {
		case naming.Add:
			var exist bool
			for _, v := range b.rbs {
				// TODO: Is the same addr with different server name a different balancer?
				if update.Addr == v.addr {
					exist = true
					break
				}
			}
			if exist {
				continue
			}
			md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB)
			if !ok {
				// TODO: Revisit the handling here and may introduce some fallback mechanism.
				grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata)
				continue
			}
			switch md.AddrType {
			case naming.Backend:
				// TODO: Revisit the handling here and may introduce some fallback mechanism.
				grpclog.Errorf("The name resolution does not give grpclb addresses")
				continue
			case naming.GRPCLB:
				b.rbs = append(b.rbs, remoteBalancerInfo{
					addr: update.Addr,
					name: md.ServerName,
				})
			default:
				grpclog.Errorf("Received unknow address type %d", md.AddrType)
				continue
			}
		case naming.Delete:
			for i, v := range b.rbs {
				if update.Addr == v.addr {
					copy(b.rbs[i:], b.rbs[i+1:])
					b.rbs = b.rbs[:len(b.rbs)-1]
					break
				}
			}
		default:
			grpclog.Errorf("Unknown update.Op %v", update.Op)
		}
	}
	// TODO: Fall back to the basic round-robin load balancing if the resulting address is
	// not a load balancer.
	select {
	case <-ch:
	default:
	}
	ch <- b.rbs
	return nil
}

func convertDuration(d *lbmpb.Duration) time.Duration {
	if d == nil {
		return 0
	}
	return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}

func (b *balancer) processServerList(l *lbmpb.ServerList, seq int) {
	if l == nil {
		return
	}
	servers := l.GetServers()
	var (
		sl    []*grpclbAddrInfo
		addrs []Address
	)
	for _, s := range servers {
		md := metadata.Pairs("lb-token", s.LoadBalanceToken)
		ip := net.IP(s.IpAddress)
		ipStr := ip.String()
		if ip.To4() == nil {
			// Add square brackets to ipv6 addresses, otherwise net.Dial() and
			// net.SplitHostPort() will return too many colons error.
			ipStr = fmt.Sprintf("[%s]", ipStr)
		}
		addr := Address{
			Addr:     fmt.Sprintf("%s:%d", ipStr, s.Port),
			Metadata: &md,
		}
		sl = append(sl, &grpclbAddrInfo{
			addr:                 addr,
			dropForRateLimiting:  s.DropForRateLimiting,
			dropForLoadBalancing: s.DropForLoadBalancing,
		})
		addrs = append(addrs, addr)
	}
	b.mu.Lock()
	defer b.mu.Unlock()
	if b.done || seq < b.seq {
		return
	}
	if len(sl) > 0 {
		// reset b.next to 0 when replacing the server list.
		b.next = 0
		b.addrs = sl
		b.addrCh <- addrs
	}
	return
}

func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) {
	ticker := time.NewTicker(interval)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
		case <-done:
			return
		}
		b.mu.Lock()
		stats := b.clientStats
		b.clientStats = lbmpb.ClientStats{} // Clear the stats.
		b.mu.Unlock()
		t := time.Now()
		stats.Timestamp = &lbmpb.Timestamp{
			Seconds: t.Unix(),
			Nanos:   int32(t.Nanosecond()),
		}
		if err := s.Send(&lbmpb.LoadBalanceRequest{
			LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{
				ClientStats: &stats,
			},
		}); err != nil {
			grpclog.Errorf("grpclb: failed to send load report: %v", err)
			return
		}
	}
}

func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	stream, err := lbc.BalanceLoad(ctx)
	if err != nil {
		grpclog.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
		return
	}
	b.mu.Lock()
	if b.done {
		b.mu.Unlock()
		return
	}
	b.mu.Unlock()
	initReq := &lbmpb.LoadBalanceRequest{
		LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{
			InitialRequest: &lbmpb.InitialLoadBalanceRequest{
				Name: b.target,
			},
		},
	}
	if err := stream.Send(initReq); err != nil {
		grpclog.Errorf("grpclb: failed to send init request: %v", err)
		// TODO: backoff on retry?
		return true
	}
	reply, err := stream.Recv()
	if err != nil {
		grpclog.Errorf("grpclb: failed to recv init response: %v", err)
		// TODO: backoff on retry?
		return true
	}
	initResp := reply.GetInitialResponse()
	if initResp == nil {
		grpclog.Errorf("grpclb: reply from remote balancer did not include initial response.")
		return
	}
	// TODO: Support delegation.
	if initResp.LoadBalancerDelegate != "" {
		// delegation
		grpclog.Errorf("TODO: Delegation is not supported yet.")
		return
	}
	streamDone := make(chan struct{})
	defer close(streamDone)
	b.mu.Lock()
	b.clientStats = lbmpb.ClientStats{} // Clear client stats.
	b.mu.Unlock()
	if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
		go b.sendLoadReport(stream, d, streamDone)
	}
	// Retrieve the server list.
	for {
		reply, err := stream.Recv()
		if err != nil {
			grpclog.Errorf("grpclb: failed to recv server list: %v", err)
			break
		}
		b.mu.Lock()
		if b.done || seq < b.seq {
			b.mu.Unlock()
			return
		}
		b.seq++ // tick when receiving a new list of servers.
		seq = b.seq
		b.mu.Unlock()
		if serverList := reply.GetServerList(); serverList != nil {
			b.processServerList(serverList, seq)
		}
	}
	return true
}

func (b *balancer) Start(target string, config BalancerConfig) error {
	b.rand = rand.New(rand.NewSource(time.Now().Unix()))
	// TODO: Fall back to the basic direct connection if there is no name resolver.
	if b.r == nil {
		return errors.New("there is no name resolver installed")
	}
	b.target = target
	b.mu.Lock()
	if b.done {
		b.mu.Unlock()
		return ErrClientConnClosing
	}
	b.addrCh = make(chan []Address)
	w, err := b.r.Resolve(target)
	if err != nil {
		b.mu.Unlock()
		grpclog.Errorf("grpclb: failed to resolve address: %v, err: %v", target, err)
		return err
	}
	b.w = w
	b.mu.Unlock()
	balancerAddrsCh := make(chan []remoteBalancerInfo, 1)
	// Spawn a goroutine to monitor the name resolution of remote load balancer.
	go func() {
		for {
			if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil {
				grpclog.Warningf("grpclb: the naming watcher stops working due to %v.\n", err)
				close(balancerAddrsCh)
				return
			}
		}
	}()
	// Spawn a goroutine to talk to the remote load balancer.
	go func() {
		var (
			cc *ClientConn
			// ccError is closed when there is an error in the current cc.
			// A new rb should be picked from rbs and connected.
			ccError chan struct{}
			rb      *remoteBalancerInfo
			rbs     []remoteBalancerInfo
			rbIdx   int
		)

		defer func() {
			if ccError != nil {
				select {
				case <-ccError:
				default:
					close(ccError)
				}
			}
			if cc != nil {
				cc.Close()
			}
		}()

		for {
			var ok bool
			select {
			case rbs, ok = <-balancerAddrsCh:
				if !ok {
					return
				}
				foundIdx := -1
				if rb != nil {
					for i, trb := range rbs {
						if trb == *rb {
							foundIdx = i
							break
						}
					}
				}
				if foundIdx >= 0 {
					if foundIdx >= 1 {
						// Move the address in use to the beginning of the list.
						b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0]
						rbIdx = 0
					}
					continue // If found, don't dial new cc.
				} else if len(rbs) > 0 {
					// Pick a random one from the list, instead of always using the first one.
					if l := len(rbs); l > 1 && rb != nil {
						tmpIdx := b.rand.Intn(l - 1)
						b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
					}
					rbIdx = 0
					rb = &rbs[0]
				} else {
					// foundIdx < 0 && len(rbs) <= 0.
					rb = nil
				}
			case <-ccError:
				ccError = nil
				if rbIdx < len(rbs)-1 {
					rbIdx++
					rb = &rbs[rbIdx]
				} else {
					rb = nil
				}
			}

			if rb == nil {
				continue
			}

			if cc != nil {
				cc.Close()
			}
			// Talk to the remote load balancer to get the server list.
			var (
				err   error
				dopts []DialOption
			)
			if creds := config.DialCreds; creds != nil {
				if rb.name != "" {
					if err := creds.OverrideServerName(rb.name); err != nil {
						grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v", err)
						continue
					}
				}
				dopts = append(dopts, WithTransportCredentials(creds))
			} else {
				dopts = append(dopts, WithInsecure())
			}
			if dialer := config.Dialer; dialer != nil {
				// WithDialer takes a different type of function, so we instead use a special DialOption here.
				dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer })
			}
			ccError = make(chan struct{})
			cc, err = Dial(rb.addr, dopts...)
			if err != nil {
				grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
				close(ccError)
				continue
			}
			b.mu.Lock()
			b.seq++ // tick when getting a new balancer address
			seq := b.seq
			b.next = 0
			b.mu.Unlock()
			go func(cc *ClientConn, ccError chan struct{}) {
				lbc := &loadBalancerClient{cc}
				b.callRemoteBalancer(lbc, seq)
				cc.Close()
				select {
				case <-ccError:
				default:
					close(ccError)
				}
			}(cc, ccError)
		}
	}()
	return nil
}

func (b *balancer) down(addr Address, err error) {
	b.mu.Lock()
	defer b.mu.Unlock()
	for _, a := range b.addrs {
		if addr == a.addr {
			a.connected = false
			break
		}
	}
}

func (b *balancer) Up(addr Address) func(error) {
	b.mu.Lock()
	defer b.mu.Unlock()
	if b.done {
		return nil
	}
	var cnt int
	for _, a := range b.addrs {
		if a.addr == addr {
			if a.connected {
				return nil
			}
			a.connected = true
		}
		if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing {
			cnt++
		}
	}
	// addr is the only one which is connected. Notify the Get() callers who are blocking.
	if cnt == 1 && b.waitCh != nil {
		close(b.waitCh)
		b.waitCh = nil
	}
	return func(err error) {
		b.down(addr, err)
	}
}

func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
	var ch chan struct{}
	b.mu.Lock()
	if b.done {
		b.mu.Unlock()
		err = ErrClientConnClosing
		return
	}
	seq := b.seq

	defer func() {
		if err != nil {
			return
		}
		put = func() {
			s, ok := rpcInfoFromContext(ctx)
			if !ok {
				return
			}
			b.mu.Lock()
			defer b.mu.Unlock()
			if b.done || seq < b.seq {
				return
			}
			b.clientStats.NumCallsFinished++
			if !s.bytesSent {
				b.clientStats.NumCallsFinishedWithClientFailedToSend++
			} else if s.bytesReceived {
				b.clientStats.NumCallsFinishedKnownReceived++
			}
		}
	}()

	b.clientStats.NumCallsStarted++
	if len(b.addrs) > 0 {
		if b.next >= len(b.addrs) {
			b.next = 0
		}
		next := b.next
		for {
			a := b.addrs[next]
			next = (next + 1) % len(b.addrs)
			if a.connected {
				if !a.dropForRateLimiting && !a.dropForLoadBalancing {
					addr = a.addr
					b.next = next
					b.mu.Unlock()
					return
				}
				if !opts.BlockingWait {
					b.next = next
					if a.dropForLoadBalancing {
						b.clientStats.NumCallsFinished++
						b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
					} else if a.dropForRateLimiting {
						b.clientStats.NumCallsFinished++
						b.clientStats.NumCallsFinishedWithDropForRateLimiting++
					}
					b.mu.Unlock()
					err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr)
					return
				}
			}
			if next == b.next {
				// Has iterated all the possible address but none is connected.
				break
			}
		}
	}
	if !opts.BlockingWait {
		if len(b.addrs) == 0 {
			b.clientStats.NumCallsFinished++
			b.clientStats.NumCallsFinishedWithClientFailedToSend++
			b.mu.Unlock()
			err = Errorf(codes.Unavailable, "there is no address available")
			return
		}
		// Returns the next addr on b.addrs for a failfast RPC.
		addr = b.addrs[b.next].addr
		b.next++
		b.mu.Unlock()
		return
	}
	// Wait on b.waitCh for non-failfast RPCs.
	if b.waitCh == nil {
		ch = make(chan struct{})
		b.waitCh = ch
	} else {
		ch = b.waitCh
	}
	b.mu.Unlock()
	for {
		select {
		case <-ctx.Done():
			b.mu.Lock()
			b.clientStats.NumCallsFinished++
			b.clientStats.NumCallsFinishedWithClientFailedToSend++
			b.mu.Unlock()
			err = ctx.Err()
			return
		case <-ch:
			b.mu.Lock()
			if b.done {
				b.clientStats.NumCallsFinished++
				b.clientStats.NumCallsFinishedWithClientFailedToSend++
				b.mu.Unlock()
				err = ErrClientConnClosing
				return
			}

			if len(b.addrs) > 0 {
				if b.next >= len(b.addrs) {
					b.next = 0
				}
				next := b.next
				for {
					a := b.addrs[next]
					next = (next + 1) % len(b.addrs)
					if a.connected {
						if !a.dropForRateLimiting && !a.dropForLoadBalancing {
							addr = a.addr
							b.next = next
							b.mu.Unlock()
							return
						}
						if !opts.BlockingWait {
							b.next = next
							if a.dropForLoadBalancing {
								b.clientStats.NumCallsFinished++
								b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
							} else if a.dropForRateLimiting {
								b.clientStats.NumCallsFinished++
								b.clientStats.NumCallsFinishedWithDropForRateLimiting++
							}
							b.mu.Unlock()
							err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr)
							return
						}
					}
					if next == b.next {
						// Has iterated all the possible address but none is connected.
						break
					}
				}
			}
			// The newly added addr got removed by Down() again.
			if b.waitCh == nil {
				ch = make(chan struct{})
				b.waitCh = ch
			} else {
				ch = b.waitCh
			}
			b.mu.Unlock()
		}
	}
}

func (b *balancer) Notify() <-chan []Address {
	return b.addrCh
}

func (b *balancer) Close() error {
	b.mu.Lock()
	defer b.mu.Unlock()
	if b.done {
		return errBalancerClosed
	}
	b.done = true
	if b.waitCh != nil {
		close(b.waitCh)
	}
	if b.addrCh != nil {
		close(b.addrCh)
	}
	if b.w != nil {
		b.w.Close()
	}
	return nil
}