(* type inference *)

(*
  Suppose a function, such as python +, is called with
  some values: what is type of the result?

  Let us assume that the type is determined by the
  arguments: this is not strictly the case for python
  functions, since python is dynamic.

  We solve this problem by introducing a coproduct type
  combinator, denoting one of a set of types.

  Now, given an overload set:

    int x int -> int
    float x float -> float
    float x int -> float
    string x string -> string,

  for example, and given two argument types, we can fix the 
  return type: but, if we only know that the argument is one of 
  a set, we can establish a set of return types, enriched
  by the bindings required for that type.


  For example, given int,float arguments, we have

  int x int -> int
  int x float -> error
  float x float -> float
  float x int -> float
 
  Now, consider string or int arguments, we see that

  string x string -> string

  is the only possible binding, and we have not only learned
  something about the return type, but also the argument types.

*)

(* for testing! *)
let third x = match x with (a,b,c) -> c;;

type primitive_t = Tint | Tfloat | Tstring | Tlist;;
type binop_t = (primitive_t * primitive_t) * primitive_t;;
type unop_t = primitive_t * primitive_t;;

let add_types = [
  ((Tstring, Tstring) , Tstring);
  ((Tlist, Tlist) , Tlist);
  ((Tfloat, Tfloat) , Tfloat);
  ((Tint, Tint) , Tint);
  ((Tfloat, Tint) , Tfloat);
  ((Tint, Tfloat) , Tfloat)
];;

let sub_types = [
  ((Tfloat, Tfloat) , Tfloat);
  ((Tint, Tint) , Tint);
  ((Tfloat, Tint) , Tfloat);
  ((Tint, Tfloat) , Tfloat)
];;

let mul_types = [
  ((Tfloat, Tfloat) , Tfloat);
  ((Tint, Tint) , Tint);
  ((Tfloat, Tint) , Tfloat);
  ((Tint, Tfloat) , Tfloat);
  ((Tint, Tstring) , Tstring);
  ((Tstring, Tint) , Tstring);
  ((Tint, Tlist) , Tlist);
  ((Tlist, Tint) , Tlist)
];;

let div_types = [
  ((Tfloat, Tfloat) , Tfloat);
  ((Tint, Tint) , Tint);
  ((Tfloat, Tint) , Tfloat);
  ((Tint, Tfloat) , Tfloat);
];;


let unify_binop 
  (fun_types:binop_t list) 
  (arg1_types:primitive_t list)
  (arg2_types:primitive_t list)
  (res_types:primitive_t list)
: binop_t list
=
  let result = ref [] in
  List.iter 
  begin fun x -> 
    match x with
    ((a1, a2), res) ->
      if List.mem a1 arg1_types
      && List.mem a2 arg2_types
      && List.mem res res_types
      then result := x :: !result
  end 
  fun_types;
  !result
;;

let print_primitive x = 
  match x with 
  | Tint -> print_string "int"
  | Tstring -> print_string "string"
  | Tfloat -> print_string "float"
  | Tlist -> print_string "list"
;;

let print_primitive_list x = 
  List.iter 
    (fun x-> print_primitive x; print_string ", ") 
    x
;;

let print_binop b = 
match b with ((a1, a2), r) -> 
  print_primitive a1; print_string " x "; 
  print_primitive a2; print_string " -> "; 
  print_primitive r
;;

let print_binop_list bs =
  List.iter (fun x -> print_binop x; print_endline "")
  bs
;;

let add_unique l v = 
  if List.mem v l then l
  else v :: l
;;

let list_intersect a b =
  List.filter
    begin fun x -> List.mem x b end
    a
;;

let unzip_binop_list 
  (bs:binop_t list) 
: (primitive_t list * primitive_t list * primitive_t list) 
= 
  let az = ref []
  and bz = ref []
  and rz = ref [] in
  List.iter 
    begin fun x -> match x with ((a1, a2), r) ->
      az := add_unique !az a1;
      bz := add_unique !bz a2;
      rz := add_unique !rz r
    end
    bs
  ;
  (!az, !bz, !rz)
;;

type expr_t = 
  | Var of string 
  | Add of expr_t * expr_t | Mul of expr_t * expr_t
  | Sub of expr_t * expr_t | Div of expr_t * expr_t
;;

let rec print_expr e = 
  let p = print_string and pe = print_expr in 
  match e with
  | Var x -> p x
  | Add (a1, a2) -> p "("; pe a1; p " + "; pe a2; p ")"
  | Mul (a1, a2) -> p "("; pe a1; p " * "; pe a2; p ")"
  | Sub (a1, a2) -> p "("; pe a1; p " - "; pe a2; p ")"
  | Div (a1, a2) -> p "("; pe a1; p " / "; pe a2; p ")"
;;

type expr_t' = 
  | Var' of string 
  | Add' of expr_t' * expr_t' * binop_t list
  | Mul' of expr_t' * expr_t' * binop_t list 
  | Sub' of expr_t' * expr_t'  * binop_t list
  | Div' of expr_t' * expr_t' * binop_t list
