X Tutup
Skip to content

Commit a0f11b6

Browse files
josebaliusmislav
authored andcommitted
Handle concurrency in tests and logger
- Live Share tests - Logger implementation for ghcs
1 parent 0e98b30 commit a0f11b6

File tree

2 files changed

+152
-38
lines changed

2 files changed

+152
-38
lines changed

cmd/ghcs/output/logger.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package output
33
import (
44
"fmt"
55
"io"
6+
"sync"
67
)
78

89
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
@@ -19,6 +20,7 @@ func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
1920
// If not enabled, Print functions will noop but Error functions will continue
2021
// to write to the stderr writer.
2122
type Logger struct {
23+
mu sync.Mutex // guards the writers
2224
out io.Writer
2325
errout io.Writer
2426
enabled bool
@@ -29,6 +31,9 @@ func (l *Logger) Print(v ...interface{}) (int, error) {
2931
if !l.enabled {
3032
return 0, nil
3133
}
34+
35+
l.mu.Lock()
36+
defer l.mu.Unlock()
3237
return fmt.Fprint(l.out, v...)
3338
}
3439

@@ -37,6 +42,9 @@ func (l *Logger) Println(v ...interface{}) (int, error) {
3742
if !l.enabled {
3843
return 0, nil
3944
}
45+
46+
l.mu.Lock()
47+
defer l.mu.Unlock()
4048
return fmt.Fprintln(l.out, v...)
4149
}
4250

@@ -45,15 +53,22 @@ func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
4553
if !l.enabled {
4654
return 0, nil
4755
}
56+
57+
l.mu.Lock()
58+
defer l.mu.Unlock()
4859
return fmt.Fprintf(l.out, f, v...)
4960
}
5061

5162
// Errorf writes the formatted arguments to the stderr writer.
5263
func (l *Logger) Errorf(f string, v ...interface{}) (int, error) {
64+
l.mu.Lock()
65+
defer l.mu.Unlock()
5366
return fmt.Fprintf(l.errout, f, v...)
5467
}
5568

5669
// Errorln writes the arguments to the stderr writer with a newline at the end.
5770
func (l *Logger) Errorln(v ...interface{}) (int, error) {
71+
l.mu.Lock()
72+
defer l.mu.Unlock()
5873
return fmt.Fprintln(l.errout, v...)
5974
}

internal/liveshare/test/server.go

Lines changed: 137 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"strings"
11+
"sync"
1112

