package liveshare
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
const mockClientName = "liveshare-client"
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts = append(
opts,
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating server: %w", err)
}
session, err := Connect(context.Background(), Options{
ClientName: mockClientName,
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Logger: newMockLogger(),
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
}
return testServer, session, nil
}
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshaling request: %w", err)
}
if len(args) < 3 {
return nil, errors.New("not enough arguments to start sharing")
}
if port, ok := args[0].(float64); !ok {
return nil, errors.New("port argument is not an int")
} else if port != float64(serverPort) {
return nil, errors.New("port does not match serverPort")
}
if protocol, ok := args[1].(string); !ok {
return nil, errors.New("protocol argument is not a string")
} else if protocol != serverProtocol {
return nil, errors.New("protocol does not match serverProtocol")
}
if browseURL, ok := args[2].(string); !ok {
return nil, errors.New("browse url is not a string")
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
if err != nil {
t.Errorf("error creating mock session: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %w", err)
}
if streamID.name == "" || streamID.condition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerGetSharedServers(t *testing.T) {
sharedServer := Port{
SourcePort: 2222,
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return []*Port{&sharedServer}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating mock session: %v", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := session.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %w", err)
}
if len(ports) < 1 {
done <- errors.New("not enough ports returned")
}
if ports[0].SourcePort != sharedServer.SourcePort {
done <- errors.New("source port does not match")
}
if ports[0].StreamName != sharedServer.StreamName {
done <- errors.New("stream name does not match")
}
if ports[0].StreamCondition != sharedServer.StreamCondition {
done <- errors.New("stream condiion does not match")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerUpdateSharedServerPrivacy(t *testing.T) {
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if port, ok := req[0].(float64); ok {
if port != 80.0 {
return nil, errors.New("port param is not expected value")
}
} else {
return nil, errors.New("port param is not a float64")
}
if privacy, ok := req[1].(string); ok {
if privacy != "public" {
return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy)
}
} else {
return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1])
}
return nil, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating mock session: %v", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.UpdateSharedServerPrivacy(ctx, 80, "public")
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestInvalidHostKey(t *testing.T) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts := []livesharetest.ServerOption{
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
}
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
t.Errorf("error creating server: %v", err)
}
_, err = Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err == nil {
t.Error("expected invalid host key error, got: nil")
}
}
func TestKeepAliveNonBlocking(t *testing.T) {
session := &Session{keepAliveReason: make(chan string, 1)}
for i := 0; i < 2; i++ {
session.keepAlive("io")
}
// if keepAlive blocks, we'll never reach this and timeout the test
// timing out
}
func TestNotifyHostOfActivity(t *testing.T) {
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.notifyHostOfActivity(ctx, "input")
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestSessionHeartbeat(t *testing.T) {
var (
requestsMu sync.Mutex
requests int
)
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
requestsMu.Lock()
requests++
requestsMu.Unlock()
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
logger := newMockLogger()
session.logger = logger
go session.heartbeat(ctx, 50*time.Millisecond)
go func() {
session.keepAlive("input")
<-time.Tick(200 * time.Millisecond)
session.keepAlive("input")
<-time.Tick(100 * time.Millisecond)
done <- struct{}{}
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case <-done:
activityCount := strings.Count(logger.String(), "input")
if activityCount != 2 {
t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount)
}
requestsMu.Lock()
rc := requests
requestsMu.Unlock()
if rc != 2 {
t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
}
return
}
}
type mockLogger struct {
sync.Mutex
buf *bytes.Buffer
}
func newMockLogger() *mockLogger {
return &mockLogger{buf: new(bytes.Buffer)}
}
func (m *mockLogger) Printf(format string, v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintf(format, v...))
}
func (m *mockLogger) Println(v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintln(v...))
}
func (m *mockLogger) String() string {
m.Lock()
defer m.Unlock()
return m.buf.String()
}