open Py_types
open Py_exceptions
open Py_builtins_util
open Num

(* Python marshal format. *)

let mTYPE_NULL       = "0";;
let mTYPE_NONE       = "N";;
let mTYPE_ELLIPSIS   = ".";;
let mTYPE_INT        = "i";;
let mTYPE_INT64      = "I";;
let mTYPE_FLOAT      = "f";;
let mTYPE_COMPLEX    = "x";;
let mTYPE_LONG       = "l";;
let mTYPE_STRING     = "s";;
let mTYPE_TUPLE      = "(";;
let mTYPE_LIST       = "[";;
let mTYPE_DICT       = "{";;
let mTYPE_CODE       = "c";;
let mTYPE_UNKNOWN    = "?";;

(* Viper extensions *)
let mTYPE_RATIONAL   = "r";;

(*
  format notes:
    short integers: four byte binary little endian
    float: n,s where s is a string, n its length as one byte
    complex: two floats
    long, string: n,s where n is the length as a 4 byte long
    tuple, list: n, o1, o2 .. where n is length as 4 byte long
    dictionary: NULL terminated sequence of key1,value1,key2,value2 ..
    rational: two longs (numerator amd denominator)

    code: NOT SUPPORTED! [even though this is the main purpose of marshal :-]
          Viper doesn't have code objects, only ASTs !
          [Saving an AST is possible -- but would not be python compatible]
          [Reading bytecode is possible, but would require implementing a
           python VM to execute it]

*)

let s_int i' =
    let i = ref i' in
    let s = String.create 4 in
    s.[0] <- char_of_int (!i mod 256);  i := !i lsr 8;
    s.[1] <- char_of_int (!i mod 256);  i := !i lsr 8;
    s.[2] <- char_of_int (!i mod 256);  i := !i lsr 8;
    s.[3] <- char_of_int (!i mod 256);  i := !i lsr 8;
    s
;;

(* note: encodes absolute value of argument *)
let s_long i =
  let divisor = Big_int.big_int_of_int 256 in
  let sign = Big_int.sign_big_int i in
  let quotient = ref (Big_int.mult_int_big_int sign i) in
  let x = Buffer.create 64 in
  while (Big_int.sign_big_int !quotient) <> 0 do
    let quo, rmd = Big_int.quomod_big_int !quotient divisor in
    quotient := quo;
    Buffer.add_char x (char_of_int (Big_int.int_of_big_int rmd))
  done;
  let result = Buffer.contents x in
  if (String.length result) mod 2 = 0 
  then result
  else result ^ (String.make 1 (char_of_int 0))
;;

let sbyte i = String.make 1 (char_of_int i);;

(* Note: Ellipsis isn't supported, because it isn't currently an expr_t *)
let py_pmarshal_dump 
  (interp:interpreter_t) 
  (e:expr_t list) 
  (d:dictionary_t): expr_t  = 
  empty_dict d;
  exactly e 1;
  let rec dump x = match x with
  | PyNone -> mTYPE_NONE
  | PyInt i -> mTYPE_INT ^ s_int i
  | PyLong l -> 
    let sign = Big_int.sign_big_int l in
    let s = s_long l in 
    mTYPE_LONG ^ (s_int ((String.length s) * sign / 2)) ^ s
  | PyRational r -> 
    let num, denom = match r with 
    | Int i -> Big_int.big_int_of_int i, Big_int.unit_big_int
    | Big_int i -> i, Big_int.unit_big_int
    | Ratio r -> Ratio.numerator_ratio r, Ratio.denominator_ratio r
    in 
    let 
      signn = Big_int.sign_big_int num and
      signd = Big_int.sign_big_int denom and
      sn = s_long num and
      sd = s_long denom 
    in 
      mTYPE_RATIONAL ^ 
      (s_int ((String.length sn) * signn / 2)) ^ sn ^ 
      (s_int ((String.length sd) * signd / 2)) ^ sd

  | PyFloat x -> 
    let s = string_of_float x 
    in mTYPE_FLOAT ^ (sbyte (String.length s)) ^ s
  | PyComplex (x,y) -> 
    let sx = string_of_float x 
    and sy = string_of_float y
    in mTYPE_COMPLEX ^ 
      (sbyte (String.length sx)) ^ sx ^ 
      (sbyte (String.length sy)) ^ sy  
  | PyString s -> mTYPE_STRING ^ (s_int (String.length s)) ^ s
  | PyMutableList a -> 
    let l = Varray.to_list a 
    in mTYPE_LIST ^ (s_int (List.length l )) ^ (String.concat "" (List.map dump l))
  | PyTuple l ->
    mTYPE_TUPLE ^ (s_int (List.length l )) ^ (String.concat "" (List.map dump l))

  | PyDictionary d ->
    let data = ref mTYPE_DICT in
    d#iter
    begin fun k v -> data := !data ^ (dump k) ^ (dump v) end
    ;
    !data ^ mTYPE_NULL
  | x -> raise (NotImplemented ("Cannot marshal " ^ (Py_functions.repr x)))
  in PyString (dump (arg e 0))
