diff --git a/test/test_ast.ml b/test/test_ast.ml index e72ab4a..4e366f4 100644 --- a/test/test_ast.ml +++ b/test/test_ast.ml @@ -10,13 +10,208 @@ let equal_ast ast1 ast2 = let query_testable = Alcotest.testable Ast.pp_query equal_ast -let test_simple_select() = +let test_simple_select () = let q1 = parse "SELECT a FROM t" in - let ast1 = Query(Select([Column("a")], [Table("t")])) in - Alcotest.(check query_testable) "Ok" q1 ast1 + let ast1 = Query(Select([Column("a")], [Table("t")], None)) in + Alcotest.(check query_testable) "Ok" q1 ast1; + + let q2 = parse "SELECT * FROM t" in + let ast2 = Query(Select([Asterisk], [Table("t")], None)) in + Alcotest.(check query_testable) "Ok2" q2 ast2 + +let test_default_join () = + let q1 = parse "SELECT a FROM t1 JOIN t2 ON a = b" in + let ast1 = Query( + Select( + [Column("a")], + [Join( + Table("t1"), + Left, + Table("t2"), + Some( + Condition( + "a", + Comparison(Equals, "b") + ) + ) + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_left_join () = + let q1 = parse "SELECT a FROM t1 LEFT JOIN t2 ON a = b" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Left, + Table("t2"), + Some( + Condition( + "a", + Comparison(Equals, "b") + ) + ) + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_right_join () = + let q1 = parse "SELECT a FROM t1 RIGHT JOIN t2 ON a = b" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Right, + Table("t2"), + Some( + Condition( + "a", + Comparison(Equals, "b") + ) + ) + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_inner_join () = + let q1 = parse "SELECT a FROM t1 INNER JOIN t2 ON a = b" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Inner, + Table("t2"), + Some( + Condition( + "a", + Comparison(Equals, "b") + ) + ) + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_union_join () = + let q1 = parse "SELECT a FROM t1 UNION JOIN t2" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Union, + Table("t2"), + None + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_cross_join () = + let q1 = parse "SELECT a FROM t1 CROSS JOIN t2" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Cross, + Table("t2"), + None + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_natural_join () = + let q1 = parse "SELECT a FROM t1 NATURAL JOIN t2" in + let ast1 = Query( + Select([Column("a")], + [Join( + Table("t1"), + Natural, + Table("t2"), + None + )], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_join_join () = + let q1 = parse "SELECT a FROM t1 JOIN t2 ON a = b JOIN t3 ON a = c" in + let ast1 = Query( + Select([Column("a")], [ + Join( + Join( + Table("t1"), + Left, + Table("t2"), + Some( + Condition( + "a", + Comparison(Equals, "b") + ) + ) + ), + Left, + Table("t3"), + Some( + Condition( + "a", + Comparison(Equals, "c") + ) + ) + ) + ], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 + +let test_where_equals () = + let q1 = parse "SELECT a FROM t1 WHERE a = a OR a = b" in + let ast1 = Query( + Select( + [Column("a")], + [Table("t1")], + None + ) + ) in + Alcotest.(check query_testable) "Ok" q1 ast1 let simple_select_set = [ ("Equals", `Quick, test_simple_select) ] +let simple_join_set = [ + ("Default Join", `Quick, test_default_join); + ("Left Join", `Quick, test_left_join); + ("Right Join", `Quick, test_right_join); + ("Inner Join", `Quick, test_inner_join); + ("Union Join", `Quick, test_union_join); + ("Cross Join", `Quick, test_cross_join); + ("Natural Join", `Quick, test_natural_join) +] + +let multiple_joins_set = [ + ("Join Join", `Quick, test_join_join) +] + +let where_clauses_set = [ + ("Where Equals", `Quick, test_where_equals) +] + let () = Alcotest.run "Ast tests" - [ ("Simple Selects", simple_select_set) ] + [ + ("Simple Selects", simple_select_set); + ("Simple Joins", simple_join_set); + ("Multiple Joins", multiple_joins_set); + ("Where Clauses", where_clauses_set) + ]