diff --git a/cobra/cmd/init.go b/cobra/cmd/init.go index edddc2c..504a478 100644 --- a/cobra/cmd/init.go +++ b/cobra/cmd/init.go @@ -32,42 +32,17 @@ var ( Long: `Initialize (cobra init) will create a new application, with a license and the appropriate structure for a Cobra-based CLI application. - * If a name is provided, it will be created in the current directory; + * If a name is provided, a directory with that name will be created in the current directory; * If no name is provided, the current directory will be assumed; - * If a relative path is provided, it will be created inside $GOPATH - (e.g. github.com/spf13/hugo); - * If an absolute path is provided, it will be created; - * If the directory already exists but is empty, it will be used. +`, -Init will not use an existing directory with contents.`, + Run: func(_ *cobra.Command, args []string) { - Run: func(cmd *cobra.Command, args []string) { - - wd, err := os.Getwd() + projectPath, err := initializeProject(args) if err != nil { er(err) } - - if len(args) > 0 { - if args[0] != "." { - wd = fmt.Sprintf("%s/%s", wd, args[0]) - } - } - - project := &Project{ - AbsolutePath: wd, - PkgName: pkgName, - Legal: getLicense(), - Copyright: copyrightLine(), - Viper: viper.GetBool("useViper"), - AppName: path.Base(pkgName), - } - - if err := project.Create(); err != nil { - er(err) - } - - fmt.Printf("Your Cobra application is ready at\n%s\n", project.AbsolutePath) + fmt.Printf("Your Cobra application is ready at\n%s\n", projectPath) }, } ) @@ -76,3 +51,31 @@ func init() { initCmd.Flags().StringVar(&pkgName, "pkg-name", "", "fully qualified pkg name") initCmd.MarkFlagRequired("pkg-name") } + +func initializeProject(args []string) (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + + if len(args) > 0 { + if args[0] != "." { + wd = fmt.Sprintf("%s/%s", wd, args[0]) + } + } + + project := &Project{ + AbsolutePath: wd, + PkgName: pkgName, + Legal: getLicense(), + Copyright: copyrightLine(), + Viper: viper.GetBool("useViper"), + AppName: path.Base(pkgName), + } + + if err := project.Create(); err != nil { + return "", err + } + + return project.AbsolutePath, nil +} diff --git a/cobra/cmd/init_test.go b/cobra/cmd/init_test.go index 8ee3910..c4b3f09 100644 --- a/cobra/cmd/init_test.go +++ b/cobra/cmd/init_test.go @@ -2,9 +2,12 @@ package cmd import ( "fmt" + "io/ioutil" "os" "path/filepath" "testing" + + "github.com/spf13/viper" ) func getProject() *Project { @@ -20,20 +23,72 @@ func getProject() *Project { } func TestGoldenInitCmd(t *testing.T) { - project := getProject() - defer os.RemoveAll(project.AbsolutePath) - if err := project.Create(); err != nil { + dir, err := ioutil.TempDir("", "cobra-init") + if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) - expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"} - for _, f := range expectedFiles { - generatedFile := fmt.Sprintf("%s/%s", project.AbsolutePath, f) - goldenFile := fmt.Sprintf("testdata/%s.golden", filepath.Base(f)) - err := compareFiles(generatedFile, goldenFile) - if err != nil { - t.Fatal(err) - } + tests := []struct { + name string + args []string + pkgName string + expectErr bool + }{ + { + name: "successfully creates a project with name", + args: []string{"testproject"}, + pkgName: "github.com/spf13/testproject", + expectErr: false, + }, + { + name: "returns error when passing an absolute path for project", + args: []string{dir}, + pkgName: "github.com/spf13/testproject", + expectErr: true, + }, + { + name: "returns error when passing an relative path for project", + args: []string{"github.com/spf13/testproject"}, + pkgName: "github.com/spf13/testproject", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + initCmd.Flags().Set("pkg-name", tt.pkgName) + viper.Set("useViper", true) + projectPath, err := initializeProject(tt.args) + defer func() { + if projectPath != "" { + os.RemoveAll(projectPath) + } + }() + + if !tt.expectErr && err != nil { + t.Fatalf("did not expect an error, got %s", err) + } + if tt.expectErr { + if err == nil { + t.Fatal("expected an error but got none") + } else { + // got an expected error nothing more to do + return + } + } + + expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"} + for _, f := range expectedFiles { + generatedFile := fmt.Sprintf("%s/%s", projectPath, f) + goldenFile := fmt.Sprintf("testdata/%s.golden", filepath.Base(f)) + err := compareFiles(generatedFile, goldenFile) + if err != nil { + t.Fatal(err) + } + } + }) } }