X Tutup
Skip to content

Commit c0fbb7e

Browse files
author
Alan Donovan
authored
Merge pull request cli#99 from github/runcommand
preparatory cleanups to ssh tunnel/port forwarding code
2 parents 4a45feb + 3aad0bb commit c0fbb7e

File tree

4 files changed

+138
-141
lines changed

4 files changed

+138
-141
lines changed

cmd/ghcs/logs.go

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package main
22

33
import (
4-
"bufio"
54
"context"
65
"fmt"
76
"os"
@@ -24,7 +23,7 @@ func newLogsCmd() *cobra.Command {
2423
if len(args) > 0 {
2524
codespaceName = args[0]
2625
}
27-
return logs(tail, codespaceName)
26+
return logs(context.Background(), tail, codespaceName)
2827
},
2928
}
3029

@@ -37,9 +36,12 @@ func init() {
3736
rootCmd.AddCommand(newLogsCmd())
3837
}
3938

40-
func logs(tail bool, codespaceName string) error {
39+
func logs(ctx context.Context, tail bool, codespaceName string) error {
40+
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
41+
ctx, cancel := context.WithCancel(ctx)
42+
defer cancel()
43+
4144
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
42-
ctx := context.Background()
4345
log := output.NewLogger(os.Stdout, os.Stderr, false)
4446

4547
user, err := apiClient.GetUser(ctx)
@@ -57,12 +59,17 @@ func logs(tail bool, codespaceName string) error {
5759
return fmt.Errorf("connecting to Live Share: %v", err)
5860
}
5961

62+
localSSHPort, err := codespaces.UnusedPort()
63+
if err != nil {
64+
return err
65+
}
66+
6067
remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log)
6168
if err != nil {
6269
return fmt.Errorf("error getting ssh server details: %v", err)
6370
}
6471

65-
tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort)
72+
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort)
6673
if err != nil {
6774
return fmt.Errorf("make ssh tunnel: %v", err)
6875
}
@@ -73,42 +80,30 @@ func logs(tail bool, codespaceName string) error {
7380
}
7481

7582
dst := fmt.Sprintf("%s@localhost", sshUser)
76-
stdout, err := codespaces.RunCommand(
77-
ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
83+
cmd := codespaces.NewRemoteCommand(
84+
ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
7885
)
79-
if err != nil {
80-
return fmt.Errorf("run command: %v", err)
81-
}
8286

83-
done := make(chan error)
84-
go func() {
85-
scanner := bufio.NewScanner(stdout)
86-
for scanner.Scan() {
87-
fmt.Println(scanner.Text())
88-
}
87+
// Error channels are buffered so that neither sending goroutine gets stuck.
8988

90-
if err := scanner.Err(); err != nil {
91-
done <- fmt.Errorf("error scanning: %v", err)
92-
return
93-
}
89+
tunnelClosed := make(chan error, 1)
90+
go func() {
91+
tunnelClosed <- tunnel.Start(ctx) // error is non-nil
92+
}()
9493

95-
if err := stdout.Close(); err != nil {
96-
done <- fmt.Errorf("close stdout: %v", err)
97-
return
98-
}
99-
done <- nil
94+
cmdDone := make(chan error, 1)
95+
go func() {
96+
cmdDone <- cmd.Run()
10097
}()
10198

10299
select {
103-
case err := <-connClosed:
104-
if err != nil {
105-
return fmt.Errorf("connection closed: %v", err)
106-
}
107-
case err := <-done:
100+
case err := <-tunnelClosed:
101+
return fmt.Errorf("connection closed: %v", err)
102+
103+
case err := <-cmdDone:
108104
if err != nil {
109-
return err
105+
return fmt.Errorf("error retrieving logs: %v", err)
110106
}
107+
return nil // success
111108
}
112-
113-
return nil
114109
}

cmd/ghcs/ssh.go

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func newSSHCmd() *cobra.Command {
2323
Short: "SSH into a Codespace",
2424
Args: cobra.NoArgs,
2525
RunE: func(cmd *cobra.Command, args []string) error {
26-
return ssh(sshProfile, codespaceName, sshServerPort)
26+
return ssh(context.Background(), sshProfile, codespaceName, sshServerPort)
2727
},
2828
}
2929

@@ -38,9 +38,12 @@ func init() {
3838
rootCmd.AddCommand(newSSHCmd())
3939
}
4040

41-
func ssh(sshProfile, codespaceName string, sshServerPort int) error {
41+
func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error {
42+
// Ensure all child tasks (e.g. port forwarding) terminate before return.
43+
ctx, cancel := context.WithCancel(ctx)
44+
defer cancel()
45+
4246
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
43-
ctx := context.Background()
4447
log := output.NewLogger(os.Stdout, os.Stderr, false)
4548

4649
user, err := apiClient.GetUser(ctx)
@@ -81,7 +84,16 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
8184
}
8285
log.Print("\n")
8386

84-
tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort, remoteSSHServerPort)
87+
usingCustomPort := true
88+
if localSSHServerPort == 0 {
89+
usingCustomPort = false // suppress log of command line in Shell
90+
localSSHServerPort, err = codespaces.UnusedPort()
91+
if err != nil {
92+
return err
93+
}
94+
}
95+
96+
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHServerPort, remoteSSHServerPort)
8597
if err != nil {
8698
return fmt.Errorf("make ssh tunnel: %v", err)
8799
}
@@ -91,22 +103,27 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
91103
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
92104
}
93105

