From 17ec4703405edb7534621276291752de416fea52 Mon Sep 17 00:00:00 2001 From: Grzegorz Dlugoszewski Date: Tue, 19 May 2020 10:24:25 +0200 Subject: [PATCH] Add upstream info to branch status, simplify unit tests --- branch.go | 32 ++++++++++++++++++++++++-------- branch_test.go | 22 +++++++++++++--------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/branch.go b/branch.go index 88961f2..ee7504c 100644 --- a/branch.go +++ b/branch.go @@ -33,7 +33,7 @@ func Branches(repo *git.Repository) ([]BranchStatus, error) { var statuses []BranchStatus for _, branch := range branches { - status, err := NewBranchStatus(branch) + status, err := NewBranchStatus(repo, branch) if err != nil { // TODO: handle error continue @@ -44,7 +44,7 @@ func Branches(repo *git.Repository) ([]BranchStatus, error) { return statuses, nil } -func NewBranchStatus(branch *git.Branch) (BranchStatus, error) { +func NewBranchStatus(repo *git.Repository, branch *git.Branch) (BranchStatus, error) { var status BranchStatus name, err := branch.Name() @@ -55,12 +55,28 @@ func NewBranchStatus(branch *git.Branch) (BranchStatus, error) { status.IsRemote = branch.IsRemote() - _, err = branch.Upstream() - if err != nil { - if git.IsErrorCode(err, git.ErrNotFound) { - status.HasUpstream = false - } else { - return status, errors.Wrap(err, "Failed getting branch upstream") + upstream, err := branch.Upstream() + if err != nil && !git.IsErrorCode(err, git.ErrNotFound) { + return status, errors.Wrap(err, "Failed getting branch upstream") + } + + if upstream != nil { + status.HasUpstream = true + + ahead, behind, err := repo.AheadBehind(branch.Target(), upstream.Target()) + if err != nil { + return status, errors.Wrap(err, "Failed getting ahead/behind information") + } + + status.Ahead = ahead + status.Behind = behind + + if ahead > 0 { + status.NeedsPush = true + } + + if behind > 0 { + status.NeedsPull = true } } diff --git a/branch_test.go b/branch_test.go index 15fffa5..9eb2d99 100644 --- a/branch_test.go +++ b/branch_test.go @@ -2,11 +2,9 @@ package main import ( "testing" - - "github.com/pkg/errors" ) -func TestNewBranch(t *testing.T) { +func TestNewLocalBranch(t *testing.T) { repo := newTestRepo(t) createFile(t, repo, "file") @@ -14,14 +12,20 @@ func TestNewBranch(t *testing.T) { createCommit(t, repo, "Initial commit") branch := createBranch(t, repo, "branch") - status, err := NewBranchStatus(branch) - checkFatal(t, errors.Wrap(err, "Failed getting branch status")) + status, err := NewBranchStatus(repo, branch) + checkFatal(t, err) - if status.Name != "branch" { - t.Errorf("Wrong branch name, got %s; want %s", status.Name, "branch") + want := BranchStatus{ + Name: "branch", + IsRemote: false, + HasUpstream: false, + NeedsPull: false, + NeedsPush: false, + Ahead: 0, + Behind: 0, } - if status.IsRemote != false { - t.Errorf("Branch should be local") + if status != want { + t.Errorf("Wrong branch status, got %+v; want %+v", status, want) } }