88 "net/http"
99 "net/http/httptest"
1010 "strings"
11+ "sync"
1112
1213 "github.com/gorilla/websocket"
1314 "github.com/sourcegraph/jsonrpc2"
@@ -42,17 +43,19 @@ IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
4243Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
4344-----END RSA PRIVATE KEY-----`
4445
46+ // Server represents a LiveShare relay host server.
4547type Server struct {
46- password string
47- services map [string ]RPCHandleFunc
48- relaySAS string
49- streams map [string ]io.ReadWriter
50-
48+ password string
49+ services map [string ]RPCHandleFunc
50+ relaySAS string
51+ streams map [string ]io.ReadWriter
5152 sshConfig * ssh.ServerConfig
5253 httptestServer * httptest.Server
5354 errCh chan error
5455}
5556
57+ // NewServer creates a new Server. ServerOptions can be passed to configure
58+ // the SSH password, backing service, secrets and more.
5659func NewServer (opts ... ServerOption ) (* Server , error ) {
5760 server := new (Server )
5861
@@ -71,20 +74,23 @@ func NewServer(opts ...ServerOption) (*Server, error) {
7174 }
7275 server .sshConfig .AddHostKey (privateKey )
7376
74- server .errCh = make (chan error )
77+ server .errCh = make (chan error , 1 )
7578 server .httptestServer = httptest .NewTLSServer (http .HandlerFunc (makeConnection (server )))
7679 return server , nil
7780}
7881
82+ // ServerOption is used to configure the Server.
7983type ServerOption func (* Server ) error
8084
85+ // WithPassword configures the Server password for SSH.
8186func WithPassword (password string ) ServerOption {
8287 return func (s * Server ) error {
8388 s .password = password
8489 return nil
8590 }
8691}
8792
93+ // WithService accepts a mock RPC service for the Server to invoke.
8894func WithService (serviceName string , handler RPCHandleFunc ) ServerOption {
8995 return func (s * Server ) error {
9096 if s .services == nil {
@@ -96,13 +102,15 @@ func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
96102 }
97103}
98104
105+ // WithRelaySAS configures the relay SAS configuration key.
99106func WithRelaySAS (sas string ) ServerOption {
100107 return func (s * Server ) error {
101108 s .relaySAS = sas
102109 return nil
103110 }
104111}
105112
113+ // WithStream allows you to specify a mock data stream for the server.
106114func WithStream (name string , stream io.ReadWriter ) ServerOption {
107115 return func (s * Server ) error {
108116 if s .streams == nil {
@@ -122,10 +130,12 @@ func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (
122130 }
123131}
124132
133+ // Close closes the underlying httptest Server.
125134func (s * Server ) Close () {
126135 s .httptestServer .Close ()
127136}
128137
138+ // URL returns the httptest Server url.
129139func (s * Server ) URL () string {
130140 return s .httptestServer .URL
131141}
@@ -145,73 +155,160 @@ func makeConnection(server *Server) http.HandlerFunc {
145155 // validate the sas key
146156 sasParam := req .URL .Query ().Get ("sb-hc-token" )
147157 if sasParam != server .relaySAS {
148- server .errCh <- errors .New ("error validating sas" )
158+ sendError ( server .errCh , errors .New ("error validating sas" ) )
149159 return
150160 }
151161 }
152162 c , err := upgrader .Upgrade (w , req , nil )
153163 if err != nil {
154- server .errCh <- fmt .Errorf ("error upgrading connection: %w" , err )
164+ sendError ( server .errCh , fmt .Errorf ("error upgrading connection: %w" , err ) )
155165 return
156166 }
157- defer c .Close ()
167+ defer func () {
168+ if err := c .Close (); err != nil {
169+ sendError (server .errCh , err )
170+ }
171+ }()
158172
159173 socketConn := newSocketConn (c )
160174 _ , chans , reqs , err := ssh .NewServerConn (socketConn , server .sshConfig )
161175 if err != nil {
162- server .errCh <- fmt .Errorf ("error creating new ssh conn: %w" , err )
176+ sendError ( server .errCh , fmt .Errorf ("error creating new ssh conn: %w" , err ) )
163177 return
164178 }
165179 go ssh .DiscardRequests (reqs )
166180
167- for newChannel := range chans {
168- ch , reqs , err := newChannel .Accept ()
181+ if err := handleChannels (ctx , server , chans ); err != nil {
182+ sendError (server .errCh , err )
183+ }
184+ }
185+ }
186+
187+ // sendError does a non-blocking send of the error to the err channel.
188+ func sendError (errc chan <- error , err error ) {
189+ select {
190+ case errc <- err :
191+ default :
192+ // channel is blocked with a previous error, so we ignore
193+ // this current error
194+ }
195+ }
196+
197+ // awaitError waits for the context to finish and returns its error (if any).
198+ // It also waits for an err to come through the err channel.
199+ func awaitError (ctx context.Context , errc <- chan error ) error {
200+ select {
201+ case <- ctx .Done ():
202+ return ctx .Err ()
203+ case err := <- errc :
204+ return err
205+ }
206+ }
207+
208+ // handleChannels services the sshChannels channel. For each SSH channel received
209+ // it creates a go routine to service the channel's requests. It returns on the first
210+ // error encountered.
211+ func handleChannels (ctx context.Context , server * Server , sshChannels <- chan ssh.NewChannel ) error {
212+ errc := make (chan error , 1 )
213+ go func () {
214+ for sshCh := range sshChannels {
215+ ch , reqs , err := sshCh .Accept ()
169216 if err != nil {
170- server . errCh <- fmt .Errorf ("error accepting new channel: %w" , err )
217+ sendError ( errc , fmt .Errorf ("failed to accept channel: %w" , err ) )
171218 return
172219 }
173- go handleNewRequests (ctx , server , ch , reqs )
174- go handleNewChannel (server , ch )
220+
221+ go func () {
222+ if err := handleRequests (ctx , server , ch , reqs ); err != nil {
223+ sendError (errc , fmt .Errorf ("failed to handle requests: %w" , err ))
224+ }
225+ }()
226+
227+ handleChannel (server , ch )
175228 }
176- }
229+ }()
230+ return awaitError (ctx , errc )
177231}
178232
179- func handleNewRequests (ctx context.Context , server * Server , channel ssh.Channel , reqs <- chan * ssh.Request ) {
180- for req := range reqs {
181- if req .WantReply {
182- if err := req .Reply (true , nil ); err != nil {
183- server .errCh <- fmt .Errorf ("error replying to channel request: %w" , err )
233+ // handleRequests services the SSH channel requests channel. It replies to requests and
234+ // when stream transport requests are encountered, creates a go routine to create a
235+ // bi-directional data stream between the channel and server stream. It returns on the first error
236+ // encountered.
237+ func handleRequests (ctx context.Context , server * Server , channel ssh.Channel , reqs <- chan * ssh.Request ) error {
238+ errc := make (chan error , 1 )
239+ go func () {
240+ for req := range reqs {
241+ if req .WantReply {
242+ if err := req .Reply (true , nil ); err != nil {
243+ sendError (errc , fmt .Errorf ("error replying to channel request: %w" , err ))
244+ return
245+ }
246+ }
247+
248+ if strings .HasPrefix (req .Type , "stream-transport" ) {
249+ go func () {
250+ if err := forwardStream (ctx , server , req .Type , channel ); err != nil {
251+ sendError (errc , fmt .Errorf ("failed to forward stream: %w" , err ))
252+ }
253+ }()
184254 }
185255 }
186- if strings .HasPrefix (req .Type , "stream-transport" ) {
187- forwardStream (ctx , server , req .Type , channel )
188- }
189- }
256+ }()
257+
258+ return awaitError (ctx , errc )
259+ }
260+
261+ // concurrentStream is a concurrency safe io.ReadWriter.
262+ type concurrentStream struct {
263+ sync.RWMutex
264+ stream io.ReadWriter
265+ }
266+
267+ func newConcurrentStream (rw io.ReadWriter ) * concurrentStream {
268+ return & concurrentStream {stream : rw }
190269}
191270
192- func forwardStream (ctx context.Context , server * Server , streamName string , channel ssh.Channel ) {
271+ func (cs * concurrentStream ) Read (b []byte ) (int , error ) {
272+ cs .RLock ()
273+ defer cs .RUnlock ()
274+ return cs .stream .Read (b )
275+ }
276+
277+ func (cs * concurrentStream ) Write (b []byte ) (int , error ) {
278+ cs .Lock ()
279+ defer cs .Unlock ()
280+ return cs .stream .Write (b )
281+ }
282+
283+ // forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
284+ // runs until an error is encountered.
285+ func forwardStream (ctx context.Context , server * Server , streamName string , channel ssh.Channel ) (err error ) {
193286 simpleStreamName := strings .TrimPrefix (streamName , "stream-transport-" )
194287 stream , found := server .streams [simpleStreamName ]
195288 if ! found {
196- server .errCh <- fmt .Errorf ("stream '%s' not found" , simpleStreamName )
197- return
289+ return fmt .Errorf ("stream '%s' not found" , simpleStreamName )
198290 }
291+ defer func () {
292+ if closeErr := channel .Close (); err == nil && closeErr != io .EOF {
293+ err = closeErr
294+ }
295+ }()
199296
297+ errc := make (chan error , 2 )
200298 copy := func (dst io.Writer , src io.Reader ) {
201299 if _ , err := io .Copy (dst , src ); err != nil {
202- fmt .Println (err )
203- server .errCh <- fmt .Errorf ("io copy: %w" , err )
204- return
300+ errc <- err
205301 }
206302 }
207303
208- go copy (stream , channel )
209- go copy (channel , stream )
304+ csStream := newConcurrentStream (stream )
305+ go copy (csStream , channel )
306+ go copy (channel , csStream )
210307
211- <- ctx . Done () // TODO(josebalius): improve this
308+ return awaitError ( ctx , errc )
212309}
213310
214- func handleNewChannel (server * Server , channel ssh.Channel ) {
311+ func handleChannel (server * Server , channel ssh.Channel ) {
215312 stream := jsonrpc2 .NewBufferedStream (channel , jsonrpc2.VSCodeObjectCodec {})
216313 jsonrpc2 .NewConn (context .Background (), stream , newRPCHandler (server ))
217314}
@@ -226,20 +323,22 @@ func newRPCHandler(server *Server) *rpcHandler {
226323 return & rpcHandler {server }
227324}
228325
326+ // Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
327+ // RPC service method and if found, it invokes the handler and replies to the request.
229328func (r * rpcHandler ) Handle (ctx context.Context , conn * jsonrpc2.Conn , req * jsonrpc2.Request ) {
230329 handler , found := r .server .services [req .Method ]
231330 if ! found {
232- r .server .errCh <- fmt .Errorf ("RPC Method: '%s' not serviced" , req .Method )
331+ sendError ( r .server .errCh , fmt .Errorf ("RPC Method: '%s' not serviced" , req .Method ) )
233332 return
234333 }
235334
236335 result , err := handler (req )
237336 if err != nil {
238- r .server .errCh <- fmt .Errorf ("error handling: '%s': %w" , req .Method , err )
337+ sendError ( r .server .errCh , fmt .Errorf ("error handling: '%s': %w" , req .Method , err ) )
239338 return
240339 }
241340
242341 if err := conn .Reply (ctx , req .ID , result ); err != nil {
243- r .server .errCh <- fmt .Errorf ("error replying: %w" , err )
342+ sendError ( r .server .errCh , fmt .Errorf ("error replying: %w" , err ) )
244343 }
245344}
0 commit comments