X Tutup
Skip to content
Merged
214 changes: 109 additions & 105 deletions language-extensions/index.d.ts

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/lualib/Iterator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function __TS__IteratorStringStep(this: string, index: number): [number, string]
function __TS__Iterator<T>(
this: void,
iterable: string | GeneratorIterator | Iterable<T> | readonly T[]
): [(...args: any[]) => [any, any] | [], ...any[]] {
): [(...args: any[]) => [any, any] | [], ...any[]] | LuaIterable<LuaMultiReturn<[number, T]>> {
if (typeof iterable === "string") {
return [__TS__IteratorStringStep, iterable, 0];
} else if ("____coroutine" in iterable) {
Expand All @@ -36,6 +36,6 @@ function __TS__Iterator<T>(
const iterator = iterable[Symbol.iterator]();
return [__TS__IteratorIteratorStep, iterator];
} else {
return ipairs(iterable as readonly T[]) as any;
return ipairs(iterable as readonly T[]);
}
}
8 changes: 1 addition & 7 deletions src/lualib/declarations/global.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,4 @@ declare function unpack<T>(list: T[], i?: number, j?: number): T[];
declare function select<T>(index: number, ...args: T[]): T;
declare function select<T>(index: "#", ...args: T[]): number;

/**
* @luaIterator
* @tupleReturn
*/
type LuaTupleIterator<T extends any[]> = Iterable<T> & { " LuaTupleIterator": never };

declare function ipairs<T>(t: Record<number, T>): LuaTupleIterator<[number, T]>;
declare function ipairs<T>(t: Record<number, T>): LuaIterable<LuaMultiReturn<[number, T]>, Record<number, T>>;
4 changes: 4 additions & 0 deletions src/transformation/utils/diagnostics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ export const luaIteratorForbiddenUsage = createErrorDiagnosticFactory(
"the '@tupleReturn' annotation."
);

export const invalidMultiIterableWithoutDestructuring = createErrorDiagnosticFactory(
"LuaIterable with a LuaMultiReturn return value type must be destructured."
);

export const unsupportedAccessorInObjectLiteral = createErrorDiagnosticFactory(
"Accessors in object literal are not supported."
);
Expand Down
123 changes: 58 additions & 65 deletions src/transformation/utils/language-extensions.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import * as ts from "typescript";
import * as path from "path";
import { TransformationContext } from "../context";

export enum ExtensionKind {
MultiFunction = "MultiFunction",
MultiType = "MultiType",
RangeFunction = "RangeFunction",
IterableType = "IterableType",
AdditionOperatorType = "AdditionOperatorType",
AdditionOperatorMethodType = "AdditionOperatorMethodType",
SubtractionOperatorType = "SubtractionOperatorType",
Expand Down Expand Up @@ -43,74 +44,66 @@ export enum ExtensionKind {
LengthOperatorMethodType = "LengthOperatorMethodType",
}

const functionNameToExtensionKind: { [name: string]: ExtensionKind } = {
$multi: ExtensionKind.MultiFunction,
$range: ExtensionKind.RangeFunction,
const extensionKindToFunctionName: { [T in ExtensionKind]?: string } = {
[ExtensionKind.MultiFunction]: "$multi",
[ExtensionKind.RangeFunction]: "$range",
};

const typeNameToExtensionKind: { [name: string]: ExtensionKind } = {
LuaMultiReturn: ExtensionKind.MultiType,
LuaAddition: ExtensionKind.AdditionOperatorType,
LuaAdditionMethod: ExtensionKind.AdditionOperatorMethodType,
LuaSubtraction: ExtensionKind.SubtractionOperatorType,
LuaSubtractionMethod: ExtensionKind.SubtractionOperatorMethodType,
LuaMultiplication: ExtensionKind.MultiplicationOperatorType,
LuaMultiplicationMethod: ExtensionKind.MultiplicationOperatorMethodType,
LuaDivision: ExtensionKind.DivisionOperatorType,
LuaDivisionMethod: ExtensionKind.DivisionOperatorMethodType,
LuaModulo: ExtensionKind.ModuloOperatorType,
LuaModuloMethod: ExtensionKind.ModuloOperatorMethodType,
LuaPower: ExtensionKind.PowerOperatorType,
LuaPowerMethod: ExtensionKind.PowerOperatorMethodType,
LuaFloorDivision: ExtensionKind.FloorDivisionOperatorType,
LuaFloorDivisionMethod: ExtensionKind.FloorDivisionOperatorMethodType,
LuaBitwiseAnd: ExtensionKind.BitwiseAndOperatorType,
LuaBitwiseAndMethod: ExtensionKind.BitwiseAndOperatorMethodType,
LuaBitwiseOr: ExtensionKind.BitwiseOrOperatorType,
LuaBitwiseOrMethod: ExtensionKind.BitwiseOrOperatorMethodType,
LuaBitwiseExclusiveOr: ExtensionKind.BitwiseExclusiveOrOperatorType,
LuaBitwiseExclusiveOrMethod: ExtensionKind.BitwiseExclusiveOrOperatorMethodType,
LuaBitwiseLeftShift: ExtensionKind.BitwiseLeftShiftOperatorType,
LuaBitwiseLeftShiftMethod: ExtensionKind.BitwiseLeftShiftOperatorMethodType,
LuaBitwiseRightShift: ExtensionKind.BitwiseRightShiftOperatorType,
LuaBitwiseRightShiftMethod: ExtensionKind.BitwiseRightShiftOperatorMethodType,
LuaConcat: ExtensionKind.ConcatOperatorType,
LuaConcatMethod: ExtensionKind.ConcatOperatorMethodType,
LuaLessThan: ExtensionKind.LessThanOperatorType,
LuaLessThanMethod: ExtensionKind.LessThanOperatorMethodType,
LuaGreaterThan: ExtensionKind.GreaterThanOperatorType,
LuaGreaterThanMethod: ExtensionKind.GreaterThanOperatorMethodType,
LuaNegation: ExtensionKind.NegationOperatorType,
LuaNegationMethod: ExtensionKind.NegationOperatorMethodType,
LuaBitwiseNot: ExtensionKind.BitwiseNotOperatorType,
LuaBitwiseNotMethod: ExtensionKind.BitwiseNotOperatorMethodType,
LuaLength: ExtensionKind.LengthOperatorType,
LuaLengthMethod: ExtensionKind.LengthOperatorMethodType,
const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
[ExtensionKind.MultiFunction]: "__luaMultiFunctionBrand",
[ExtensionKind.MultiType]: "__luaMultiReturnBrand",
[ExtensionKind.RangeFunction]: "__luaRangeFunctionBrand",
[ExtensionKind.IterableType]: "__luaIterableBrand",
[ExtensionKind.AdditionOperatorType]: "__luaAdditionBrand",
[ExtensionKind.AdditionOperatorMethodType]: "__luaAdditionMethodBrand",
[ExtensionKind.SubtractionOperatorType]: "__luaSubtractionBrand",
[ExtensionKind.SubtractionOperatorMethodType]: "__luaSubtractionMethodBrand",
[ExtensionKind.MultiplicationOperatorType]: "__luaMultiplicationBrand",
[ExtensionKind.MultiplicationOperatorMethodType]: "__luaMultiplicationMethodBrand",
[ExtensionKind.DivisionOperatorType]: "__luaDivisionBrand",
[ExtensionKind.DivisionOperatorMethodType]: "__luaDivisionMethodBrand",
[ExtensionKind.ModuloOperatorType]: "__luaModuloBrand",
[ExtensionKind.ModuloOperatorMethodType]: "__luaModuloMethodBrand",
[ExtensionKind.PowerOperatorType]: "__luaPowerBrand",
[ExtensionKind.PowerOperatorMethodType]: "__luaPowerMethodBrand",
[ExtensionKind.FloorDivisionOperatorType]: "__luaFloorDivisionBrand",
[ExtensionKind.FloorDivisionOperatorMethodType]: "__luaFloorDivisionMethodBrand",
[ExtensionKind.BitwiseAndOperatorType]: "__luaBitwiseAndBrand",
[ExtensionKind.BitwiseAndOperatorMethodType]: "__luaBitwiseAndMethodBrand",
[ExtensionKind.BitwiseOrOperatorType]: "__luaBitwiseOrBrand",
[ExtensionKind.BitwiseOrOperatorMethodType]: "__luaBitwiseOrMethodBrand",
[ExtensionKind.BitwiseExclusiveOrOperatorType]: "__luaBitwiseExclusiveOrBrand",
[ExtensionKind.BitwiseExclusiveOrOperatorMethodType]: "__luaBitwiseExclusiveOrMethodBrand",
[ExtensionKind.BitwiseLeftShiftOperatorType]: "__luaBitwiseLeftShiftBrand",
[ExtensionKind.BitwiseLeftShiftOperatorMethodType]: "__luaBitwiseLeftShiftMethodBrand",
[ExtensionKind.BitwiseRightShiftOperatorType]: "__luaBitwiseRightShiftBrand",
[ExtensionKind.BitwiseRightShiftOperatorMethodType]: "__luaBitwiseRightShiftMethodBrand",
[ExtensionKind.ConcatOperatorType]: "__luaConcatBrand",
[ExtensionKind.ConcatOperatorMethodType]: "__luaConcatMethodBrand",
[ExtensionKind.LessThanOperatorType]: "__luaLessThanBrand",
[ExtensionKind.LessThanOperatorMethodType]: "__luaLessThanMethodBrand",
[ExtensionKind.GreaterThanOperatorType]: "__luaGreaterThanBrand",
[ExtensionKind.GreaterThanOperatorMethodType]: "__luaGreaterThanMethodBrand",
[ExtensionKind.NegationOperatorType]: "__luaNegationBrand",
[ExtensionKind.NegationOperatorMethodType]: "__luaNegationMethodBrand",
[ExtensionKind.BitwiseNotOperatorType]: "__luaBitwiseNotBrand",
[ExtensionKind.BitwiseNotOperatorMethodType]: "__luaBitwiseNotMethodBrand",
[ExtensionKind.LengthOperatorType]: "__luaLengthBrand",
[ExtensionKind.LengthOperatorMethodType]: "__luaLengthMethodBrand",
};

function isSourceFileFromLanguageExtensions(sourceFile: ts.SourceFile): boolean {
const extensionDirectory = path.resolve(__dirname, "../../../language-extensions");
const sourceFileDirectory = path.dirname(path.normalize(sourceFile.fileName));
return extensionDirectory === sourceFileDirectory;
export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): boolean {
const typeBrand = extensionKindToTypeBrand[extensionKind];
return typeBrand !== undefined && type.getProperty(typeBrand) !== undefined;
}

export function getExtensionKind(declaration: ts.Declaration): ExtensionKind | undefined {
const sourceFile = declaration.getSourceFile();
if (isSourceFileFromLanguageExtensions(sourceFile)) {
if (ts.isFunctionDeclaration(declaration) && declaration.name?.text) {
const extensionKind = functionNameToExtensionKind[declaration.name.text];
if (extensionKind) {
return extensionKind;
}
}

if (ts.isTypeAliasDeclaration(declaration)) {
const extensionKind = typeNameToExtensionKind[declaration.name.text];
if (extensionKind) {
return extensionKind;
}
}

throw new Error("Unknown extension kind");
}
export function isExtensionFunction(
context: TransformationContext,
symbol: ts.Symbol,
extensionKind: ExtensionKind
): boolean {
return (
symbol.getName() === extensionKindToFunctionName[extensionKind] &&
symbol.declarations.some(d => isExtensionType(context.checker.getTypeAtLocation(d), extensionKind))
);
}
82 changes: 82 additions & 0 deletions src/transformation/visitors/language-extensions/iterable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import * as ts from "typescript";
import * as lua from "../../../LuaAST";
import * as extensions from "../../utils/language-extensions";
import { TransformationContext } from "../../context";
import { getVariableDeclarationBinding, transformForInitializer } from "../loops/utils";
import { transformArrayBindingElement } from "../variable-declaration";
import { invalidMultiIterableWithoutDestructuring } from "../../utils/diagnostics";
import { cast } from "../../../utils";
import { isMultiReturnType } from "./multi";

export function isIterableType(type: ts.Type): boolean {
return extensions.isExtensionType(type, extensions.ExtensionKind.IterableType);
}

export function returnsIterableType(context: TransformationContext, node: ts.CallExpression): boolean {
const signature = context.checker.getResolvedSignature(node);
const type = signature?.getReturnType();
return type ? isIterableType(type) : false;
}

export function isIterableExpression(context: TransformationContext, expression: ts.Expression): boolean {
const type = context.checker.getTypeAtLocation(expression);
return isIterableType(type);
}

function transformForOfMultiIterableStatement(
context: TransformationContext,
statement: ts.ForOfStatement,
block: lua.Block
): lua.Statement {
const luaIterator = context.transformExpression(statement.expression);
let identifiers: lua.Identifier[] = [];

if (ts.isVariableDeclarationList(statement.initializer)) {
// Variables declared in for loop
// for ${initializer} in ${iterable} do
const binding = getVariableDeclarationBinding(context, statement.initializer);
if (ts.isArrayBindingPattern(binding)) {
identifiers = binding.elements.map(e => transformArrayBindingElement(context, e));
} else {
context.diagnostics.push(invalidMultiIterableWithoutDestructuring(binding));
}
} else if (ts.isArrayLiteralExpression(statement.initializer)) {
// Variables NOT declared in for loop - catch iterator values in temps and assign
// for ____value0 in ${iterable} do
// ${initializer} = ____value0
identifiers = statement.initializer.elements.map((_, i) => lua.createIdentifier(`____value${i}`));
if (identifiers.length > 0) {
block.statements.unshift(
lua.createAssignmentStatement(
statement.initializer.elements.map(e =>
cast(context.transformExpression(e), lua.isAssignmentLeftHandSideExpression)
),
identifiers
)
);
}
} else {
context.diagnostics.push(invalidMultiIterableWithoutDestructuring(statement.initializer));
}

if (identifiers.length === 0) {
identifiers.push(lua.createAnonymousIdentifier());
}

return lua.createForInStatement(block, identifiers, [luaIterator], statement);
}

export function transformForOfIterableStatement(
context: TransformationContext,
statement: ts.ForOfStatement,
block: lua.Block
): lua.Statement {
const type = context.checker.getTypeAtLocation(statement.expression);
if (type.aliasTypeArguments?.length === 2 && isMultiReturnType(type.aliasTypeArguments[0])) {
return transformForOfMultiIterableStatement(context, statement, block);
}

const luaIterator = context.transformExpression(statement.expression);
const identifier = transformForInitializer(context, statement.initializer, block);
return lua.createForInStatement(block, [identifier], [luaIterator], statement);
}
26 changes: 12 additions & 14 deletions src/transformation/visitors/language-extensions/multi.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import * as ts from "typescript";
import * as extensions from "../../utils/language-extensions";
import { TransformationContext } from "../../context";
import { invalidMultiFunctionUse } from "../../utils/diagnostics";
import { findFirstNodeAbove } from "../../utils/typescript";

const isMultiFunctionDeclaration = (declaration: ts.Declaration): boolean =>
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.MultiFunction;

const isMultiTypeDeclaration = (declaration: ts.Declaration): boolean =>
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.MultiType;
import { isIterableExpression } from "./iterable";
import { invalidMultiFunctionUse } from "../../utils/diagnostics";

export function isMultiReturnType(type: ts.Type): boolean {
return type.aliasSymbol?.declarations?.some(isMultiTypeDeclaration) ?? false;
return extensions.isExtensionType(type, extensions.ExtensionKind.MultiType);
}

export function isMultiFunctionCall(context: TransformationContext, expression: ts.CallExpression): boolean {
const type = context.checker.getTypeAtLocation(expression.expression);
return type.symbol?.declarations?.some(isMultiFunctionDeclaration) ?? false;
return isMultiFunctionNode(context, expression.expression);
}

export function returnsMultiType(context: TransformationContext, node: ts.CallExpression): boolean {
Expand All @@ -30,8 +24,8 @@ export function isMultiReturnCall(context: TransformationContext, expression: ts
}

export function isMultiFunctionNode(context: TransformationContext, node: ts.Node): boolean {
const type = context.checker.getTypeAtLocation(node);
return type.symbol?.declarations?.some(isMultiFunctionDeclaration) ?? false;
const symbol = context.checker.getSymbolAtLocation(node);
return symbol ? extensions.isExtensionFunction(context, symbol, extensions.ExtensionKind.MultiFunction) : false;
}

export function isInMultiReturnFunction(context: TransformationContext, node: ts.Node) {
Expand Down Expand Up @@ -86,6 +80,11 @@ export function shouldMultiReturnCallBeWrapped(context: TransformationContext, n
return false;
}

// LuaIterable in for...of
if (ts.isForOfStatement(node.parent) && isIterableExpression(context, node)) {
return false;
}

return true;
}

Expand All @@ -99,8 +98,7 @@ export function findMultiAssignmentViolations(
if (!ts.isShorthandPropertyAssignment(element)) continue;
const valueSymbol = context.checker.getShorthandAssignmentValueSymbol(element);
if (valueSymbol) {
const declaration = valueSymbol.valueDeclaration;
if (declaration && isMultiFunctionDeclaration(declaration)) {
if (extensions.isExtensionFunction(context, valueSymbol, extensions.ExtensionKind.MultiFunction)) {
context.diagnostics.push(invalidMultiFunctionUse(element));
result.push(element);
}
Expand Down
21 changes: 4 additions & 17 deletions src/transformation/visitors/language-extensions/operators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ const unaryOperatorMappings = new Map<extensions.ExtensionKind, lua.UnaryOperato
[extensions.ExtensionKind.LengthOperatorMethodType, lua.SyntaxKind.LengthOperator],
]);

const operatorMapExtensions = new Set<extensions.ExtensionKind>([
...binaryOperatorMappings.keys(),
...unaryOperatorMappings.keys(),
]);
const operatorMapExtensions = [...binaryOperatorMappings.keys(), ...unaryOperatorMappings.keys()];

const bitwiseOperatorMapExtensions = new Set<extensions.ExtensionKind>([
extensions.ExtensionKind.BitwiseAndOperatorType,
Expand Down Expand Up @@ -84,25 +81,15 @@ function getOperatorMapExtensionKindForCall(context: TransformationContext, node
if (!typeDeclaration) {
return;
}
const mapping = extensions.getExtensionKind(typeDeclaration);
if (mapping !== undefined && operatorMapExtensions.has(mapping)) {
return mapping;
}
}

function isOperatorMapDeclaration(declaration: ts.Declaration) {
const typeDeclaration = getTypeDeclaration(declaration);
if (typeDeclaration) {
const extensionKind = extensions.getExtensionKind(typeDeclaration);
return extensionKind !== undefined ? operatorMapExtensions.has(extensionKind) : false;
}
const type = context.checker.getTypeFromTypeNode(typeDeclaration.type);
return operatorMapExtensions.find(extensionKind => extensions.isExtensionType(type, extensionKind));
}

function isOperatorMapType(context: TransformationContext, type: ts.Type): boolean {
if (type.isUnionOrIntersection()) {
return type.types.some(t => isOperatorMapType(context, t));
} else {
return type.symbol?.declarations?.some(isOperatorMapDeclaration);
return operatorMapExtensions.some(extensionKind => extensions.isExtensionType(type, extensionKind));
}
}

Expand Down
8 changes: 2 additions & 6 deletions src/transformation/visitors/language-extensions/range.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@ import { transformArguments } from "../call";
import { assert } from "../../../utils";
import { invalidRangeControlVariable } from "../../utils/diagnostics";

const isRangeFunctionDeclaration = (declaration: ts.Declaration): boolean =>
extensions.getExtensionKind(declaration) === extensions.ExtensionKind.RangeFunction;

export function isRangeFunction(context: TransformationContext, expression: ts.CallExpression): boolean {
const type = context.checker.getTypeAtLocation(expression.expression);
return type.symbol?.declarations?.some(isRangeFunctionDeclaration) ?? false;
return isRangeFunctionNode(context, expression.expression);
}

export function isRangeFunctionNode(context: TransformationContext, node: ts.Node): boolean {
const symbol = context.checker.getSymbolAtLocation(node);
return symbol?.declarations?.some(isRangeFunctionDeclaration) ?? false;
return symbol ? extensions.isExtensionFunction(context, symbol, extensions.ExtensionKind.RangeFunction) : false;
}

function getControlVariable(context: TransformationContext, statement: ts.ForOfStatement) {
Expand Down
Loading
X Tutup