(* Function binding *)

open Util
open Py_types
open Py_mtypes
open Py_builtins

let name_binder binder statement = 
  let rec bs s = 
  match s with
  Empty -> Empty
  | While (sr, e1, s1, s2) -> While (sr, be e1, bs s1, bs s2)

  | For (sr, e1, e2, s1, s2) -> For (sr, be e1, be e2, bs s1, bs s2)
  | Def (sr,  name, ps, s1) -> Def (sr, name, b_parameters ps, bs s1)
  | Class (sr, name, bases, body) -> Class (sr, name, be bases, bs body)
  | TryFinally (body, finally) -> TryFinally (bs body, bs finally)
  | TryElse  (body, handlers, else_part) -> 
    TryElse (bs body, List.map b_handler handlers, bs else_part)
  | Global _ -> Empty
  | Break sr -> Break sr
  | Continue sr -> Continue sr
  | Return (sr, e1) -> Return (sr, be e1)
  | Raise0 (sr) -> Raise0 sr
  | Raise1 (sr, e1) -> Raise1(sr, be e1)
  | Raise2 (sr, e1, e2) -> Raise2 (sr, be e1, be e2)
  | Raise3 (sr, e1, e2, e3) -> Raise3 (sr, be e1, be e2, be e3)
  | Print (sr, es) -> Print (sr, List.map be es)
  | PrintComma (sr, es) -> PrintComma (sr, List.map be es)
  | Pass -> Pass
  | Exec1 (sr, e) -> Exec1 (sr, be(e))
  | Exec2 (sr, e1, e2) -> Exec2 (sr, be e1, be e2)
  | Exec3 (sr, e1, e2, e3) -> Exec3 (sr, be e1, be e2, be e3)
  | Assert2 (sr, e1, e2) -> Assert2 (sr, be e1, be e2) 
  | Assert1 (sr, e1) -> Assert1 (sr, be e1) 
  | Suite ss -> Suite (List.map bs ss)

  (* still need to fix import! *)
  | Import (sr, sll) -> Import (sr, sll)
  | ImportFrom (sr, sl1, sl2) -> ImportFrom (sr, sl1, sl2)
  | ImportAll (sr, sl) -> ImportAll (sr, sl)

  | If (css, else_part) -> 
    If ((
      List.map
      (fun x -> match x with (sr, e, s) -> (sr, be e, bs s))
      css
    ), bs else_part)
    
  | Assign (sr,es) -> Assign (sr, List.map be es)

  | ColonEqual (sr,v,e) -> ColonEqual (sr, be v, be e)
  | PlusEqual (sr,v,e) -> PlusEqual (sr, be v, be e)
  | MinusEqual (sr,v,e) -> MinusEqual (sr, be v, be e)
  | StarEqual (sr,v,e) -> StarEqual (sr, be v, be e)
  | SlashEqual (sr,v,e) -> SlashEqual (sr, be v, be e)
  | PercentEqual (sr,v,e) -> PercentEqual (sr, be v, be e)
  | AmperEqual (sr,v,e) -> AmperEqual (sr, be v, be e)
  | VbarEqual (sr,v,e) -> VbarEqual (sr, be v, be e)
  | CaretEqual (sr,v,e) -> CaretEqual (sr, be v, be e)
  | LeftShiftEqual (sr,v,e) -> LeftShiftEqual (sr, be v, be e)
  | RightShiftEqual (sr,v,e) -> RightShiftEqual (sr, be v, be e)
  | PlusPlus (sr,v) -> PlusPlus (sr, be v)
  | MinusMinus (sr,v) -> MinusMinus (sr, be v)

  | Expr  (sr, e) -> Expr (sr, be(e))
  | Del (sr, e) -> Del (sr, be(e))

  and b_dictent d =
  match d with DictEnt (e1, e2) -> DictEnt (be e1, be e2)

  and b_argument a  = 
  match a with
    | Argument2 (name , e) -> Argument2 (name, be e)
    | Argument1 e -> Argument1  (be e)

  and b_parameter p = 
  match p with 
    Parameter2 (pname, dflt) -> Parameter2 (pname, be dflt)
    | Parameter1 pname -> Parameter1 pname

  and b_parameters pms = 
  match pms with (ps, star_param, starstar_param) ->
   ((List.map b_parameter ps), star_param, starstar_param)
   
  and b_subscript_entry se =
  match se with
    Defsub -> Defsub
    | Pos e -> Pos (be e)

  and b_subscript j = 
  match j with
    | Ellipsis -> Ellipsis
    | Subscript2 (se1, se2, se3) -> Subscript2 (
        b_subscript_entry se1, 
        b_subscript_entry se2, 
        b_subscript_entry se3
      )
    | Subscript1 (se1, se2) -> Subscript1 (b_subscript_entry se1, b_subscript_entry se2) 
    | Subscript0 se -> Subscript0 (b_subscript_entry se)

  and b_except ex = 
  match ex with
    Except0 -> Except0
    | Except1 (sr, e) -> Except1 (sr, be e)
    | Except2 (sr, e1, e2) -> Except2 (sr, be e1, be e2)

  and b_handler h = 
  match h with (ex, s) -> (b_except ex, bs s)

  and b_trailer t = 
  match t with
    Arglist als -> Arglist (List.map b_argument als)
    | Dotname name -> Dotname name
    | Sublist subs -> Sublist (List.map b_subscript subs)

  and b_comparator c = 
  match c with
    Less e -> Less (be e)
    | Greater e -> Greater (be e)
    | LessEqual e -> LessEqual (be e)
    | GreaterEqual e -> GreaterEqual (be e)
    | Equal e -> Equal (be e)
    | NotEqual e -> NotEqual (be e)
    | Is e -> Is (be e)
    | IsNot e -> IsNot (be e)
    | In e -> In (be e)
    | NotIn e -> NotIn (be e)

  and b_binop b  =
  match b with
    Add e -> Add (be e)
    | Sub e -> Sub (be e)
    | Mul e -> Mul (be e)
    | Div e -> Div (be e)
    | Mod e -> Mod (be e)
    | Asl e -> Asl (be e)
    | Lsr e -> Lsr (be e)
    | Pow e -> Pow (be e)
  
  and be expr = 
    match expr with
    | PyName name -> binder name
    | PyTuple es -> PyTuple (List.map be es)
    | PyList es -> PyList (List.map be es)
    | PyMutableList l -> PyMutableList l
    | PyDict des -> PyDict (List.map b_dictent des)
    | PyRepr e -> PyRepr (be e)
    | Or es -> Or (List.map be es)
    | And es -> And (List.map be es)
    | Not e -> Not (be e)
    | Neg e -> Neg (be e)
    | Compare (e, cs) -> Compare (be e, List.map b_comparator cs) 
    | BitOr es -> BitOr (List.map be es)
    | BitAnd es -> BitAnd (List.map be es)
    | BitXor es -> BitXor (List.map be es)
    | Complement e -> Complement (be e)
    | Eval (e, binops)  -> Eval (be e, List.map b_binop binops)
    | AtomWithTrailers (e, tls) -> AtomWithTrailers (be e, List.map b_trailer tls)
    | Lambda (ps, e) -> Lambda ((b_parameters ps), be e)
    | x -> x
  in bs statement
