diff --git a/worker/cmd_provider.go b/worker/cmd_provider.go index 36cc700..f6ee4aa 100644 --- a/worker/cmd_provider.go +++ b/worker/cmd_provider.go @@ -20,7 +20,8 @@ type cmdConfig struct { type cmdProvider struct { baseProvider cmdConfig - cmd []string + command []string + cmd *exec.Cmd session *sh.Session } @@ -43,7 +44,7 @@ func newCmdProvider(c cmdConfig) (*cmdProvider, error) { if err != nil { return nil, err } - provider.cmd = cmd + provider.command = cmd return provider, nil } @@ -70,17 +71,16 @@ func newEnviron(env map[string]string, inherit bool) []string { //map[string]str // TODO: implement this func (p *cmdProvider) Run() error { - var cmd *exec.Cmd - if len(p.cmd) == 1 { - cmd = exec.Command(p.cmd[0]) - } else if len(p.cmd) > 1 { - c := p.cmd[0] - args := p.cmd[1:] - cmd = exec.Command(c, args...) - } else if len(p.cmd) == 0 { + if len(p.command) == 1 { + p.cmd = exec.Command(p.command[0]) + } else if len(p.command) > 1 { + c := p.command[0] + args := p.command[1:] + p.cmd = exec.Command(c, args...) + } else if len(p.command) == 0 { panic("Command length should be at least 1!") } - cmd.Dir = p.WorkingDir() + p.cmd.Dir = p.WorkingDir() env := map[string]string{ "TUNASYNC_MIRROR_NAME": p.Name(), @@ -91,16 +91,20 @@ func (p *cmdProvider) Run() error { for k, v := range p.env { env[k] = v } - cmd.Env = newEnviron(env, true) + p.cmd.Env = newEnviron(env, true) - logFile, err := os.OpenFile(p.LogFile(), os.O_WRONLY, 0644) + logFile, err := os.OpenFile(p.LogFile(), os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } - cmd.Stdout = logFile - cmd.Stderr = logFile + p.cmd.Stdout = logFile + p.cmd.Stderr = logFile - return cmd.Start() + return p.cmd.Start() +} + +func (p *cmdProvider) Wait() error { + return p.cmd.Wait() } // TODO: implement this diff --git a/worker/provider_test.go b/worker/provider_test.go index ce09316..aa2ada2 100644 --- a/worker/provider_test.go +++ b/worker/provider_test.go @@ -1,6 +1,10 @@ package worker import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" "testing" . "github.com/smartystreets/goconvey/convey" @@ -58,3 +62,77 @@ func TestRsyncProvider(t *testing.T) { }) } + +func TestCmdProvider(t *testing.T) { + Convey("Command Provider should work", t, func() { + tmpDir, err := ioutil.TempDir("", "tunasync") + defer os.RemoveAll(tmpDir) + So(err, ShouldBeNil) + scriptFile := filepath.Join(tmpDir, "cmd.sh") + tmpFile := filepath.Join(tmpDir, "log_file") + + c := cmdConfig{ + name: "tuna-cmd", + upstreamURL: "http://mirrors.tuna.moe/", + command: "bash " + scriptFile, + workingDir: tmpDir, + logDir: tmpDir, + logFile: tmpFile, + interval: 600, + } + + provider, err := newCmdProvider(c) + So(err, ShouldBeNil) + + So(provider.Name(), ShouldEqual, c.name) + So(provider.WorkingDir(), ShouldEqual, c.workingDir) + So(provider.LogDir(), ShouldEqual, c.logDir) + So(provider.LogFile(), ShouldEqual, c.logFile) + So(provider.Interval(), ShouldEqual, c.interval) + + Convey("Let's try to run a simple command", func() { + scriptContent := `#!/bin/bash +echo $TUNASYNC_WORKING_DIR +echo $TUNASYNC_MIRROR_NAME +echo $TUNASYNC_UPSTREAM_URL +echo $TUNASYNC_LOG_FILE +` + exceptedOutput := fmt.Sprintf( + "%s\n%s\n%s\n%s\n", + provider.WorkingDir(), + provider.Name(), + provider.upstreamURL, + provider.LogFile(), + ) + err = ioutil.WriteFile(scriptFile, []byte(scriptContent), 0755) + So(err, ShouldBeNil) + readedScriptContent, err := ioutil.ReadFile(scriptFile) + So(err, ShouldBeNil) + So(readedScriptContent, ShouldResemble, []byte(scriptContent)) + + err = provider.Run() + So(err, ShouldBeNil) + err = provider.cmd.Wait() + So(err, ShouldBeNil) + + loggedContent, err := ioutil.ReadFile(provider.LogFile()) + So(err, ShouldBeNil) + So(string(loggedContent), ShouldEqual, exceptedOutput) + }) + + Convey("If a command fails", func() { + scriptContent := `exit 1` + err = ioutil.WriteFile(scriptFile, []byte(scriptContent), 0755) + So(err, ShouldBeNil) + readedScriptContent, err := ioutil.ReadFile(scriptFile) + So(err, ShouldBeNil) + So(readedScriptContent, ShouldResemble, []byte(scriptContent)) + + err = provider.Run() + So(err, ShouldBeNil) + err = provider.cmd.Wait() + So(err, ShouldNotBeNil) + + }) + }) +}