;;

(* given a string and a start position, 
   return a tuple consisting of a parsed object,
   and the index of the first unparsed byte 
*)

let p_int s i = 
  if String.length s < i + 5
  then raise (ValueError "py_marshal_load: string too short, corrupted");
  (( int_of_char s.[i+3] * 256 +
  (int_of_char s.[i+2])) * 256 +
  (int_of_char s.[i+1])) * 256 +
  int_of_char s.[i]
;;

(* Note n' is negative is the sign is negative, and, it represents
   the number of 16 bit words, not the number of bytes
*)
let p_long s i n' =
  let n = abs n'
  and sign = if n' < 0 then -1 else 1
  in if String.length s < i + n
  then raise (ValueError "py_marshal_load: string too short, corrupted")
  else begin
    let acc = ref Big_int.zero_big_int in
    for j = i to i + 2 * n - 1 do
      acc := Big_int.add_int_big_int
        (int_of_char s.[j]) 
        (Big_int.mult_int_big_int 256 !acc)
    done;
    !acc
  end
;;

(* WARNING: not all this code does length checking yet *)
let py_pmarshal_load
  (interp:interpreter_t) 
  (e:expr_t list) 
  (d:dictionary_t): expr_t  = 
  empty_dict d;
  exactly e 2;
  let rec load s k = 
    let code = String.make 1 s.[k] in
    if code = mTYPE_INT then 
      PyInt (p_int s k+1), k+5
    
    else if code = mTYPE_INT64 then 
      PyInt (p_int s k+1), k+9 (* ignore high 4 bytes *)
    
    else if code = mTYPE_NONE then 
      PyNone, k+1 
    
    else if code = mTYPE_FLOAT then 
      let n = int_of_char s.[k+1] in
      PyFloat (float_of_string (String.sub s (k+2) n)), k+n+2
    
    else if code = mTYPE_COMPLEX then
      let nr = int_of_char s.[k+1] in
      let ni = int_of_char s.[k+2+nr] in
      PyComplex (
        float_of_string (String.sub s (k+2) nr), 
        float_of_string (String.sub s (k+3+nr) ni)
      )
      , 
      k+nr+ni+3
    else if code = mTYPE_LONG then
      let n = p_int s k+1 in 
      PyLong (p_long s (k+5) n), k+5+(abs n)
    
    else if code = mTYPE_STRING then
      let n = p_int s k+1 in (PyString (String.sub s (k+5) n)), k+5+n
    
    else if code = mTYPE_TUPLE then
      let n = p_int s k+1 in
      let tup = ref [] in
      let pos = ref (k+5) in
      for i=0 to n - 1 do
        let value, j = load s !pos in 
          tup := value :: !tup;
          pos := j
      done;
      PyTuple (List.rev !tup), !pos
      
    else if code = mTYPE_LIST then
      let n = p_int s k+1 in
      let tup = ref [] in
      let pos = ref (k+5) in
      for i=0 to n - 1 do
        let value, j = load s !pos in 
          tup := value :: !tup;
          pos := j
      done;
      PyMutableList (Varray.of_list (List.rev !tup)), !pos
 
    else if code = mTYPE_DICT then 
      let d = new Py_dict.py_dictionary in 
      let pos = ref (k+1) in
      while String.make 1 s.[!pos] <> mTYPE_NULL do
        let key, vpos = load s !pos in
        let value, pos' = load s vpos in
          pos := pos';
          ignore (d#set_item key, value)
      done;
      PyDictionary d, !pos + 1
      
    else if code = mTYPE_RATIONAL then
      let nnum = p_int s (k+1) in 
      let ndemom = p_int s (k+5+ (abs nnum)) in
      PyRational (
        Num.div_num
          (Num.num_of_big_int (p_long s (k+5) nnum))
          (Num.num_of_big_int  (p_long s (k+9+(abs nnum)) ndemom))
      )
      ,
      k+9+(abs nnum)+(abs ndemom)

    else if code = mTYPE_UNKNOWN then
      raise (ValueError "py_marshal_load: Unexpected UNKNOWN code in string")
    
    else if code = mTYPE_CODE then
      raise (ValueError "py_marshal_load: Unexpected CODE code in string")
    
    else if code = mTYPE_ELLIPSIS then
      raise (ValueError "py_marshal_load: Unexpected ELLIPSIS code in string")
    
    else if code = mTYPE_NULL then
      raise (ValueError "py_marshal_load: Unexpected NULL code in string")
    
    else 
      raise (ValueError "py_marshal_load: Unknown marshal code in string")

  in let start = 
    match arg e 1 with 
    | PyInt i -> i 
    | _ -> raise (TypeError "py_marshal_load requires integer index as second argument")
  in match arg e 0 with
  | PyString s -> 
    if String.length s > 0 
    then begin
      let value, index = load s 0
      in PyTuple [value; PyString (String.sub s index (String.length s - index))]
    end
    else raise (ValueError "Attempt to py_marshal_load empty string")
  | _ -> raise (TypeError "py_marshal_load requires string first argument")
;;

