feat: `given` pattern

Resolves #322

Summary
----
- `given_pattern` now handles cases like:
```scala
for
  given Int <- Some(1)
yield summon[Int]
```
- `corpus/patterns.txt` reformatted with `tree-sitter test -u`
pull/659/head
susliko 2023-08-07 22:27:19 +07:00
parent a2f36c2477
commit c14e2d4aad
2 changed files with 218 additions and 87 deletions

@ -1,30 +1,36 @@
========================= ================================================================================
Alternative patterns Alternative patterns
========================= ================================================================================
val x = y match { val x = y match {
case 1 | a => b case 1 | a => b
case "c" | "d" | "e" => f case "c" | "d" | "e" => f
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(val_definition (val_definition
(identifier) (identifier)
(match_expression (identifier) (case_block (match_expression
(case_clause (identifier)
(alternative_pattern (integer_literal) (identifier)) (case_block
(identifier)) (case_clause
(case_clause (alternative_pattern
(alternative_pattern (integer_literal)
(alternative_pattern (string) (string)) (identifier))
(string)) (identifier))
(identifier)))))) (case_clause
(alternative_pattern
(alternative_pattern
(string)
(string))
(string))
(identifier))))))
========================= ================================================================================
Typed patterns Typed patterns
========================= ================================================================================
val x = y match { val x = y match {
case 1 : Int => 2 case 1 : Int => 2
@ -33,29 +39,44 @@ val x = y match {
case Object.Constant => 3 case Object.Constant => 3
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(val_definition (val_definition
(identifier) (identifier)
(match_expression (identifier) (case_block (match_expression
(case_clause (identifier)
(typed_pattern (integer_literal) (type_identifier)) (integer_literal)) (case_block
(case_clause (case_clause
(typed_pattern (identifier) (compound_type (type_identifier) (type_identifier))) (typed_pattern
(identifier)) (integer_literal)
(case_clause (type_identifier))
(alternative_pattern (integer_literal))
(typed_pattern (wildcard) (type_identifier)) (case_clause
(typed_pattern (wildcard) (type_identifier))) (typed_pattern
(integer_literal)) (identifier)
(case_clause (compound_type
(stable_identifier (identifier) (identifier)) (type_identifier)
(integer_literal)))))) (type_identifier)))
(identifier))
============================ (case_clause
(alternative_pattern
(typed_pattern
(wildcard)
(type_identifier))
(typed_pattern
(wildcard)
(type_identifier)))
(integer_literal))
(case_clause
(stable_identifier
(identifier)
(identifier))
(integer_literal))))))
================================================================================
Tuple patterns Tuple patterns
============================ ================================================================================
val (a, b) = if (c) (d, e) else (f, g) val (a, b) = if (c) (d, e) else (f, g)
@ -63,24 +84,36 @@ val x = y match {
case (A, B) => X case (A, B) => X
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(val_definition (val_definition
(tuple_pattern (identifier) (identifier)) (tuple_pattern
(identifier)
(identifier))
(if_expression (if_expression
(parenthesized_expression (identifier)) (parenthesized_expression
(tuple_expression (identifier) (identifier)) (identifier))
(tuple_expression (identifier) (identifier)))) (tuple_expression
(val_definition (identifier) (identifier)
(match_expression (identifier) (identifier))
(tuple_expression
(identifier)
(identifier))))
(val_definition
(identifier)
(match_expression
(identifier)
(case_block (case_block
(case_clause (case_clause
(tuple_pattern (identifier) (identifier)) (identifier)))))) (tuple_pattern
(identifier)
(identifier))
(identifier))))))
============================ ================================================================================
Case class patterns Case class patterns
============================ ================================================================================
def showNotification(notification: Notification): String = { def showNotification(notification: Notification): String = {
notification match { notification match {
@ -93,49 +126,98 @@ def showNotification(notification: Notification): String = {
} }
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(function_definition (function_definition
(identifier) (identifier)
(parameters (parameter (identifier) (type_identifier))) (parameters
(parameter
(identifier)
(type_identifier)))
(type_identifier) (type_identifier)
(block (block
(match_expression (identifier) (case_block (match_expression
(case_clause (identifier)
(case_class_pattern (type_identifier) (identifier) (identifier) (wildcard)) (case_block
(interpolated_string_expression (identifier) (interpolated_string (interpolation (identifier)) (interpolation (identifier))))) (case_clause
(case_clause (case_class_pattern
(case_class_pattern (type_identifier) (identifier) (identifier)) (type_identifier)
(interpolated_string_expression (identifier) (interpolated_string (interpolation (identifier)) (interpolation (identifier))))) (identifier)
(case_clause (identifier)
(case_class_pattern (type_identifier) (identifier) (identifier)) (wildcard))
(interpolated_string_expression (identifier) (interpolated_string (interpolation (identifier)) (interpolation (identifier)))))))))) (interpolated_string_expression
(identifier)
(interpolated_string
(interpolation
(identifier))
(interpolation
(identifier)))))
(case_clause
(case_class_pattern
(type_identifier)
(identifier)
(identifier))
(interpolated_string_expression
(identifier)
(interpolated_string
(interpolation
(identifier))
(interpolation
(identifier)))))
(case_clause
(case_class_pattern
(type_identifier)
(identifier)
(identifier))
(interpolated_string_expression
(identifier)
(interpolated_string
(interpolation
(identifier))
(interpolation
(identifier))))))))))
============================ ================================================================================
Infix patterns Infix patterns
============================ ================================================================================
def first(x: Seq[Int]) = x match { def first(x: Seq[Int]) = x match {
case e :+ _ => Some(e) case e :+ _ => Some(e)
case _ => None case _ => None
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(function_definition (identifier) (function_definition
(parameters (parameter (identifier) (generic_type (type_identifier) (type_arguments (type_identifier))))) (identifier)
(match_expression (identifier) (parameters
(parameter
(identifier)
(generic_type
(type_identifier)
(type_arguments
(type_identifier)))))
(match_expression
(identifier)
(case_block (case_block
(case_clause (infix_pattern (identifier) (operator_identifier) (wildcard)) (case_clause
(call_expression (identifier) (arguments (identifier)))) (infix_pattern
(case_clause (wildcard) (identifier)
(operator_identifier)
(wildcard))
(call_expression
(identifier)
(arguments
(identifier))))
(case_clause
(wildcard)
(identifier)))))) (identifier))))))
============================ ================================================================================
Capture patterns Capture patterns
============================ ================================================================================
val x = y match { val x = y match {
case a @ B(1) => a case a @ B(1) => a
@ -144,7 +226,7 @@ val x = y match {
case Array(a: Type, _@_*) => y case Array(a: Type, _@_*) => y
} }
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(val_definition (val_definition
@ -153,23 +235,41 @@ val x = y match {
(identifier) (identifier)
(case_block (case_block
(case_clause (case_clause
(capture_pattern (identifier) (case_class_pattern (type_identifier) (integer_literal))) (capture_pattern
(identifier)
(case_class_pattern
(type_identifier)
(integer_literal)))
(identifier)) (identifier))
(case_clause (case_clause
(capture_pattern (identifier) (capture_pattern
(case_class_pattern (type_identifier) (identifier)
(capture_pattern (identifier) (case_class_pattern
(type_identifier)
(capture_pattern
(identifier)
(tuple_pattern (tuple_pattern
(capture_pattern (identifier) (identifier)) (capture_pattern
(typed_pattern (wildcard) (type_identifier)))))) (identifier)
(identifier))
(typed_pattern
(wildcard)
(type_identifier))))))
(identifier)) (identifier))
(case_clause (case_clause
(infix_pattern (infix_pattern
(infix_pattern (infix_pattern
(capture_pattern (identifier) (capture_pattern
(tuple_pattern (alternative_pattern (identifier) (identifier)))) (identifier)
(operator_identifier) (identifier)) (operator_identifier) (string)) (tuple_pattern
(integer_literal)) (alternative_pattern
(identifier)
(identifier))))
(operator_identifier)
(identifier))
(operator_identifier)
(string))
(integer_literal))
(case_clause (case_clause
(case_class_pattern (case_class_pattern
(type_identifier) (type_identifier)
@ -182,25 +282,53 @@ val x = y match {
(wildcard)))) (wildcard))))
(identifier)))))) (identifier))))))
============================ ================================================================================
Quoted patterns (Scala 3 syntax) Quoted patterns (Scala 3 syntax)
============================ ================================================================================
def foo = def foo =
x match x match
case '{ $boolExpr } => Some(true) case '{ $boolExpr } => Some(true)
case _ => None case _ => None
--- --------------------------------------------------------------------------------
(compilation_unit (compilation_unit
(function_definition (identifier) (function_definition
(identifier)
(indented_block (indented_block
(match_expression (identifier) (match_expression
(indented_cases (identifier)
(case_clause (indented_cases
(quote_expression (identifier)) (case_clause
(call_expression (identifier) (arguments (boolean_literal)))) (quote_expression
(case_clause (wildcard) (identifier))
(identifier))))))) (call_expression
(identifier)
(arguments
(boolean_literal))))
(case_clause
(wildcard)
(identifier)))))))
================================================================================
Given pattern (Scala 3 syntax)
================================================================================
for
given Int <- Some(1)
yield ()
--------------------------------------------------------------------------------
(compilation_unit
(for_expression
(enumerators
(enumerator
(given_pattern
(type_identifier))
(call_expression
(identifier)
(arguments
(integer_literal)))))
(unit)))

@ -954,6 +954,7 @@ module.exports = grammar({
$.infix_pattern, $.infix_pattern,
$.alternative_pattern, $.alternative_pattern,
$.typed_pattern, $.typed_pattern,
$.given_pattern,
$.quote_expression, $.quote_expression,
$.literal, $.literal,
$.wildcard, $.wildcard,
@ -995,6 +996,8 @@ module.exports = grammar({
seq(field("pattern", $._pattern), ":", field("type", $._type)), seq(field("pattern", $._pattern), ":", field("type", $._type)),
), ),
given_pattern: $ => seq("given", field("type", $._type)),
// TODO: Flatten this. // TODO: Flatten this.
alternative_pattern: $ => prec.left(-1, seq($._pattern, "|", $._pattern)), alternative_pattern: $ => prec.left(-1, seq($._pattern, "|", $._pattern)),