94-
usingCustomPort := tunnelPort == sshServerPort
95-
connClosed := codespaces.ConnectToTunnel(ctx, log, tunnelPort, connectDestination, usingCustomPort)
106+
tunnelClosed := make(chan error)
107+
go func() {
108+
tunnelClosed <- tunnel.Start(ctx) // error is always non-nil
109+
}()
110+
111+
shellClosed := make(chan error)
112+
go func() {
113+
shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort)
114+
}()
96115

97116
log.Println("Ready...")
98117
select {
99118
case err := <-tunnelClosed:
119+
return fmt.Errorf("tunnel closed: %v", err)
120+
121+
case err := <-shellClosed:
100122
if err != nil {
101-
return fmt.Errorf("tunnel closed: %v", err)
102-
}
103-
case err := <-connClosed:
104-
if err != nil {
105-
return fmt.Errorf("connection closed: %v", err)
123+
return fmt.Errorf("shell closed: %v", err)
106124
}
125+
return nil // success
107126
}
108-
109-
return nil
110127
}
111128

112129
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {

internal/codespaces/ssh.go

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,54 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"io"
8-
"math/rand"
7+
"net"
98
"os"
109
"os/exec"
1110
"strconv"
1211
"strings"
13-
"time"
1412

1513
"github.com/github/go-liveshare"
1614
)
1715

