From e5b15bf93f36a974500a470847e611e019205cde Mon Sep 17 00:00:00 2001
From: simon petit <nomisp96@hotmail.fr>
Date: Mon, 25 Nov 2024 10:36:33 +0000
Subject: [PATCH] Allowing cascading joins

---
 lib/ast.ml         | 19 +++++++++++--
 parser/parser.mly  | 70 +++++++++++++++++++++++-----------------------
 test/SQL_parser.ml | 46 ++++++++++++++++++++++++++++--
 3 files changed, 96 insertions(+), 39 deletions(-)

diff --git a/lib/ast.ml b/lib/ast.ml
index 3f37dc3..5debe2f 100644
--- a/lib/ast.ml
+++ b/lib/ast.ml
@@ -6,7 +6,7 @@ and column =
   | Column of string
 and table = 
   | Table of string
-  | Join of table * join_type * table
+  | Join of table * join_type * table * condition option
 and join_type =
   | Inner
   | Left
@@ -15,4 +15,19 @@ and join_type =
   | Cross
   | Union
   | Natural
-
+and condition = 
+  | Condition of string * comparison
+  | And of condition * condition
+  | Or of condition * condition
+  | Not of condition
+and comparison = 
+  | Comparison of operator * string
+and operator =
+  | Equals
+  | NotEquals
+  | LessThan
+  | GreaterThan
+  | LessEquals
+  | GreaterEquals
+and search_condition =
+  | Search of string
diff --git a/parser/parser.mly b/parser/parser.mly
index 2f2afc9..184c70b 100644
--- a/parser/parser.mly
+++ b/parser/parser.mly
@@ -53,7 +53,6 @@ table_reference :
 (*  | table_primary_or_joined_table sample_clause { $1 } *)
 
 table_primary_or_joined_table:
- (* | table_primary { Table($1) }*)
   | table_primary { $1 }
   | joined_table { $1 }
 
@@ -70,24 +69,24 @@ joined_table :
   | union_join { $1 }
 
 cross_join:
-  | table_reference CROSS JOIN table_primary { Join($1, Cross, $4) }
+  | table_reference CROSS JOIN table_primary { Join($1, Cross, $4, None) }
 
 qualified_join:
-  | table_reference JOIN table_reference join_specification { Join($1, Left, $3) }
-  | table_reference join_type JOIN table_reference join_specification { Join($1, $2, $4) }
+  | table_reference JOIN table_reference join_specification { Join($1, Left, $3, $4) }
+  | table_reference join_type JOIN table_reference join_specification { Join($1, $2, $4, $5) }
 
 join_specification:
-  | join_condition {}
+  | join_condition { $1 }
 
 join_condition:
-  | ON search_condition {}
+  | ON search_condition { Some($2) }
 
 natural_join:
-  | table_reference NATURAL JOIN table_primary { Join($1, Natural, $4) }
-  | table_reference NATURAL join_type JOIN table_primary { Join($1, Natural, $5) }
+  | table_reference NATURAL JOIN table_primary { Join($1, Natural, $4, None) }
+  | table_reference NATURAL join_type JOIN table_primary { Join($1, Natural, $5, None) }
 
 union_join:
-  | table_reference UNION JOIN table_primary { Join($1,Union, $4) }
+  | table_reference UNION JOIN table_primary { Join($1, Union, $4, None) }
 
 table_name :
   | IDENT { Table($1) }
@@ -107,43 +106,44 @@ where_clause :
   | WHERE search_condition { }
 
 search_condition:
-  | IDENT EQUALS_OPERATOR IDENT {}
+  (*| IDENT EQUALS_OPERATOR IDENT {}*)
+  | boolean_value_expression { $1 }
 
 boolean_value_expression:
-  | boolean_term {}
-  | boolean_value_expression OR boolean_term {}
+  | boolean_term { $1 }
+  | boolean_value_expression OR boolean_term { Or($1, $3) }
  
 boolean_term:
-  | boolean_factor {}
-  | boolean_term AND boolean_factor {}
+  | boolean_factor { $1 }
+  | boolean_term AND boolean_factor { And($1, $3) }
 
 boolean_factor:
-  | boolean_test {}
-  | NOT boolean_test {}
+  | boolean_test { $1 }
+  | NOT boolean_test { Not($2) }
 
 boolean_test:
