(ns thi.ng.ndarray.macros)

(defn- type-hinted
  [type x] (if type (with-meta x {:tag (name type)}) x))

(defn- make-symbols
  [id n] (mapv #(symbol (str id %)) (range n)))

(defn- pair-fn
  [f coll]
  (let [coll (->> coll
                  (partition-all 2)
                  (map #(if (< 1 (count %)) (cons f %) (first %))))]
    (if (> (count coll) 2) (recur f coll) coll)))

(defn- make-indexer
  [dim ->st p]
  `(int (+ ~@(->> (range dim)
                  (map #(list '* (->st %) `(int (~p ~%))))
                  (cons '_offset)
                  (pair-fn '+)))))

(defn- make-indexer-syms
  [dim ->st ->p]
  `(int (+ ~@(->> (range dim)
                  (map #(list '* (->st %) `(int ~(->p %))))
                  (cons '_offset)
                  (pair-fn '+)))))

(defmacro and*
  "Like clojure.core/and, but avoids intermediate let bindings and
  only ever returns either result of last form (if all previous
  succeeded) or nil."
  ([x] x)
  ([x & more] `(if ~x (and* ~@more))))

(defn- with-bounds-check
  [dim psyms shapes clj? & body]
  `(if (and*
        ~@(mapcat
           #(let [p (symbol (psyms %))]
              (list `(>= ~p 0) `(< ~p ~(symbol (shapes %)))))
           (range dim)))
     (do ~@body)
     (throw
      (new ~(if clj? 'IndexOutOfBoundsException 'js/Error)
           (str "Invalid index: " (pr-str [~@psyms]))))))

(defn- for*
  [->a ->sh rdim body]
  `(for [~@(mapcat #(vector (->a %) `(range ~(->sh %))) rdim)] ~body))

(defn- inject-clj-protos
  [clj? get data ->a ->sh idx rdim]
  (if clj?
    (list
     'clojure.lang.Seqable
     `(~'seq
       [_#]
       ~(for* ->a ->sh rdim `(~get ~data ~idx))))
    (list
     'ISeqable
     `(~'-seq
       [_#]
       ~(for* ->a ->sh rdim `(~get ~data ~idx))))))

