Fix indent/outdent tracking

Problem
-------
Given something like
    class A:
      def a() =
        ()

    def a() = 1
the scanner only outdents once, so it fails to parse the above.

Solution
--------
Track `last_indentation_size` in the payload to indicate last
outdent (default: -1). If it's not -1 that means there was an outdent.
This then checks if the last_indentation_size qualified for another
outdent.
pull/481/head
Eugene Yokota 2022-12-15 12:20:50 +07:00
parent fa8d64b77e
commit 83aaa6020e
5 changed files with 85 additions and 30 deletions

@ -23,27 +23,18 @@ package a.b
package c:
object A
package d:
object A
end d
---
(compilation_unit
(package_clause (package_identifier (identifier) (identifier)))
(package_clause (package_identifier (identifier))
(template_body
(object_definition (identifier)))))
(object_definition (identifier))))
================================
Package (Scala 3 syntax with end)
================================
package a.b
package c:
object A
end c
---
(compilation_unit
(package_clause (package_identifier (identifier) (identifier)))
(package_clause (package_identifier (identifier))
(template_body
(object_definition (identifier)))))
@ -582,6 +573,32 @@ class A:
(indented_block
(val_definition (identifier) (integer_literal)) (val_definition (identifier) (integer_literal)) (infix_expression (identifier) (operator_identifier) (identifier)))))))
=======================================
Top-level Definitions (Scala 3 syntax)
=======================================
class A:
def a() =
()
def a() = 1
---
(compilation_unit
(class_definition
(identifier)
(template_body
(function_definition
(identifier)
(parameters)
(indented_block
(unit)))))
(function_definition
(identifier)
(parameters)
(integer_literal)))
=======================================
Initialization expressions
=======================================

@ -73,6 +73,7 @@ module.exports = grammar({
$.package_clause,
$.package_object,
$._definition,
$._end_marker,
),
_definition: $ => choice(
@ -236,7 +237,6 @@ module.exports = grammar({
$._indent,
$._block,
$._outdent,
optional($._end_marker),
)),
seq(
'{',
@ -248,7 +248,19 @@ module.exports = grammar({
_end_marker: $ => prec.left(PREC.end_marker, seq(
'end',
alias($.identifier, '_end_ident'),
choice(
'if',
'while',
'for',
'match',
'try',
'new',
'this',
'given',
'extension',
'val',
alias($.identifier, '_end_ident'),
),
)),
annotation: $ => prec.right(seq(
@ -626,7 +638,6 @@ module.exports = grammar({
'else',
field('alternative', $._indentable_expression),
)),
optional(seq('end', 'if')),
)),
match_expression: $ => prec.left(PREC.postfix, seq(

32
src/scanner.c vendored

@ -91,10 +91,25 @@ static bool scan_string_content(TSLexer *lexer, bool is_multiline, bool has_inte
bool tree_sitter_scala_external_scanner_scan(void *payload, TSLexer *lexer,
const bool *valid_symbols) {
ScannerStack *stack = (ScannerStack *)payload;
int prev = peekStack(stack);
unsigned newline_count = 0;
unsigned indentation_size = 0;
int indentation_size = 0;
LOG("scanner was called at column: %d\n", lexer->get_column(lexer));
// Before advancing the lexer, check if we can double outdent
if (valid_symbols[OUTDENT] &&
(lexer->lookahead == 0 || (
stack->last_indentation_size != -1 &&
prev != -1 &&
stack->last_indentation_size < prev))) {
popStack(stack);
LOG(" pop\n");
LOG(" OUTDENT\n");
lexer->result_symbol = OUTDENT;
return true;
}
stack->last_indentation_size = -1;
while (iswspace(lexer->lookahead)) {
if (lexer->lookahead == '\n') {
newline_count++;
@ -104,27 +119,32 @@ bool tree_sitter_scala_external_scanner_scan(void *payload, TSLexer *lexer,
indentation_size++;
lexer->advance(lexer, true);
}
int prev = peekStack(stack);
printStack(stack, "before");
printStack(stack, " before");
if (valid_symbols[INDENT] && newline_count > 0 &&
(isEmptyStack(stack) || indentation_size > peekStack(stack))) {
pushStack(stack, indentation_size);
lexer->result_symbol = INDENT;
LOG(" INDENT\n");
return true;
}
// This saves the newline_count into the stack since
// sometimes we need to outdent multiple times.
if (valid_symbols[OUTDENT] &&
(lexer->lookahead == 0 || (
newline_count > 0 && prev != -1 && indentation_size < prev))) {
popStack(stack);
LOG("pop\n");
LOG(" pop\n");
LOG(" OUTDENT\n");
lexer->result_symbol = OUTDENT;
stack->last_indentation_size = indentation_size;
return true;
}
printStack(stack, "after");
printStack(stack, " after");
LOG("indentation_size: %d, newline_count: %d, column: %d, indent_is_valid: %d, dedent_is_valid: %d\n", indentation_size,
LOG(" indentation_size: %d, newline_count: %d, column: %d, indent_is_valid: %d, dedent_is_valid: %d\n", indentation_size,
newline_count, lexer->get_column(lexer), valid_symbols[INDENT], valid_symbols[OUTDENT]);
if (valid_symbols[AUTOMATIC_SEMICOLON] && newline_count > 0) {

17
src/stack.h vendored

@ -5,9 +5,9 @@
#include <string.h>
#ifdef DEBUG
#define LOG(args...) fprintf(stderr, args);
#define LOG(...) fprintf(stderr, __VA_ARGS__)
#else
#define LOG(args...)
#define LOG(...)
#endif
#define STACK_SIZE 1024
@ -15,12 +15,14 @@
typedef struct ScannerStack {
unsigned int stack[STACK_SIZE];
int top;
int last_indentation_size;
} ScannerStack;
ScannerStack* createStack() {
ScannerStack* ptr = (ScannerStack*) malloc(sizeof(ScannerStack));
ptr -> top = 0;
ptr -> last_indentation_size = -1;
memset(ptr -> stack, STACK_SIZE, (0));
return ptr;
@ -58,9 +60,10 @@ void printStack(ScannerStack *stack, char *msg) {
unsigned serialiseStack(ScannerStack *stack, char *buf) {
unsigned elements = isEmptyStack(stack) ? 0 : stack->top;
unsigned result_length = elements * sizeof(int);
unsigned result_length = (elements + 1) * sizeof(int);
int *placement = (int *)buf;
memcpy(placement, stack->stack, elements * sizeof(int));
placement[elements] = stack->last_indentation_size;
return result_length;
}
@ -69,10 +72,14 @@ void deserialiseStack(ScannerStack* stack, const char* buf, unsigned n) {
if (n != 0) {
int *intBuf = (int *)buf;
int elements = n / sizeof(int);
unsigned elements = n / sizeof(int) - 1;
stack->top = elements;
memcpy(stack->stack, intBuf, elements * sizeof(int));
stack->last_indentation_size = intBuf[elements];
}
}
void resetStack(ScannerStack *p) { p->top = 0; }
void resetStack(ScannerStack *p) {
p->top = 0;
p->last_indentation_size = -1;
}

@ -32,11 +32,11 @@ int main() {
pushStack(stack, i);
}
assert(serialiseStack(stack, buf) == sizeof(int) * 250);
assert(serialiseStack(stack, buf) == sizeof(int) * 251);
ScannerStack *newStack = createStack();
deserialiseStack(newStack, buf, sizeof(int) * 250);
deserialiseStack(newStack, buf, sizeof(int) * 251);
assert(newStack -> top == 250);
assert(popStack(newStack) == 249);