diff --git a/cmd/get/main.go b/cmd/get/main.go index 97313d2..9b19fa5 100644 --- a/cmd/get/main.go +++ b/cmd/get/main.go @@ -27,6 +27,7 @@ var cmd = &cobra.Command{ func init() { cmd.PersistentFlags().StringP(cfg.KeyBranch, "b", "", "Branch (or tag) to checkout after cloning.") cmd.PersistentFlags().StringP(cfg.KeyDefaultHost, "t", cfg.Defaults[cfg.KeyDefaultHost], "Host to use when doesn't have a specified host.") + cmd.PersistentFlags().StringP(cfg.KeyDefaultScheme, "c", cfg.Defaults[cfg.KeyDefaultScheme], "Scheme to use when doesn't have a specified scheme.") cmd.PersistentFlags().StringP(cfg.KeyDump, "d", "", "Path to a dump file listing repos to clone. Ignored when argument is used.") cmd.PersistentFlags().BoolP(cfg.KeySkipHost, "s", false, "Don't create a directory for host.") cmd.PersistentFlags().StringP(cfg.KeyReposRoot, "r", cfg.Defaults[cfg.KeyReposRoot], "Path to repos root where repositories are cloned.") @@ -35,6 +36,7 @@ func init() { viper.BindPFlag(cfg.KeyBranch, cmd.PersistentFlags().Lookup(cfg.KeyBranch)) viper.BindPFlag(cfg.KeyDefaultHost, cmd.PersistentFlags().Lookup(cfg.KeyDefaultHost)) + viper.BindPFlag(cfg.KeyDefaultScheme, cmd.PersistentFlags().Lookup(cfg.KeyDefaultScheme)) viper.BindPFlag(cfg.KeyDump, cmd.PersistentFlags().Lookup(cfg.KeyDump)) viper.BindPFlag(cfg.KeyReposRoot, cmd.PersistentFlags().Lookup(cfg.KeyReposRoot)) viper.BindPFlag(cfg.KeySkipHost, cmd.PersistentFlags().Lookup(cfg.KeySkipHost)) @@ -51,12 +53,13 @@ func run(cmd *cobra.Command, args []string) error { cfg.Expand(cfg.KeyReposRoot) config := &pkg.GetCfg{ - Branch: viper.GetString(cfg.KeyBranch), - DefHost: viper.GetString(cfg.KeyDefaultHost), - Dump: viper.GetString(cfg.KeyDump), - SkipHost: viper.GetBool(cfg.KeySkipHost), - Root: viper.GetString(cfg.KeyReposRoot), - URL: url, + Branch: viper.GetString(cfg.KeyBranch), + DefHost: viper.GetString(cfg.KeyDefaultHost), + DefScheme: viper.GetString(cfg.KeyDefaultScheme), + Dump: viper.GetString(cfg.KeyDump), + SkipHost: viper.GetBool(cfg.KeySkipHost), + Root: viper.GetString(cfg.KeyReposRoot), + URL: url, } return pkg.Get(config) } diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 79b9171..2caa2c5 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -17,21 +17,22 @@ const GitgetPrefix = "gitget" // CLI flag keys. var ( - KeyBranch = "branch" - KeyDump = "dump" - KeyDefaultHost = "host" - KeyFetch = "fetch" - KeyOutput = "out" - KeySkipHost = "skip-host" - KeyReposRoot = "root" + KeyBranch = "branch" + KeyDump = "dump" + KeyDefaultHost = "host" + KeyFetch = "fetch" + KeyOutput = "out" + KeyDefaultScheme = "scheme" + KeySkipHost = "skip-host" + KeyReposRoot = "root" ) // Defaults is a map of default values for config keys. var Defaults = map[string]string{ - KeyDefaultHost: "github.com", - KeyOutput: OutTree, - KeyReposRoot: fmt.Sprintf("~%c%s", filepath.Separator, "repositories"), - // KeySkipHost: "false", + KeyDefaultHost: "github.com", + KeyOutput: OutTree, + KeyReposRoot: fmt.Sprintf("~%c%s", filepath.Separator, "repositories"), + KeyDefaultScheme: "ssh", } // Values for the --out flag. diff --git a/pkg/get.go b/pkg/get.go index 93cf25b..407fff8 100644 --- a/pkg/get.go +++ b/pkg/get.go @@ -8,12 +8,13 @@ import ( // GetCfg provides configuration for the Get command. type GetCfg struct { - Branch string - DefHost string - Dump string - Root string - SkipHost bool - URL string + Branch string + DefHost string + DefScheme string + Dump string + Root string + SkipHost bool + URL string } // Get executes the "git get" command. @@ -33,7 +34,7 @@ func Get(c *GetCfg) error { } func cloneSingleRepo(c *GetCfg) error { - url, err := ParseURL(c.URL, c.DefHost) + url, err := ParseURL(c.URL, c.DefHost, c.DefScheme) if err != nil { return err } @@ -56,7 +57,7 @@ func cloneDumpFile(c *GetCfg) error { } for _, line := range parsedLines { - url, err := ParseURL(line.rawurl, c.DefHost) + url, err := ParseURL(line.rawurl, c.DefHost, c.DefScheme) if err != nil { return err } diff --git a/pkg/url.go b/pkg/url.go index d8beefa..5715cf0 100644 --- a/pkg/url.go +++ b/pkg/url.go @@ -17,8 +17,9 @@ var errEmptyURLPath = errors.New("parsed URL path is empty") var scpSyntax = regexp.MustCompile(`^([a-zA-Z0-9_]+)@([a-zA-Z0-9._-]+):(.*)$`) // ParseURL parses given rawURL string into a URL. -// The defaultHost argument defines the host to use (eg, github.com) in case parsed URL has an empty host. -func ParseURL(rawURL string, defaultHost string) (url *urlpkg.URL, err error) { +// When the parsed URL has an empty host, use the defaultHost. +// When the parsed URL has an empty scheme, use the defaultScheme. +func ParseURL(rawURL string, defaultHost string, defaultScheme string) (url *urlpkg.URL, err error) { // If rawURL matches the SCP-like syntax, convert it into a standard ssh Path. // eg, git@github.com:user/repo => ssh://git@github.com/user/repo if m := scpSyntax.FindStringSubmatch(rawURL); m != nil { @@ -26,7 +27,7 @@ func ParseURL(rawURL string, defaultHost string) (url *urlpkg.URL, err error) { Scheme: "ssh", User: urlpkg.User(m[1]), Host: m[2], - Path: m[3], + Path: path.Join("/", m[3]), } } else { url, err = urlpkg.Parse(rawURL) @@ -43,14 +44,21 @@ func ParseURL(rawURL string, defaultHost string) (url *urlpkg.URL, err error) { url.Scheme = "ssh" } - // Default to "git" user when using ssh and no user is provided - if url.Scheme == "ssh" && url.User == nil { - url.User = urlpkg.User("git") - } - // Default to configured defaultHost when host is empty if url.Host == "" { url.Host = defaultHost + // Add a leading slash to path when host is missing. It's needed to correctly compare urlpkg.URL structs. + url.Path = path.Join("/", url.Path) + } + + // Default to configured defaultScheme when scheme is empty + if url.Scheme == "" { + url.Scheme = defaultScheme + } + + // Default to "git" user when using ssh and no user is provided + if url.Scheme == "ssh" && url.User == nil { + url.User = urlpkg.User("git") } // Don't use host when scheme is file://. The fragment detected as url.Host should be a first directory of url.Path @@ -59,11 +67,6 @@ func ParseURL(rawURL string, defaultHost string) (url *urlpkg.URL, err error) { url.Host = "" } - // Default to https when scheme is empty - if url.Scheme == "" { - url.Scheme = "https" - } - return url, nil } diff --git a/pkg/url_test.go b/pkg/url_test.go index 9f12fd7..934c211 100644 --- a/pkg/url_test.go +++ b/pkg/url_test.go @@ -3,6 +3,8 @@ package pkg import ( "git-get/pkg/cfg" "testing" + + "github.com/stretchr/testify/assert" ) // Following URLs are considered valid according to https://git-scm.com/docs/git-clone#_git_urls: @@ -50,16 +52,11 @@ func TestURLParse(t *testing.T) { } for _, test := range tests { - url, err := ParseURL(test.in, cfg.Defaults[cfg.KeyDefaultHost]) - if err != nil { - t.Fatalf("got error: %+v", err) - } + url, err := ParseURL(test.in, cfg.Defaults[cfg.KeyDefaultHost], cfg.Defaults[cfg.KeyDefaultScheme]) + assert.NoError(t, err) got := URLToPath(*url, false) - - if got != test.want { - t.Errorf("wrong result for %q; expected %q; got %q", test.in, test.want, got) - } + assert.Equal(t, test.want, got) } } func TestURLParseSkipHost(t *testing.T) { @@ -95,16 +92,39 @@ func TestURLParseSkipHost(t *testing.T) { } for _, test := range tests { - url, err := ParseURL(test.in, cfg.Defaults[cfg.KeyDefaultHost]) - if err != nil { - t.Fatalf("got error: %+v", err) - } + url, err := ParseURL(test.in, cfg.Defaults[cfg.KeyDefaultHost], cfg.Defaults[cfg.KeyDefaultScheme]) + assert.NoError(t, err) got := URLToPath(*url, true) + assert.Equal(t, test.want, got) + } +} - if got != test.want { - t.Errorf("wrong result for %q; expected %q; got %q", test.in, test.want, got) - } +func TestDefaultScheme(t *testing.T) { + tests := []struct { + in string + scheme string + want string + }{ + {"grdl/git-get", "ssh", "ssh://git@github.com/grdl/git-get"}, + {"grdl/git-get", "https", "https://github.com/grdl/git-get"}, + {"https://github.com/grdl/git-get", "ssh", "https://github.com/grdl/git-get"}, + {"https://github.com/grdl/git-get", "https", "https://github.com/grdl/git-get"}, + {"ssh://github.com/grdl/git-get", "ssh", "ssh://git@github.com/grdl/git-get"}, + {"ssh://github.com/grdl/git-get", "https", "ssh://git@github.com/grdl/git-get"}, + {"git+ssh://github.com/grdl/git-get", "https", "ssh://git@github.com/grdl/git-get"}, + {"git@github.com:grdl/git-get", "ssh", "ssh://git@github.com/grdl/git-get"}, + {"git@github.com:grdl/git-get", "https", "ssh://git@github.com/grdl/git-get"}, + } + + for _, test := range tests { + url, err := ParseURL(test.in, cfg.Defaults[cfg.KeyDefaultHost], test.scheme) + assert.NoError(t, err) + + want, err := url.Parse(test.want) + assert.NoError(t, err) + + assert.Equal(t, url, want) } } @@ -119,9 +139,8 @@ func TestInvalidURLParse(t *testing.T) { } for _, test := range invalidURLs { - got, err := ParseURL(test, cfg.Defaults[cfg.KeyDefaultHost]) - if err == nil { - t.Errorf("expected error; got %q", got) - } + _, err := ParseURL(test, cfg.Defaults[cfg.KeyDefaultHost], cfg.Defaults[cfg.KeyDefaultScheme]) + + assert.Error(t, err) } }