;;

let rec print_expr' env e = 
  let p = print_string and pe = print_expr' env in 
  match e with
  | Var' (x) -> p x; print_string ":"; print_primitive_list (Hashtbl.find env x);
  | Add' (a1, a2,ts) -> p "("; pe a1; p " + "; pe a2; p "):"; print_primitive_list (third (unzip_binop_list ts))
  | Mul' (a1, a2,ts) -> p "("; pe a1; p " * "; pe a2; p "):"; print_primitive_list (third (unzip_binop_list ts)) 
  | Sub' (a1, a2,ts) -> p "("; pe a1; p " - "; pe a2; p "):"; print_primitive_list (third (unzip_binop_list ts)) 
  | Div' (a1, a2,ts) -> p "("; pe a1; p " / "; pe a2; p "):"; print_primitive_list (third (unzip_binop_list ts)) 
;;

let rec prime_expr env e = 
let pre = prime_expr env in
match e with
| Var x -> Var' (x)
| Add (x, y) -> Add' (pre x, pre y, add_types) 
| Sub (x, y) -> Sub' (pre x, pre y, sub_types) 
| Mul (x, y) -> Mul' (pre x, pre y, mul_types) 
| Div (x, y) -> Div' (pre x, pre y, div_types) 

let print_vars vs = 
  Hashtbl.iter
  begin fun key value -> 
    print_string key; print_string " -> ";
    List.iter 
      begin fun value ->
        print_primitive value;
        print_string ", "
      end
      value
    ;
    print_endline ""
  end
  vs
;;
  

let get_types env expr = match expr with
| Var' (name) -> Hashtbl.find env name 
| Mul' (_,_,ts) -> third (unzip_binop_list ts)
| Add' (_,_,ts) -> third (unzip_binop_list ts)
| Sub' (_,_,ts) -> third (unzip_binop_list ts)
| Div' (_,_,ts) -> third (unzip_binop_list ts)

let rec unify env expr restrict =
  match expr with
  | Var' (name) -> 
    let ts' = list_intersect restrict (Hashtbl.find env name) 
    in 
      Hashtbl.remove env name;
      Hashtbl.add env name ts';
      Var' (name) 

  | Mul' (x,y,ts) ->
    let tx = get_types env x
    and  ty = get_types env y
    in let u = unify_binop ts tx ty restrict in
    let (a,b,c) = unzip_binop_list u in
    let x' = unify env x a
    and y' = unify env y b
    in Mul' (x', y', u)

  | Add' (x,y,ts) -> 
    let tx = get_types env x
    and  ty = get_types env y
    in let u = unify_binop ts tx ty restrict in
    let (a,b,c) = unzip_binop_list u in
    let x' = unify env x a
    and y' = unify env y b
    in Add' (x', y', u)

  | Sub' (x,y,ts) -> 
    let tx = get_types env x
    and  ty = get_types env y
    in let u = unify_binop ts tx ty restrict in
    let (a,b,c) = unzip_binop_list u in
    let x' = unify env x a
    and y' = unify env y b
    in Sub' (x', y', u)

  | Div' (x,y,ts) -> 
    let tx = get_types env x
    and  ty = get_types env y
    in let u = unify_binop ts tx ty restrict in
    let (a,b,c) = unzip_binop_list u in
    let x' = unify env x a
    and y' = unify env y b
    in Div' (x', y', u)
;;


let vars = Hashtbl.create 97;;
let any = [Tint; Tfloat; Tstring];;
List.iter begin fun (x,y) -> Hashtbl.add vars x y end [
  ("x", any); 
  ("y", any);
  ("z", any)
];;

let expr_1 = 
  Add (
    Mul ( 
      Sub (
        Var "x", 
        Var "y"
      ), 
      Var "z"
    )
  ,
    Var "x"
  )
;;

let expr_2 = Mul (Var "x", Var "x");;
let samples  = ref [expr_1, expr_2];;

while (List.length !samples) <> 0 do
  print_endline "Variables";
  print_vars vars;
  print_endline "";

  print_endline "Example";
  print_expr expr_2;
  print_endline "";

  print_endline "Unification";
  let u1 = ref (prime_expr vars (List.hd !samples)) in
  let u2 = ref (unify vars !u1 any) in
  print_expr' vars !u2;
  print_endline "";

  print_endline "Variables";
  print_vars vars;
  print_endline "";

  begin let counter = ref 2 in
    while !u1 <> !u2 do
      print_endline (string_of_int (compare !u1 !u2));
      u1 := !u2;

      print_endline ("Unification " ^ string_of_int !counter);
      u2 := unify vars !u1 any;
      print_expr' vars !u2;
      print_endline "";

      print_endline "Variables";
      print_vars vars;
      print_endline "";
      incr counter
    done
  end;

  samples := List.tl !samples
done

;;

