open Py_types
open Py_exceptions
open Util
open Big_int
open Num

exception Found
exception NotFound

(* 
  RAW DATUM OPERATIONS
  --------------------

  These operations apply directly to terms representing
    evaluated native objects.

  They do NOT take into account special class methods
    when applied to instance objects.

  They do NOT take into account special/psuedo attributes.

  They do NOT work on unevaluated terms (like PyList).
*)

let get_typename a =
  match a with
  (* specials *)
  | PyNone              -> "PyNoneType"
  | PyTraceback _       -> "PyTracebackType"
  | PyEnv _             -> "PyEnvironmentType"

  (* numeric *)
  | PyInt _             -> "PyIntType"
  | PyLong _            -> "PyLongType"
  | PyRational _        -> "PyRationalType"
  | PyFloat _           -> "PyFloatType"
  | PyComplex _         -> "PyComplexType"

  (* sequences and mappings *)
  | PyString _          -> "PyStringType"
  | PyMutableList _     -> "PyListType"
  | PyTuple _           -> "PyTupleType"
  | PyModule _          -> "PyModuleType"
  | PyDictionary _      -> "PyDictionaryType"
  | IntRange _          -> "PyXrangeType"
  | PyFile _            -> "PyFileType"

  (* functions, classes and methods *)
  | PyBoundMethod _     -> "PyBoundMethodType"
  | PyClass _           -> "PyClassType"
  | PyInstance _        -> "PyInstanceType"
  | PyFunction _        -> "PyFunctionType"
  | PyNativeFunction _  -> "PyNativeFunctionType"

  (* other *)
  | PyRegexp _          -> "PyRegexpType"
  | PyWidget _          -> "PyWidgetType"
  | _                   -> "PyExpressionType" 
;;

(* Boolean operations *)

let py_istrue x = match x with
  | PyNone 
  | PyInitial
  | PyInt 0 
  | PyFloat 0.0 
  | PyString ""
  | PyComplex (0.0, 0.0) 
  | PyTuple []
      -> false
  | PyLong x -> eq_big_int x zero_big_int 
  | PyRational x -> eq_num x (num_of_int 0)
  | PyMutableList l -> Varray.length l <> 0
  | PyDictionary d -> d#len <> 0
  | IntRange (start, stop, step) -> start * step < stop * step
  | _ -> true

(* WARNING: the next two routines use evaluated arguments,
  they're not shortcut operations!
*)
let py_or x y = if py_istrue x then x else y
and py_and x y = if not (py_istrue y) then x else y
and py_not x = PyInt (if py_istrue x then 0 else 1)

(* internal coercion: converts two number to common type,
   leaves everything else alone.
   WARNING: this is not the same as the builtin function
   coerce: if two values cannot be coerced, they're just left alone.
*)

let py_coerce (x,y) = match (x,y) with
  | (PyInt x', PyLong _) ->      (PyLong (big_int_of_int x'), y)
  | (PyInt x', PyRational _) ->  (PyRational (num_of_int x'), y)
  | (PyInt x', PyFloat _) ->     (PyFloat (float_of_int x'), y)
  | (PyInt x', PyComplex _) ->    (PyComplex ((float_of_int x'),0.0), y)

  | (PyLong _, PyInt y') ->      (x, PyLong (big_int_of_int y'))
  | (PyLong x', PyRational _) -> (PyRational (num_of_big_int x'), y)
  | (PyLong x', PyFloat _) ->    (PyFloat (float_of_big_int x'), y)
  | (PyLong x', PyComplex _) ->   (PyComplex ((float_of_big_int x'),0.0), y)

  | (PyRational _, PyInt y') -> (x, PyRational (num_of_int y'))
  | (PyRational _, PyLong y') -> (x, PyRational (num_of_big_int y'))
  | (PyRational x', PyFloat _) -> (PyFloat (float_of_num x'), y)
  | (PyRational x', PyComplex _) -> (PyComplex ((float_of_num x'),0.0), y)

  | (PyFloat _, PyInt y') -> (x, PyFloat (float_of_int y'))
  | (PyFloat _, PyLong y') -> (x, PyFloat (float_of_big_int y'))
  | (PyFloat _, PyRational y') -> (x, PyFloat (float_of_num y'))
  | (PyFloat x', PyComplex _) -> (PyComplex (x',0.0), y)

  | (PyComplex _, PyInt y') -> (x, PyComplex ((float_of_int y'), 0.0))
  | (PyComplex _, PyLong y') -> (x, PyComplex ((float_of_big_int y'), 0.0))
  | (PyComplex _, PyRational y') -> (x, PyComplex ((float_of_num y'), 0.0))
  | (PyComplex _, PyFloat y') -> (x, PyComplex (y',0.0))
  | _ -> (x,y) (* NOTE: NO ERROR HERE *)