1213
"github.com/gorilla/websocket"
1314
"github.com/sourcegraph/jsonrpc2"
@@ -42,17 +43,19 @@ IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
4243
Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
4344
-----END RSA PRIVATE KEY-----`
4445

46+
// Server represents a LiveShare relay host server.
4547
type Server struct {
46-
password string
47-
services map[string]RPCHandleFunc
48-
relaySAS string
49-
streams map[string]io.ReadWriter
50-
48+
password string
49+
services map[string]RPCHandleFunc
50+
relaySAS string
51+
streams map[string]io.ReadWriter
5152
sshConfig *ssh.ServerConfig
5253
httptestServer *httptest.Server
5354
errCh chan error
5455
}
5556

57+
// NewServer creates a new Server. ServerOptions can be passed to configure
58+
// the SSH password, backing service, secrets and more.
5659
func NewServer(opts ...ServerOption) (*Server, error) {
5760
server := new(Server)
5861

@@ -71,20 +74,23 @@ func NewServer(opts ...ServerOption) (*Server, error) {
7174
}
7275
server.sshConfig.AddHostKey(privateKey)
7376

74-
server.errCh = make(chan error)
77+
server.errCh = make(chan error, 1)
7578
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
7679
return server, nil
7780
}
7881

82+
// ServerOption is used to configure the Server.
7983
type ServerOption func(*Server) error
8084

85+
// WithPassword configures the Server password for SSH.
8186
func WithPassword(password string) ServerOption {
8287
return func(s *Server) error {
8388
s.password = password
8489
return nil
8590
}
8691
}
8792

93+
// WithService accepts a mock RPC service for the Server to invoke.
8894
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
8995
return func(s *Server) error {
9096
if s.services == nil {
@@ -96,13 +102,15 @@ func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
96102
}
97103
}
98104

105+
// WithRelaySAS configures the relay SAS configuration key.
99106
func WithRelaySAS(sas string) ServerOption {
100107
return func(s *Server) error {
101108
s.relaySAS = sas
102109
return nil
103110
}
104111
}
105112

113+
// WithStream allows you to specify a mock data stream for the server.
106114
func WithStream(name string, stream io.ReadWriter) ServerOption {
107115
return func(s *Server) error {
108116
if s.streams == nil {
@@ -122,10 +130,12 @@ func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (
122130
}
123131
}
124132

133+
// Close closes the underlying httptest Server.
125134
func (s *Server) Close() {
126135
s.httptestServer.Close()
127136
}
128137

138+
// URL returns the httptest Server url.
129139
func (s *Server) URL() string {
130140
return s.httptestServer.URL
131141
}
@@ -145,73 +155,160 @@ func makeConnection(server *Server) http.HandlerFunc {
145155
// validate the sas key
146156
sasParam := req.URL.Query().Get("sb-hc-token")
147157
if sasParam != server.relaySAS {
148-
server.errCh <- errors.New("error validating sas")
158+
sendError(server.errCh, errors.New("error validating sas"))
149159
return
150160
}
151161
}
152162
c, err := upgrader.Upgrade(w, req, nil)
153163
if err != nil {
154-
server.errCh <- fmt.Errorf("error upgrading connection: %w", err)
164+
sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
155165
return
156166
}
157-
defer c.Close()
167+
defer func() {
168+
if err := c.Close(); err != nil {
169+
sendError(server.errCh, err)
170+
}
171+
}()
158172

159173
socketConn := newSocketConn(c)
160174
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
161175
if err != nil {
162-
server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err)
176+
sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
163177
return
164178
}
165179
go ssh.DiscardRequests(reqs)
166180

167-
for newChannel := range chans {
168-
ch, reqs, err := newChannel.Accept()
181+
if err := handleChannels(ctx, server, chans); err != nil {
182+
sendError(server.errCh, err)
183+
}
184+
}
185+
}
186+
187+
// sendError does a non-blocking send of the error to the err channel.
188+
func sendError(errc chan<- error, err error) {
189+
select {
190+
case errc <- err:
191+
default:
192+
// channel is blocked with a previous error, so we ignore
193+
// this current error
194+
}
195+
}
196+
197+
// awaitError waits for the context to finish and returns its error (if any).
198+
// It also waits for an err to come through the err channel.
199+
func awaitError(ctx context.Context, errc <-chan error) error {
200+
select {
201+
case <-ctx.Done():
202+
return ctx.Err()
203+
case err := <-errc:
204+
return err
205+
}
206+
}
207+
208+
// handleChannels services the sshChannels channel. For each SSH channel received
209+
// it creates a go routine to service the channel's requests. It returns on the first
210+
// error encountered.
211+
func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
212+
errc := make(chan error, 1)
213+
go func() {
214+
for sshCh := range sshChannels {
215+
ch, reqs, err := sshCh.Accept()
169216
if err != nil {
170-
server.errCh <- fmt.Errorf("error accepting new channel: %w", err)
217+
sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
171218
return
172219
}
173-
go handleNewRequests(ctx, server, ch, reqs)
174-
go handleNewChannel(server, ch)
220+
221+
go func() {
222+
if err := handleRequests(ctx, server, ch, reqs); err != nil {
223+
sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
224+
}
225+
}()
226+
227+
handleChannel(server, ch)
175228
}
176-
}
229+
}()
230+
return awaitError(ctx, errc)
177231
}
178232

179-
func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
180-
for req := range reqs {
181-
if req.WantReply {
182-
if err := req.Reply(true, nil); err != nil {
183-
server.errCh <- fmt.Errorf("error replying to channel request: %w", err)
233+
// handleRequests services the SSH channel requests channel. It replies to requests and
234+
// when stream transport requests are encountered, creates a go routine to create a
235+
// bi-directional data stream between the channel and server stream. It returns on the first error
236+
// encountered.
237+
func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
238+
errc := make(chan error, 1)
239+
go func() {
240+
for req := range reqs {
241+
if req.WantReply {
242+
if err := req.Reply(true, nil); err != nil {
243+
sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
244+
return
245+
}
246+
}
247+
248+
if strings.HasPrefix(req.Type, "stream-transport") {
249+
go func() {
250+
if err := forwardStream(ctx, server, req.Type, channel); err != nil {
251+
sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
252+
}
253+
}()
184254
}
185255
}
186-
if strings.HasPrefix(req.Type, "stream-transport") {
187-
forwardStream(ctx, server, req.Type, channel)
188-
}
189-
}
256+
}()
257+
258+
return awaitError(ctx, errc)
259+
}
260+
261+
// concurrentStream is a concurrency safe io.ReadWriter.
262+
type concurrentStream struct {
263+
sync.RWMutex
264+
stream io.ReadWriter
265+
}
266+
267+
func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
268+
return &concurrentStream{stream: rw}
190269
}
191270

192-
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) {
271+
func (cs *concurrentStream) Read(b []byte) (int, error) {
272+
cs.RLock()
273+
defer cs.RUnlock()
274+
return cs.stream.Read(b)
275+
}
276+
277+
func (cs *concurrentStream) Write(b []byte) (int, error) {
278+
cs.Lock()
279+
defer cs.Unlock()
280+
return cs.stream.Write(b)
281+
}
282+
283+
// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
284+
// runs until an error is encountered.
285+
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
193286
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
194287
stream, found := server.streams[simpleStreamName]
195288
if !found {
196-
server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName)
197-
return
289+
return fmt.Errorf("stream '%s' not found", simpleStreamName)
198290
}
291+
defer func() {
292+
if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
293+
err = closeErr
294+
}
295+
}()
199296

297+
errc := make(chan error, 2)
200298
copy := func(dst io.Writer, src io.Reader) {
201299
if _, err := io.Copy(dst, src); err != nil {
202-
fmt.Println(err)
203-
server.errCh <- fmt.Errorf("io copy: %w", err)
204-
return
300+
errc <- err
205301
}
206302
}
207303

208-
go copy(stream, channel)
209-
go copy(channel, stream)
304+
csStream := newConcurrentStream(stream)
305+
go copy(csStream, channel)
306+
go copy(channel, csStream)
210307

211-
<-ctx.Done() // TODO(josebalius): improve this
308+
return awaitError(ctx, errc)
212309
}
213310

214-
func handleNewChannel(server *Server, channel ssh.Channel) {
311+
func handleChannel(server *Server, channel ssh.Channel) {
215312
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
216313
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
217314
}
@@ -226,20 +323,22 @@ func newRPCHandler(server *Server) *rpcHandler {
226323
return &rpcHandler{server}
227324
}
228325

326+
// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
327+
// RPC service method and if found, it invokes the handler and replies to the request.
229328
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
230329
handler, found := r.server.services[req.Method]
231330
if !found {
232-
r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method)
331+
sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
233332
return
234333
}
235334

236335
result, err := handler(req)
237336
if err != nil {
238-
r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err)
337+
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
239338
return
240339
}
241340

242341
if err := conn.Reply(ctx, req.ID, result); err != nil {
243-
r.server.errCh <- fmt.Errorf("error replying: %w", err)
342+
sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
244343
}
245344
}

0 commit comments

Comments
 (0)
X Tutup