-  | boolean_primary {}
+  | boolean_primary { $1 }
 
 boolean_primary :
-  | predicate {}
-  | boolean_predicand {}
+  | predicate { $1 }
+  (*| boolean_predicand {}*)
 
 predicate :
-  | comparison_predicate {}
+  | comparison_predicate { $1 }
 
 comparison_predicate :
-  | row_value_predicand comparison_predicate_part2 {}
+  | row_value_predicand comparison_predicate_part2 { Condition($1, $2) }
 
 comparison_predicate_part2:
-  | comp_op row_value_predicand {}
+  | comp_op row_value_predicand { Comparison($1, $2) }
 
 comp_op :
-  | EQUALS_OPERATOR {}
-  | not_equals_operator {}
-  | LESS_THAN_OPERATOR {}
-  | GREATER_THAN_OPERATOR {}
-  | less_than_or_equals_operator {}
-  | greater_than_or_equals_operator {}
+  | EQUALS_OPERATOR { Equals }
+  | not_equals_operator { NotEquals }
+  | LESS_THAN_OPERATOR { LessThan }
+  | GREATER_THAN_OPERATOR { GreaterThan }
+  | less_than_or_equals_operator { LessEquals }
+  | greater_than_or_equals_operator { GreaterEquals }
 
 not_equals_operator :
   | LESS_THAN_OPERATOR GREATER_THAN_OPERATOR {}
@@ -155,23 +155,23 @@ greater_than_or_equals_operator:
   | GREATER_THAN_OPERATOR EQUALS_OPERATOR {}
 
 row_value_predicand: 
-  | row_value_special_case {}
+  | row_value_special_case { $1 }
 
 row_value_special_case :
-  | nonparenthesized_value_expression_primary {}
+  | nonparenthesized_value_expression_primary { $1 }
 
 nonparenthesized_value_expression_primary:
-  | column_reference {}
+  | column_reference { $1 }
 
 column_reference:
-  | basic_identifier_chain {}
+  | basic_identifier_chain { $1 }
 
 basic_identifier_chain:
-  | identifier_chain {}
+  | identifier_chain { $1 }
 
 identifier_chain:
-  | IDENT {}
-  | identifier_chain DOT IDENT  {}
+  | IDENT { $1 }
+  (*| identifier_chain DOT IDENT  {}*)
 
 boolean_predicand:
   | nonparenthesized_value_expression_primary {}
diff --git a/test/SQL_parser.ml b/test/SQL_parser.ml
index a09f19c..eb1d4c3 100644
--- a/test/SQL_parser.ml
+++ b/test/SQL_parser.ml
@@ -7,5 +7,47 @@ let parse query =
 let () = 
   assert(parse "SELECT ab FROM b1"  = Query(Select([Column("ab")], [Table "b1"])));
   assert(parse "SELECT * FROM b1"  = Query(Select([Asterisk], [Table "b1"])));
-  assert(parse "SELECT * FROM t1 CROSS JOIN t2"  = Query(Select([Asterisk], [Join(Table("t1"), Cross, Table("t2"))])));
-  assert(parse "SELECT * FROM t1 JOIN t2 ON a = b"  = Query(Select([Asterisk], [Join(Table("t1"), Left, Table("t2"))])));
+  assert(parse "SELECT * FROM t1 CROSS JOIN t2"  = Query(Select([Asterisk], [Join(Table("t1"), Cross, Table("t2"), None)])));
+  assert(parse "SELECT * FROM t1 JOIN t2 ON a = b"  = Query(
+    Select([Asterisk], [
+      Join(
+        Table("t1"),
+        Left,
+        Table("t2"), 
+        Some(
+          Condition(
+            "a",
+            Comparison(Equals, "b")
+            )
+          )
+        )
+    ]
+      )
+    ));
+  assert(parse "SELECT * FROM t1 JOIN t2 ON a = b JOIN t3 ON c = d"  = Query(
+    Select([Asterisk], [
+      Join(
+        Join(
+          Table("t1"), 
+          Left, 
+          Table("t2"),
+          Some(
+            Condition(
+              "a",
+              Comparison(Equals, "b")
+            )
+          )
+        ),
+        Left,
+        Table("t3"),
+        Some(
+          Condition(
+            "c",
+            Comparison(Equals, "d")
+          )
+        )
+      )
+    ]
+      )
+    )
+  );