(defn- do-cast
  [cast body]
  (if cast `(~cast ~body) body))

(defmacro def-ndarray
  [dim cast type-hint type-id data-ctor get set & [clj?]]
  (let [type-name (symbol (str "NDArray" dim (name type-id)))
        raw-name  (symbol (str "make-raw-ndarray" dim "-" (name type-id)))
        strides   (make-symbols "_stride" dim)
        shapes    (make-symbols "_shape" dim)
        asyms     (make-symbols "a" dim)
        bsyms     (make-symbols "b" dim)
        psyms     (make-symbols "p" dim)
        [->st ->sh ->a ->b ->p] (map #(comp symbol %) [strides shapes asyms bsyms psyms])
        [c d f p o] (repeatedly gensym)
        idx       (make-indexer dim ->st p)
        idx-syms  (make-indexer-syms dim ->st ->p)
        data      (type-hinted type-hint '_data)
        rdim      (range dim)]
    `(do
       (deftype ~type-name
           [~data ~'_offset ~@strides ~@shapes]
         ~@(inject-clj-protos clj? get data ->a ->sh (make-indexer-syms dim ->st ->a) rdim)
         ~'thi.ng.ndarray.core/PNDArray
         (~'data
           [_#] ~data)
         (~'data-type
           [_#] ~type-id)
         (~'dimension
           [_#] ~dim)
         (~'stride
           [_#] [~@strides])
         (~'shape
           [_#] [~@shapes])
         (~'offset
           [_#] ~'_offset)
         (~'size
           [_#] (* ~@(pair-fn '* shapes)))
         (~'extract
           [_#]
           (let [buf#      ~(type-hinted type-hint `(~data-ctor (* ~@(pair-fn '* shapes))))
                 [~@asyms] (thi.ng.ndarray.core/shape->stride [~@shapes])
                 arr#      (new ~type-name buf# 0 ~@asyms ~@shapes)]
             (loop [~c (thi.ng.ndarray.core/index-seq _#)
                    ~d (thi.ng.ndarray.core/index-seq arr#)]
               (when ~c
                 (~set buf# (int (first ~d)) ~(do-cast cast `(~get ~data (int (first ~c)))))
                 (recur (next ~c) (next ~d))))
             arr#))
         (~'index-at
           [_# ~@psyms] ~idx-syms)
         (~'index-pos
           [_# ~p]
           (let [~p (int ~p)
                 ~c (- ~p ~'_offset)
                 ~@(drop-last
                    2 (mapcat
                       #(let [a (->a %) s (->st %)]
                          (list a `(int (/ ~c ~s))
                                c `(- ~c (* ~a ~s))))
                       rdim))]
             [~@asyms]))
         (~'index-seq
           [_#]
           ~(let [idx (make-indexer-syms dim ->st ->a)]
              (for* ->a ->sh rdim idx)))
         (~'position-seq
           [_#] ~(for* ->a ->sh rdim `[~@asyms]))
         (~'get-at
           [_# ~@psyms] (~get ~data ~idx-syms))
         (~'get-at-safe
           [_# ~@psyms]
           ~(with-bounds-check dim psyms shapes clj?
              `(~get ~data ~idx-syms)))
         (~'get-at-index
           [_# i#] (~get ~data (int i#)))
         (~'set-at
           [_# ~@psyms ~c] (~set ~data ~idx-syms ~(do-cast cast c)) _#)
         (~'set-at-safe
           [_# ~@psyms ~c]
           ~(with-bounds-check dim psyms shapes clj?
              `(~set ~data ~idx-syms ~(do-cast cast c)))
           _#)
         (~'set-at-index
           [_# i# ~c] (~set ~data (int i#) ~(do-cast cast c)) _#)
         (~'update-at
           [_# ~@psyms ~f]
           (let [~c ~idx-syms]
             (~set ~data ~c ~(do-cast cast `(~f ~@psyms (~get ~data ~c)))))
           _#)
         (~'update-at-safe
           [_# ~@psyms ~f]
           ~(with-bounds-check dim psyms shapes clj?
              `(let [~c ~idx-syms]
                 (~set ~data ~c ~(do-cast cast `(~f ~@psyms (~get ~data ~c))))))
           _#)
         (~'update-at-index
           [_# ~c ~f] (~set ~data ~c ~(do-cast cast `(~f ~c (~get ~data (int ~c))))) _#)
         (~'truncate-h
           [_# ~@psyms]
           (new ~type-name ~data ~'_offset ~@strides
                ~@(map
                   #(let [p (->p %) s (->sh %)]
                      `(if (number? ~p)
                         (if (neg? ~p)
                           (+ ~s (int ~p))
                           (int ~p))
                         ~s))
                   rdim)))
         (~'truncate-l
           [_# ~@psyms]
           (let [~@(mapcat
                    #(let [p (->p %) sh (->sh %) st (->st %)]
                       (list
                        [(->a %) (->b %)]
                        `(if (pos? ~p)
                           [(- ~sh (int ~p))
                            (* ~st (int ~p))]
                           [~sh 0])))
                    rdim)
                 ~o (+ ~@(->> rdim (map ->b) (cons '_offset) (pair-fn '+)))]
             (new ~type-name ~data ~o ~@strides ~@asyms)))
         (~'transpose
           [_# ~@psyms]
           (let [~@(mapcat #(let [p (->p %)] (list p `(if ~p (int ~p) ~%))) rdim)
                 ~c [~@strides]
                 ~d [~@shapes]]
             (new ~type-name ~data ~'_offset
                  ~@(map #(list c (->p %)) rdim)
                  ~@(map #(list d (->p %)) rdim))))
         (~'step
           [_# ~@psyms]
           (let [~o ~'_offset
                 ~@(mapcat
                    #(let [p (->p %) sh (->sh %) st (->st %)
                           stride' `(* ~st (int ~p))]
                       (list
                        [(->a %) (->b %) o]
                        `(if (number? ~p)
                           (if (neg? ~p)
                             [(int (~'Math/ceil (/ (- ~sh) (int ~p))))
                              ~stride'
                              (+ ~o (* ~st (dec ~sh)))]
                             [(int (~'Math/ceil (/ ~sh (int ~p))))
                              ~stride'
                              ~o])
                           [~sh ~st ~o])))
                    rdim)]
             (new ~type-name ~data ~o ~@bsyms ~@asyms)))
         (~'pick
           [_# ~@psyms]
           (let [~o ~'_offset, ~c [], ~d []
                 ~@(mapcat
                    #(let [p (->p %) sh (->sh %) st (->st %)]
                       (list
                        [c d o]
                        `(if (and (number? ~p) (>= ~p 0))
                           [~c ~d (+ ~o (* ~st (int ~p)))]
                           [(conj ~c ~sh) (conj ~d ~st) ~o])))
                    rdim)
                 cnt# (count ~c)]
             (if (pos? cnt#)
               ((get-in @~'thi.ng.ndarray.core/ctor-registry [cnt# ~type-id :ctor]) ~data ~o ~d ~c)
               (~get ~data (int ~o)))))
         ~'Object
         (~'toString
           [_#]
           (pr-str
            {:data ~data :type ~type-id
             :size (* ~@(pair-fn '* shapes)) :total (count (seq _#)) :offset ~'_offset
             :shape [~@shapes] :stride [~@strides]})))

       (defn ~(with-meta raw-name {:export true})
         [data# o# [~@strides] [~@shapes]]
         (new ~type-name data# o# ~@strides ~@shapes))

       (swap!
        ~'thi.ng.ndarray.core/ctor-registry
        assoc-in [~dim ~type-id]
        {:ctor ~raw-name
         :data-ctor ~data-ctor}))))