(* Comparisons *)

let rec list_less (cmp:'a->'a->bool) (x:'a list) (y:'a list) : bool = 
  let list_less' = list_less cmp in
  match (x,y) with
  | (xh :: xt, yh :: yt) ->
     if cmp xh yh then true
     else if cmp yh xh then false
     else list_less' xt yt
  | ([], _ :: _) -> true
  | (_ :: _, []) -> false
  | ([], []) -> false

let rec dict_entry_less (DictEnt (xk, xv)) (DictEnt (yk, yv)) = 
  if py_less xk yk then true 
  else if py_less yk xk then false
  else py_less xv yv

and dictionary_less x y =
  raise (NotImplemented "Dictionary less")
  
(* compare python values, return ocaml bool result *)
and py_less (x:expr_t) (y:expr_t) : bool =
  match py_coerce (x,y) with
  | (PyInitial, PyInitial) -> false
  | (PyTerminal, PyTerminal) -> false

  | (PyInitial, _ ) -> true
  | (_, PyInitial) -> false 
  | (_ , PyTerminal ) -> true
  | (PyTerminal, _ ) -> false

  | PyNone, PyNone -> false
  | PyNone, _ -> true
  | (PyInt x', PyInt y') -> x' < y'
  | (PyLong x', PyLong y') -> lt_big_int x' y'
  | (PyRational x', PyRational y') -> lt_num x' y'
  | (PyFloat x', PyFloat y') -> x' < y'
  | (PyString x', PyString y') -> x' < y'
  | (PyTuple x', PyTuple y') -> list_less py_less x' y'
  | (PyMutableList x', PyMutableList y') -> Varray.less py_less x' y'
  | (PyDictionary x', PyDictionary y') -> dictionary_less x' y'
  | _ -> (compare x y) < 0 
  (* extension to total order for all types: might not work right *)

(* for numbers, numeric equality, for structures, structural equality,
  for objects, object identity [more or less :-]
*)

and py_equal (x:expr_t) (y:expr_t) : bool = 
  let (x',y') = py_coerce (x, y) 
  in match (x',y') with
  | (PyLong a, PyLong b) -> eq_big_int a b
  | (PyRational a, PyRational b) -> eq_num a b
  | (PyDictionary d1, PyDictionary d2) -> d1#cmp py_equal d2
  | (PyMutableList l1, PyMutableList l2) -> Varray.equal py_equal l1 l2
  | _ -> (x' = y')