18-
func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort int, remoteSSHPort int) (int, <-chan error, error) {
19-
tunnelClosed := make(chan error)
16+
// UnusedPort returns the number of a local TCP port that is currently
17+
// unbound, or an error if none was available.
18+
//
19+
// Use of this function carries an inherent risk of a time-of-check to
20+
// time-of-use race against other processes.
21+
func UnusedPort() (int, error) {
22+
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
23+
if err != nil {
24+
return 0, fmt.Errorf("internal error while choosing port: %v", err)
25+
}
2026

21-
server, err := liveshare.NewServer(lsclient)
27+
l, err := net.ListenTCP("tcp", addr)
2228
if err != nil {
23-
return 0, nil, fmt.Errorf("new Live Share server: %v", err)
29+
return 0, fmt.Errorf("choosing available port: %v", err)
2430
}
31+
defer l.Close()
32+
return l.Addr().(*net.TCPAddr).Port, nil
33+
}
2534

26-
rand.Seed(time.Now().Unix())
27-
port := rand.Intn(9999-2000) + 2000 // improve this obviously
28-
if localSSHPort != 0 {
29-
port = localSSHPort
35+
// NewPortForwarder returns a new port forwarder for traffic between
36+
// the Live Share client and the specified local and remote ports.
37+
//
38+
// The session name is used (along with the port) to generate
39+
// names for streams, and may appear in error messages.
40+
func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) {
41+
if localSSHPort == 0 {
42+
return nil, fmt.Errorf("a local port must be provided")
3043
}
3144

32-
if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil {
33-
return 0, nil, fmt.Errorf("sharing sshd port: %v", err)
45+
server, err := liveshare.NewServer(client)
46+
if err != nil {
47+
return nil, fmt.Errorf("new liveshare server: %v", err)
3448
}
3549

36-
go func() {
37-
portForwarder := liveshare.NewPortForwarder(lsclient, server, port)
38-
if err := portForwarder.Start(ctx); err != nil {
39-
tunnelClosed <- fmt.Errorf("forwarding port: %v", err)
40-
return
41-
}
42-
tunnelClosed <- nil
43-
}()
50+
if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil {
51+
return nil, fmt.Errorf("sharing sshd port: %v", err)
52+
}
4453

45-
return port, tunnelClosed, nil
54+
return liveshare.NewPortForwarder(client, server, localSSHPort), nil
4655
}
4756

4857
// StartSSHServer installs (if necessary) and starts the SSH in the codespace.
@@ -72,72 +81,41 @@ func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (
7281
return portInt, sshServerStartResult.User, nil
7382
}
7483

75-
func makeSSHArgs(port int, dst, cmd string) ([]string, []string) {
76-
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
77-
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
78-
79-
if cmd != "" {
80-
cmdArgs = append(cmdArgs, cmd)
81-
}
82-
83-
return cmdArgs, connArgs
84-
}
85-
86-
func ConnectToTunnel(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) <-chan error {
87-
connClosed := make(chan error)
88-
args, connArgs := makeSSHArgs(port, destination, "")
84+
// Shell runs an interactive secure shell over an existing
85+
// port-forwarding session. It runs until the shell is terminated
86+
// (including by cancellation of the context).
87+
func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error {
88+
cmd, connArgs := newSSHCommand(ctx, port, destination, "")
8989

9090
if usingCustomPort {
9191
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
9292
}
9393

94-
cmd := exec.CommandContext(ctx, "ssh", args...)
95-
cmd.Stdout = os.Stdout
96-
cmd.Stdin = os.Stdin
97-
cmd.Stderr = os.Stderr
98-
99-
go func() {
100-
connClosed <- cmd.Run()
101-
}()
102-
103-
return connClosed
104-
}
105-
106-
type command struct {
107-
Cmd *exec.Cmd
108-
StdoutPipe io.ReadCloser
94+
return cmd.Run()
10995
}
11096

111-
func newCommand(cmd *exec.Cmd) (*command, error) {
112-
stdoutPipe, err := cmd.StdoutPipe()
113-
if err != nil {
114-
return nil, fmt.Errorf("create stdout pipe: %v", err)
115-
}
116-
117-
if err := cmd.Start(); err != nil {
118-
return nil, fmt.Errorf("cmd start: %v", err)
119-
}
120-
121-
return &command{
122-
Cmd: cmd,
123-
StdoutPipe: stdoutPipe,
124-
}, nil
97+
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
98+
// command on the remote machine.
99+
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
100+
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
101+
return cmd
125102
}
126103

127-
func (c *command) Read(p []byte) (int, error) {
128-
return c.StdoutPipe.Read(p)
129-
}
104+
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
105+
// an interactive shell) over ssh.
106+
func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) {
107+
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
108+
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
109+
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
130110

131-
func (c *command) Close() error {
132-
if err := c.StdoutPipe.Close(); err != nil {
133-
return fmt.Errorf("close stdout: %v", err)
111+
if command != "" {
112+
cmdArgs = append(cmdArgs, command)
134113
}
135114

136-
return c.Cmd.Wait()
137-
}
115+
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
116+
cmd.Stdout = os.Stdout
117+
cmd.Stdin = os.Stdin
118+
cmd.Stderr = os.Stderr
138119

139-
func RunCommand(ctx context.Context, tunnelPort int, destination, cmdString string) (io.ReadCloser, error) {
140-
args, _ := makeSSHArgs(tunnelPort, destination, cmdString)
141-
cmd := exec.CommandContext(ctx, "ssh", args...)
142-
return newCommand(cmd)
120+
return cmd, connArgs
143121
}

0 commit comments

Comments
 (0)
X Tutup