;;

let bind_function (f:function_t) gn2i =
  print_endline ("BINDING FUNCTION ");

  print_endline "Parameters are";
  let parameters = f#get_parameters in
  Py_print.print_parameters 0 parameters;

  print_endline "Declared globals are";
  let declared_globals = f#get_global_names in
  print_endline (string_of_list "" " " "" (fun x -> x) declared_globals);

  print_endline "BODY is ";
  let body = f#get_code in
  Py_print.print_statement 0 body;

  let vars = ref [] 
  and lhss = ref StringSet.empty 
  in let add_var name = vars := name :: !vars
  and add_local name = lhss := StringSet.add name !lhss
  in 
  let rec find_vars x =
  match x with 
  | PyName name -> add_local name
  | PyTuple ts -> List.iter find_vars ts
  | _ -> ()

  in let rec find_variables x = 
  match x with 
  | Empty -> ()
  | For (_,e1, e2, s1, s2) ->
    find_vars e1;
    find_variables s1;
    find_variables s2

  | Class (_,name,_,_) -> add_local name
  | Def (_,name,_,_) -> add_local name
  | TryFinally (s1, s2) -> find_variables s1; find_variables s2
  | TryElse (s1,_,s2) -> find_variables s1; find_variables s2
  | Suite ts -> List.iter find_variables ts
  | If (cs,s1) -> 
    List.iter (fun (_,_,s) -> find_variables s) cs;
    find_variables s1
  | Assign (_, es) ->
    let lhs = list_all_but_last es in
    List.iter find_vars lhs
  | _ -> ()
  in find_variables body;
  print_endline "LHS VARIABLES ARE";
  StringSet.iter (fun x -> print_string (x ^ " ")) !lhss;
  print_endline "";
  
  let locals = ref StringSet.empty in
  StringSet.iter 
  (fun x -> 
    if not (List.mem x f#get_global_names) 
    then locals := StringSet.add x !locals) 
  !lhss;
  
  match parameters with (params, stars, starstars) ->
  List.iter
  begin fun x ->
    let ps = match x with
    | Parameter1 p -> p
    | Parameter2 (p, _) -> p
    in
      let rec get_names p = 
        match p with
        | Param name -> locals := StringSet.add name !locals
        | Paramtuple ls -> List.iter get_names ls
      in get_names ps
  end
  params;
  begin match stars with 
  | StarParam name -> locals := StringSet.add name !locals 
  | NoStarParam -> ()
  end;
  begin match starstars with 
  | StarStarParam name -> locals := StringSet.add name !locals 
  | NoStarStarParam -> ()
  end;

  print_endline "LOCAL VARIABLES ARE";
  StringSet.iter (fun x -> print_string (x ^ " ")) !locals;
  print_endline "";

  (* now, bind the local variables *)
  let n2i (* map: functional *) = ref f#get_local_name_to_index in
  let old_i2n (* array: imperative  *) = ref f#get_index_to_local_name in
  let counter = ref (Array.length !old_i2n) in
  StringSet.iter
  begin fun name -> 
    if not (VarMap.mem name !n2i)
    then begin
      n2i := VarMap.add name !counter !n2i ; 
      incr counter
    end
  end
  !locals;

  VarMap.iter
  (fun k v -> print_endline (k ^ " -> " ^ (string_of_int v)) )
  !n2i;
  
  let local_binder name =
    begin try PyVarIndex (0, (VarMap.find name !n2i))
    with _ -> 
        begin try PyVarIndex (1, (VarMap.find name gn2i))
        with _ -> PyName name
        end
    end
  in let bs = name_binder local_binder
  in let body' = bs f#get_code
  in Py_print.print_statement 0 body';
  flush stdout;

  let new_i2n = Array.create !counter "" in
  VarMap.iter
  (fun name index -> new_i2n.(index) <- name)
  !n2i;
  let f' = new Py_function.py_function 
    f#get_name 
    f#get_parameters 
    body' 
    f#get_global_names
    new_i2n 
    !n2i
    f#get_environment
  in 
  flush stdout;
  PyFunction (f':>function_t)
;;


let bind_module (m:module_t) =

  (* pass 1: make module index *)
  let d = m#get_dictionary in
  let i2n (* array *) = ref m#get_index_to_global_name
  and n2i (* map *) = ref m#get_global_name_to_index
  in let counter = ref (Array.length !i2n) 
  in
    (* extend the name to index map *)
    d#iter
    begin fun k v -> 
      match k with
      | PyString name ->
        if not (VarMap.mem name !n2i)
        then begin
          n2i := VarMap.add name !counter !n2i; 
          incr counter
        end
      | _ -> print_endline "WARNING: module dictionary contains non-string key"
    end;

    (* create a new array of names *)
    let new_i2n = Array.create !counter "" in   
    VarMap.iter
    (fun name index -> new_i2n.(index) <- name)
    !n2i;
    
    (* print the array *)
    print_endline "MODULE INDEX TO NAME MAP";
    for i = 0 to (Array.length new_i2n) - 1 do
      print_endline ((string_of_int i) ^ ": " ^ new_i2n.(i))
    done;
    print_endline "";

    (* print the map *)
    print_endline "MODULE NAME TO INDEX MAP";
    VarMap.iter
    (fun name index -> print_endline (name ^ " -> " ^ (string_of_int index)))
    !n2i;
    flush stdout;

    (* bind all the functions now *)
    for i = 0 to (Array.length new_i2n) - 1 do
      let name = new_i2n.(i) in
      let obj = 
        match d#get_item (PyString name) 
        with Some x -> x
        | _ -> raise (Failure "Cannot find item in dictionary (bind module)")
      in
      match obj with
      | PyFunction f -> 
        let bound_f = bind_function f !n2i 
        in ignore (d#set_item (PyString name) bound_f)
      | _ -> ()
    done;
    PyNone    
;;