and py_greater x y = py_less y x
and py_not_equal x y = not (py_equal x y)
and py_greater_equal x y = not (py_less x y )
and py_less_equal x y = not (py_less y x)
and py_is x y  = x == y
and py_in e l = 
  match l with
  | PyTuple l' -> List.mem e l'
  | PyMutableList l' -> 
    begin try  
      Varray.iter (fun x -> if x = e then raise Found) l';
      raise NotFound
    with | Found -> true | NotFound -> false
    end
  | PyString s -> 
    begin match e with 
    | PyString s' when String.length s' = 1 -> 
      String.contains s (String.get s' 0)
    | _ -> raise (TypeError "string member test needs char left operand")
    end
  | IntRange (start, stop, step) ->
    begin match e with 
    | PyInt i ->
      let n = max ((stop - start)/step) 0 in
      let i' = (i - start) in
      if i' mod (abs step) <> 0 then false
      else i' >=0 && i'/step < n
    | _ -> false
    end
  | _ -> raise (TypeError "'in' or 'not in' requires sequence right argument")

and py_not_in e l = not (py_in e l)

and py_compare x y =
  if py_less x y then -1
  else if py_equal x y then 0
  else 1
;;


(* bitwise operations on integers *)
let py_bit_or (x:int) (y:expr_t) : int =
  match y with 
  | PyInt y' -> x lor y'
  | _ -> raise (TypeError "bad operand types(s) for |")

and py_bit_and (x:int) (y:expr_t) : int =
  match y with 
  | PyInt y' -> x land y'
  | _ -> raise (TypeError "bad operand types(s) for &")

and py_bit_xor (x:int) (y:expr_t) : int =
  match y with 
  | PyInt y -> x lxor y
  | _ -> raise (TypeError "bad operand types(s) for ^")

and py_complement (e:expr_t):int = 
  match e with PyInt x' -> x' lxor x'
  | _ -> raise (TypeError "Bad operand type for unary ~")
;;

(* arithmetic operations *)
let py_neg x = 
  match x with
  | PyInt y -> PyInt (-y)
  | PyLong y -> PyLong (minus_big_int y)
  | PyRational y -> PyRational (minus_num y)
  | PyFloat y -> PyFloat (-. y)
  | PyComplex (r,i) -> PyComplex ( -. r, -. i)
  | _ -> raise (TypeError "Attempt to negate non number")

let py_add x y = 
  match py_coerce (x,y) with 
  | (PyInt x' , PyInt y') -> PyInt (x' + y')
  | (PyLong x' , PyLong y') -> PyLong (add_big_int x' y')
  | (PyRational x' , PyRational y') -> PyRational (add_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' +. y')
  | (PyComplex (rx', ix') , PyComplex (ry', iy')) -> 
    PyComplex ((rx' +. ry'),(ix' +. iy'))
  | (PyString x', PyString y') -> PyString (x' ^ y')
  | (PyTuple x', PyTuple y') -> PyTuple (x' @ y')
  | (PyMutableList x', PyMutableList y') -> PyMutableList (Varray.concat [x'; y'])
  | _ -> 
    print_string "Add: Error coercing ";
    Py_print.print_expression 0 x;
    print_string " and ";
    Py_print.print_expression 0 y;
    print_endline "";
    flush stdout;
    raise (TypeError "Coerce: number coercion failed")

and py_sub x y =
  match py_coerce (x,y) with 
  | (PyInt x' , PyInt y') -> PyInt (x' - y')
  | (PyLong x' , PyLong y') -> PyLong (sub_big_int x' y')
  | (PyRational x' , PyRational y') -> PyRational (sub_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' -. y')
  | (PyComplex (rx', ix') , PyComplex (ry', iy')) -> 
    PyComplex ((rx' -. ry'),(ix' -. iy'))
  | _ -> raise (TypeError "Sub: number coercion failed")

and py_mul x y =
  match py_coerce (x,y) with 
  | (PyInt x' , PyInt y') -> PyInt (x' * y')
  | (PyLong x' , PyLong y') -> PyLong (mult_big_int x' y')
  | (PyRational x' , PyRational y') -> PyRational (mult_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' *. y')
  | (PyComplex (rx', ix') , PyComplex (ry', iy')) -> 
    PyComplex ((rx' *. ry' -. ix' *. iy'), (rx' *. iy' +. ix' *. ry'))
  | (PyString x', PyInt y') -> 
    let result = ref "" in for i = 1 to y' do
      result := !result ^ x'
      done;
      PyString !result
  | (PyTuple x', PyInt y') -> 
    let result = ref [] in 
    for i = 1 to y' do
      result := !result @ x'
    done;
    PyTuple !result

  | (PyMutableList x', PyInt y') ->
    if y' >= 0 then
    begin
      let rlist = ref [] in
      for i= 0 to y'-1 do rlist := x' :: !rlist done;
      PyMutableList (Varray.concat !rlist)
    end else raise (ValueError "List * int requires int >=0")
  | _ -> raise (TypeError "Multiply: number coercion failed")

and py_div x y =
  match py_coerce (x,y) with
  | (PyInt x' , PyInt y') -> PyInt (x' / y')
  | (PyLong x' , PyLong y') -> PyLong (div_big_int x' y')
  | (PyRational x' , PyRational y') -> PyRational (div_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' /. y')
  | (PyComplex (rx', ix') , PyComplex (ry', iy')) -> 
    let d = ix' *. ix' +. iy' *. iy' in
    PyComplex (
      (rx' *. ry' +. ix' *. iy') /. d,
      (ix' *. ry' -. rx' *. iy') /. d
    )
  | _ -> raise (TypeError "Divide: number coercion failed")


and py_mod x y =
  match py_coerce (x,y) with
  | (PyInt x' , PyInt y') -> PyInt (x' mod y')
  | (PyLong x' , PyLong y') -> PyLong (mod_big_int x' y')
  | (PyRational x' , PyRational y') -> PyRational (mod_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' -. floor (x' /. y') *. y')
  | (PyComplex _, PyComplex _) -> raise (NotImplemented "Complex remainder")
  | (PyString x', y') -> PyString (Py_printf.py_format x' y') 
  | _ -> raise (TypeError "Mod: number coercion failed")

and py_lsl x y =
  match (x,y) with 
  | (PyInt x' , PyInt y') -> PyInt (x' lsl y')
  | _ -> raise (TypeError "Bad operand type(s) for <<")

and py_lsr x y =
  match (x,y) with 
  | (PyInt x' , PyInt y') -> PyInt (x' lsr y')
  | _ -> raise (TypeError "Bad operand type(s) for >>")

and py_pow x y =
  match py_coerce (x,y) with
  | (PyInt x' , PyInt y') -> PyInt (truncate ((float x') ** (float  y')))
  | (PyRational x' , PyRational y') -> PyRational (power_num x' y')
  | (PyFloat x' , PyFloat y') -> PyFloat (x' ** y')
  | (PyComplex _, PyComplex _) -> raise (NotImplemented "Complex exponent")
  | _ -> raise (TypeError "Bad operand type(s) for **")

(* tracebacks don't give errors on getattr *)
let get_tb_attr tb attr =
  Some begin match tb with
  | (lineno, filename) :: next ->
    begin match attr with 
    | PyString "tb_filename" -> PyString filename
    | PyString "tb_lineno" -> PyInt lineno
    | PyString "tb_next" -> if next = [] then PyNone else PyTraceback next
    | _ -> PyNone
    end
  | [] -> PyNone
  end

(* get item from a mapping like object *)
let py_get_attribute obj attr =
  match obj with
  | PyInstance i -> i#get_attr attr
  | PyClass c -> c#get_attr attr
  | PyModule m -> m#get_attr attr
  | PyFunction f -> f#get_attr attr
  | PyTraceback tb -> get_tb_attr (List.rev tb) attr
  | _ -> None

let py_is_sequence obj = 
  match obj with 
  | PyMutableList _ 
  | PyTuple _ 
  | PyString _
  | IntRange _ -> true
  | _ -> false

let py_is_mapping obj = 
  match obj with 
  | PyDictionary _ -> true
  | _ -> py_is_sequence obj (* sequences are mappings too! *)

let py_seq_len s = 
  match s with 
  | PyMutableList l -> Varray.length l
  | PyTuple l -> List.length l
  | PyString s' -> String.length s'
  | IntRange (start, stop, step) ->
    max ((stop - start) / step) 0
  | _ -> raise NonSequence


(* get the k'th item of a sequence object, with -ve conversion *)
let py_get_seq_elem obj (k:int) =
  match obj with
  | PyTuple li -> 
    let n = List.length li in
    let k' = if k>=0 then k else n + k in
    begin try
      List.nth li k'
    with Failure "nth" -> 
      raise (IndexError ("Tuple Subscript out of range", k',n))
    end

  | PyMutableList li -> 
    let n = Varray.length li in
    let k' = if k>=0 then k else n + k in
    if k'<0 or k'>=n 
    then raise (IndexError ("List Subscript out of range",k',n))
    else Varray.get li k'

  | PyString s -> 
    let n = String.length s in
    let k' = if k>=0 then k else n+ k in
    begin try
      PyString (String.sub s k' 1)
    with Failure "nth" -> raise (IndexError ("String Subscript out of range",k',n))
    end

  | IntRange (start, stop, step) -> 
    let n = max ((stop - start) / step) 0 in
    let k' = if k>=0 then k else n+ k in
    let j = start + k' * step in
    if step > 0 then
      if j >= start && j < stop then PyInt j
      else raise (IndexError ("Xrange subscript out of range",k',n))
    else
      if j <= start && j > stop then PyInt j
      else raise (IndexError ("Xrange subscript out of range",k',n))

  | _ -> raise (ViperError ("native sequence required for get_subscript"))

(* result is 0 <= first <= last <= n *)
let normalise_range n first last = 
  let first = max 0 (if first < 0 then first + n else first)
  and last = min n (if last < 0 then last + n else last)
  in let amt = max 0 (last - first) in
  let last = first + amt in first, last
 
let py_get_seq_slice obj first last =
  let n = py_seq_len obj in
  let first, last = normalise_range n first last in
  let amt = last - first in 
  match obj with
  | PyTuple li -> PyTuple (Util.list_sub li first amt)
  | PyMutableList li -> PyMutableList (Varray.sub li first amt)
  | PyString s -> PyString (String.sub s first amt)
  | IntRange (start, stop, step) -> 
    let start = start + first * step
    and stop = start + amt * step
    in IntRange (start, stop, step)
  | _ -> raise (ViperError ("native sequence required for get_seq_slice"))


(* handle strides > 0 *)
let py_get_seq_slice_extended obj first last stride =
  if stride <= 0 then raise (ValueError "Stride must be positive");
  if stride = 1 then py_get_seq_slice obj first last
  else
    let n = py_seq_len obj in
    let first, last = normalise_range n first last in
    let amt = last - first in
    match obj with
    | IntRange (start, stop, step) ->
      let start'' = start + step * first
      and stop'' = start + step * last
      and step'' = step * stride 
      in let n' = max ((stop'' - start'') / step'') 0
      in let stop''' = start'' + step'' * n' 
      in IntRange (start'', stop''', step'')

    | PyString s -> 
      let res = String.create amt in
      let idx = ref first in
      for i=0 to amt-1 do
        res.[i] <- s.[!idx];
        idx := !idx + stride
      done
      ;
      PyString res
  
    | PyMutableList s -> 
      if amt = 0 then PyMutableList (Varray.empty ())
      else
      let res = Array.create amt (Varray.get s 0) in
      let idx = ref first in
      for i=0 to amt-1 do
        res.(i) <- Varray.get s !idx;
        idx := !idx + stride
      done
      ;
      PyMutableList (Varray.of_array res)
  
    | PyTuple s -> (* this code is slow ! *)
      let res = ref [] in
      let idx = ref first in
      for i=0 to amt-1 do
        res := !res @ [List.nth s !idx];
        idx := !idx + stride
      done
      ;
      PyTuple !res
    | _ -> raise (ViperError "In Py_datum.py_get_seq_slice_extended")
    
