X Tutup
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/transformation/utils/lua-ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ export function createLocalOrExportedOrGlobalDeclaration(
const isTopLevelVariable = scope.type === ScopeType.File;

if (context.isModule || !isTopLevelVariable) {
if (scope.type === ScopeType.Switch || (!isFunctionDeclaration && hasMultipleReferences(scope, lhs))) {
if (!isFunctionDeclaration && hasMultipleReferences(scope, lhs)) {
// Split declaration and assignment of identifiers that reference themselves in their declaration
declaration = lua.createVariableDeclarationStatement(lhs, undefined, tsOriginal);
if (rhs) {
Expand All @@ -185,15 +185,13 @@ export function createLocalOrExportedOrGlobalDeclaration(
declaration = lua.createVariableDeclarationStatement(lhs, rhs, tsOriginal);
}

// Remember local variable declarations for hoisting later
if (!scope.variableDeclarations) {
scope.variableDeclarations = [];
}

scope.variableDeclarations.push(declaration);
if (!isFunctionDeclaration) {
// Remember local variable declarations for hoisting later
if (!scope.variableDeclarations) {
scope.variableDeclarations = [];
}

if (scope.type === ScopeType.Switch) {
declaration = undefined;
scope.variableDeclarations.push(declaration);
}
} else if (rhs) {
// global
Expand Down
118 changes: 85 additions & 33 deletions src/transformation/utils/scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ export interface Scope {
functionReturned?: boolean;
}

export interface HoistingResult {
statements: lua.Statement[];
hoistedStatements: lua.Statement[];
hoistedIdentifiers: lua.Identifier[];
}

const scopeStacks = new WeakMap<TransformationContext, Scope[]>();
function getScopeStack(context: TransformationContext): Scope[] {
return getOrUpdate(scopeStacks, context, () => []);
Expand Down Expand Up @@ -133,16 +139,47 @@ export function isFunctionScopeWithDefinition(scope: Scope): scope is Scope & {
return scope.node !== undefined && ts.isFunctionLike(scope.node);
}

export function performHoisting(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
export function separateHoistedStatements(context: TransformationContext, statements: lua.Statement[]): HoistingResult {
const scope = peekScope(context);
let result = statements;
result = hoistFunctionDefinitions(context, scope, result);
result = hoistVariableDeclarations(context, scope, result);
result = hoistImportStatements(scope, result);
return result;
const allHoistedStatments: lua.Statement[] = [];
const allHoistedIdentifiers: lua.Identifier[] = [];

let { unhoistedStatements, hoistedStatements, hoistedIdentifiers } = hoistFunctionDefinitions(
context,
scope,
statements
);
allHoistedStatments.push(...hoistedStatements);
allHoistedIdentifiers.push(...hoistedIdentifiers);

({ unhoistedStatements, hoistedIdentifiers } = hoistVariableDeclarations(context, scope, unhoistedStatements));
allHoistedIdentifiers.push(...hoistedIdentifiers);

({ unhoistedStatements, hoistedStatements } = hoistImportStatements(scope, unhoistedStatements));
allHoistedStatments.unshift(...hoistedStatements);

return {
statements: unhoistedStatements,
hoistedStatements: allHoistedStatments,
hoistedIdentifiers: allHoistedIdentifiers,
};
}

export function performHoisting(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] {
const result = separateHoistedStatements(context, statements);
const modifiedStatements = [...result.hoistedStatements, ...result.statements];
if (result.hoistedIdentifiers.length > 0) {
modifiedStatements.unshift(lua.createVariableDeclarationStatement(result.hoistedIdentifiers));
}
return modifiedStatements;
}

function shouldHoistSymbol(context: TransformationContext, symbolId: lua.SymbolId, scope: Scope): boolean {
// Always hoist in top-level of switch statements
if (scope.type === ScopeType.Switch) {
return true;
}

const symbolInfo = getSymbolInfo(context, symbolId);
if (!symbolInfo) {
return false;
Expand Down Expand Up @@ -183,65 +220,80 @@ function hoistVariableDeclarations(
context: TransformationContext,
scope: Scope,
statements: lua.Statement[]
): lua.Statement[] {
): { unhoistedStatements: lua.Statement[]; hoistedIdentifiers: lua.Identifier[] } {
if (!scope.variableDeclarations) {
return statements;
return { unhoistedStatements: statements, hoistedIdentifiers: [] };
}

const result = [...statements];
const hoistedLocals: lua.Identifier[] = [];
const unhoistedStatements = [...statements];
const hoistedIdentifiers: lua.Identifier[] = [];
for (const declaration of scope.variableDeclarations) {
const symbols = declaration.left.map(i => i.symbolId).filter(isNonNull);
if (symbols.some(s => shouldHoistSymbol(context, s, scope))) {
const index = result.indexOf(declaration);
assert(index > -1);
const index = unhoistedStatements.indexOf(declaration);
if (index < 0) {
continue; // statements array may not contain all statements in the scope (switch-case)
}

if (declaration.right) {
const assignment = lua.createAssignmentStatement(declaration.left, declaration.right);
lua.setNodePosition(assignment, declaration); // Preserve position info for sourcemap
result.splice(index, 1, assignment);
unhoistedStatements.splice(index, 1, assignment);
} else {
result.splice(index, 1);
unhoistedStatements.splice(index, 1);
}

hoistedLocals.push(...declaration.left);
} else if (scope.type === ScopeType.Switch) {
assert(!declaration.right);
hoistedLocals.push(...declaration.left);
hoistedIdentifiers.push(...declaration.left);
}
}

if (hoistedLocals.length > 0) {
result.unshift(lua.createVariableDeclarationStatement(hoistedLocals));
}

return result;
return { unhoistedStatements, hoistedIdentifiers };
}

function hoistFunctionDefinitions(
context: TransformationContext,
scope: Scope,
statements: lua.Statement[]
): lua.Statement[] {
): { unhoistedStatements: lua.Statement[]; hoistedStatements: lua.Statement[]; hoistedIdentifiers: lua.Identifier[] } {
if (!scope.functionDefinitions) {
return statements;
return { unhoistedStatements: statements, hoistedStatements: [], hoistedIdentifiers: [] };
}

const result = [...statements];
const hoistedFunctions: Array<lua.VariableDeclarationStatement | lua.AssignmentStatement> = [];
const unhoistedStatements = [...statements];
const hoistedStatements: lua.Statement[] = [];
const hoistedIdentifiers: lua.Identifier[] = [];
for (const [functionSymbolId, functionDefinition] of scope.functionDefinitions) {
assert(functionDefinition.definition);

if (shouldHoistSymbol(context, functionSymbolId, scope)) {
const index = result.indexOf(functionDefinition.definition);
result.splice(index, 1);
hoistedFunctions.push(functionDefinition.definition);
const index = unhoistedStatements.indexOf(functionDefinition.definition);
if (index < 0) {
continue; // statements array may not contain all statements in the scope (switch-case)
}
unhoistedStatements.splice(index, 1);

if (lua.isVariableDeclarationStatement(functionDefinition.definition)) {
// Separate function definition and variable declaration
assert(functionDefinition.definition.right);
hoistedIdentifiers.push(...functionDefinition.definition.left);
hoistedStatements.push(
lua.createAssignmentStatement(
functionDefinition.definition.left,
functionDefinition.definition.right
)
);
} else {
hoistedStatements.push(functionDefinition.definition);
}
}
}

return [...hoistedFunctions, ...result];
return { unhoistedStatements, hoistedStatements, hoistedIdentifiers };
}

function hoistImportStatements(scope: Scope, statements: lua.Statement[]): lua.Statement[] {
return scope.importStatements ? [...scope.importStatements, ...statements] : statements;
function hoistImportStatements(
scope: Scope,
statements: lua.Statement[]
): { unhoistedStatements: lua.Statement[]; hoistedStatements: lua.Statement[] } {
return { unhoistedStatements: statements, hoistedStatements: scope.importStatements ?? [] };
}
59 changes: 49 additions & 10 deletions src/transformation/visitors/switch.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as ts from "typescript";
import * as lua from "../../LuaAST";
import { FunctionVisitor, TransformationContext } from "../context";
import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope";
import { popScope, pushScope, ScopeType, separateHoistedStatements } from "../utils/scope";

const containsBreakOrReturn = (nodes: Iterable<ts.Node>): boolean => {
for (const s of nodes) {
Expand Down Expand Up @@ -55,15 +55,25 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st

// If the switch only has a default clause, wrap it in a single do.
// Otherwise, we need to generate a set of if statements to emulate the switch.
let statements: lua.Statement[] = [];
const statements: lua.Statement[] = [];
const hoistedStatements: lua.Statement[] = [];
const hoistedIdentifiers: lua.Identifier[] = [];
const clauses = statement.caseBlock.clauses;
if (clauses.length === 1 && ts.isDefaultClause(clauses[0])) {
const defaultClause = clauses[0].statements;
if (defaultClause.length) {
statements.push(lua.createDoStatement(context.transformStatements(defaultClause)));
const {
statements: defaultStatements,
hoistedStatements: defaultHoistedStatements,
hoistedIdentifiers: defaultHoistedIdentifiers,
} = separateHoistedStatements(context, context.transformStatements(defaultClause));
hoistedStatements.push(...defaultHoistedStatements);
hoistedIdentifiers.push(...defaultHoistedIdentifiers);
statements.push(lua.createDoStatement(defaultStatements));
}
} else {
// Build up the condition for each if statement
let defaultTransformed = false;
let isInitialCondition = true;
let condition: lua.Expression | undefined = undefined;
for (let i = 0; i < clauses.length; i++) {
Expand Down Expand Up @@ -124,10 +134,21 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
}

// Transform the clause and append the final break statement if necessary
const clauseStatements = context.transformStatements(clause.statements);
const {
statements: clauseStatements,
hoistedStatements: clauseHoistedStatements,
hoistedIdentifiers: clauseHoistedIdentifiers,
} = separateHoistedStatements(context, context.transformStatements(clause.statements));
if (i === clauses.length - 1 && !containsBreakOrReturn(clause.statements)) {
clauseStatements.push(lua.createBreakStatement());
}
hoistedStatements.push(...clauseHoistedStatements);
hoistedIdentifiers.push(...clauseHoistedIdentifiers);

// Remember that we transformed default clause so we don't duplicate hoisted statements later
if (ts.isDefaultClause(clause)) {
defaultTransformed = true;
}

// Push if statement for case
statements.push(lua.createIfStatement(conditionVariable, lua.createBlock(clauseStatements)));
Expand All @@ -145,11 +166,25 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
(clause, index) => index >= start && containsBreakOrReturn(clause.statements)
);

// Combine the default and all fallthrough statements
const defaultStatements: lua.Statement[] = [];
clauses
.slice(start, end >= 0 ? end + 1 : undefined)
.forEach(c => defaultStatements.push(...context.transformStatements(c.statements)));
const {
statements: defaultStatements,
hoistedStatements: defaultHoistedStatements,
hoistedIdentifiers: defaultHoistedIdentifiers,
} = separateHoistedStatements(context, context.transformStatements(clauses[start].statements));

// Only push hoisted statements if this is the first time we're transforming the default clause
if (!defaultTransformed) {
hoistedStatements.push(...defaultHoistedStatements);
hoistedIdentifiers.push(...defaultHoistedIdentifiers);
}

// Combine the fallthrough statements
for (const clause of clauses.slice(start + 1, end >= 0 ? end + 1 : undefined)) {
let statements = context.transformStatements(clause.statements);
// Drop hoisted statements as they were already added when clauses were initially transformed above
({ statements } = separateHoistedStatements(context, statements));
defaultStatements.push(...statements);
}

// Add the default clause if it has any statements
// The switch will always break on the final clause and skip execution if valid to do so
Expand All @@ -160,7 +195,11 @@ export const transformSwitchStatement: FunctionVisitor<ts.SwitchStatement> = (st
}

// Hoist the variable, function, and import statements to the top of the switch
statements = performHoisting(context, statements);
statements.unshift(...hoistedStatements);
if (hoistedIdentifiers.length > 0) {
statements.unshift(lua.createVariableDeclarationStatement(hoistedIdentifiers));
}

popScope(context);

// Add the switch expression after hoisting
Expand Down
76 changes: 76 additions & 0 deletions test/unit/__snapshots__/switch.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,82 @@ end
return ____exports"
`;

exports[`switch hoisting hoisting from default clause is not duplicated when falling through 1`] = `
"local ____exports = {}
function ____exports.__main(self)
local x = 1
local result = \\"\\"
repeat
local ____switch3 = x
local hoisted
function hoisted(self)
return \\"hoisted\\"
end
local ____cond3 = ____switch3 == 1
if ____cond3 then
result = hoisted(nil)
break
end
____cond3 = ____cond3 or (____switch3 == 2)
if ____cond3 then
result = \\"2\\"
end
if ____cond3 then
result = \\"default\\"
end
____cond3 = ____cond3 or (____switch3 == 3)
if ____cond3 then
result = \\"3\\"
break
end
do
result = \\"default\\"
result = \\"3\\"
end
until true
return result
end
return ____exports"
`;

exports[`switch hoisting hoisting from fallthrough clause after default is not duplicated 1`] = `
"local ____exports = {}
function ____exports.__main(self)
local x = 1
local result = \\"\\"
repeat
local ____switch3 = x
local hoisted
function hoisted(self)
return \\"hoisted\\"
end
local ____cond3 = ____switch3 == 1
if ____cond3 then
result = hoisted(nil)
break
end
____cond3 = ____cond3 or (____switch3 == 2)
if ____cond3 then
result = \\"2\\"
end
if ____cond3 then
result = \\"default\\"
end
____cond3 = ____cond3 or (____switch3 == 3)
if ____cond3 then
result = \\"3\\"
break
end
do
result = \\"default\\"
result = \\"3\\"
end
until true
return result
end
return ____exports"
`;

exports[`switch produces optimal output 1`] = `
"require(\\"lualib_bundle\\");
local ____exports = {}
Expand Down
Loading
X Tutup