001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016
017import org.eclipse.january.DatasetException;
018import org.eclipse.january.IMonitor;
019import org.eclipse.january.metadata.StatisticsMetadata;
020import org.eclipse.january.metadata.internal.StatisticsMetadataImpl;
021
022/**
023 * Generic container class for data that is compound in nature
024 * 
025 * Each subclass has an array of compound types, items of this array are composed of primitive types
026 * 
027 * Data items can be Complex, Vector, etc
028 * 
029 */
030public abstract class AbstractCompoundDataset extends AbstractDataset implements CompoundDataset {
031        // pin UID to base class
032        private static final long serialVersionUID = Dataset.serialVersionUID;
033
034        protected int isize; // number of elements per item
035
036        @Override
037        public int getElementsPerItem() {
038                return isize;
039        }
040
041        @Override
042        protected int get1DIndex(final int i) {
043                int n = super.get1DIndex(i);
044                return stride == null ? isize * n : n;
045        }
046
047        @Override
048        protected int get1DIndex(final int i, final int j) {
049                int n = super.get1DIndex(i, j);
050                return stride == null ? isize * n : n;
051        }
052
053        @Override
054        protected int get1DIndexFromShape(final int[] n) {
055                return isize * super.get1DIndexFromShape(n);
056        }
057
058        @Override
059        public Dataset getUniqueItems() {
060                throw new UnsupportedOperationException("Cannot sort compound datasets");
061        }
062
063        @Override
064        public IndexIterator getIterator(final boolean withPosition) {
065                if (stride != null) {
066                        return base.getSize() == 1 ? 
067                                        (withPosition ? new PositionIterator(offset, shape) :  new SingleItemIterator(offset, size)) : new StrideIterator(isize, shape, stride, offset);
068                }
069                return withPosition ? getSliceIterator(null, null, null) :
070                        new ContiguousIterator(size, isize);
071        }
072
073        /**
074         * Get an iterator that picks out the chosen element from all items
075         * @param element
076         * @return an iterator
077         */
078        public IndexIterator getIterator(int element) {
079                if (element < 0)
080                        element += isize;
081                if (element < 0 || element > isize) {
082                        logger.error("Invalid choice of element: {}/{}", element, isize);
083                        throw new IllegalArgumentException("Invalid choice of element: " + element + "/" + isize);
084                }
085
086                final IndexIterator it;
087                if (stride != null) {
088                        it = base.getSize() == 1 ? new SingleItemIterator(offset + element, size) : new StrideIterator(isize, shape, stride, offset, element);
089                } else {
090                        it = new ContiguousIterator(size, isize, element);
091                }
092
093                return it;
094        }
095
096        @Override
097        public IndexIterator getSliceIterator(SliceND slice) {
098                if (ShapeUtils.calcLongSize(slice.getShape()) == 0) {
099                        return new NullIterator(shape, slice.getShape());
100                }
101                if (stride != null) {
102                        return new StrideIterator(isize, shape, stride, offset, slice);
103                }
104
105                return new SliceIterator(shape, size, isize, slice);
106        }
107
108        /**
109         * Constructor required for serialisation.
110         */
111        public AbstractCompoundDataset() {
112        }
113
114        @Override
115        public boolean equals(Object obj) {
116                if (!super.equals(obj)) {
117                        return false;
118                }
119
120                CompoundDataset other = (CompoundDataset) obj;
121                return isize == other.getElementsPerItem();
122        }
123
124        @Override
125        public int hashCode() {
126                return getCompoundStats().getHash(shape);
127        }
128
129        @Override
130        public CompoundDataset cast(boolean repeat, int dtype, int isize) {
131                return (CompoundDataset) super.cast(repeat, dtype, isize);
132        }
133
134        @Override
135        public CompoundDataset cast(int dtype) {
136                return (CompoundDataset) super.cast(dtype);
137        }
138
139        @Override
140        abstract public AbstractCompoundDataset clone();
141
142        @Override
143        public CompoundDataset flatten() {
144                return (CompoundDataset) super.flatten();
145        }
146
147        @Override
148        public CompoundDataset getBy1DIndex(IntegerDataset index) {
149                return (CompoundDataset) super.getBy1DIndex(index);
150        }
151
152        @Override
153        public CompoundDataset getByBoolean(Dataset selection) {
154                return (CompoundDataset) super.getByBoolean(selection);
155        }
156
157        @Override
158        public CompoundDataset getByIndexes(Object... indexes) {
159                return (CompoundDataset) super.getByIndexes(indexes);
160        }
161
162        @Override
163        public CompoundDataset getSlice(IMonitor mon, int[] start, int[] stop, int[] step) {
164                return (CompoundDataset) super.getSlice(mon, start, stop, step);
165        }
166
167        @Override
168        public CompoundDataset getSlice(IMonitor mon, Slice... slice) {
169                return (CompoundDataset) super.getSlice(mon, slice);
170        }
171
172        @Override
173        public CompoundDataset getSlice(IMonitor mon, SliceND slice) {
174                return (CompoundDataset) super.getSlice(mon, slice);
175        }
176
177        @Override
178        public CompoundDataset getSlice(int[] start, int[] stop, int[] step) {
179                return (CompoundDataset) super.getSlice(start, stop, step);
180        }
181
182        @Override
183        public CompoundDataset getSlice(Slice... slice) {
184                return (CompoundDataset) super.getSlice(slice);
185        }
186
187        @Override
188        public CompoundDataset getSlice(SliceND slice) {
189                return (CompoundDataset) super.getSlice(slice);
190        }
191
192        @Override
193        abstract public AbstractCompoundDataset getSlice(SliceIterator iterator);
194
195        @Override
196        public CompoundDataset getSliceView(int[] start, int[] stop, int[] step) {
197                return (CompoundDataset) super.getSliceView(start, stop, step);
198        }
199
200        @Override
201        public CompoundDataset getSliceView(Slice... slice) {
202                return (CompoundDataset) super.getSliceView(slice);
203        }
204
205        @Override
206        public CompoundDataset getSliceView(SliceND slice) {
207                return (CompoundDataset) super.getSliceView(slice);
208        }
209
210        @Override
211        public CompoundDataset getTransposedView(int... axes) {
212                return (CompoundDataset) super.getTransposedView(axes);
213        }
214
215        @Override
216        abstract public AbstractCompoundDataset getView(boolean deepCopyMetadata);
217
218        @Override
219        public CompoundDataset getBroadcastView(int... broadcastShape) {
220                return (CompoundDataset) super.getBroadcastView(broadcastShape);
221        }
222
223        @Override
224        public CompoundDataset ifloorDivide(Object o) {
225                return (CompoundDataset) super.ifloorDivide(o);
226        }
227
228        @Override
229        public CompoundDataset reshape(int... shape) {
230                return (CompoundDataset) super.reshape(shape);
231        }
232
233        @Override
234        public CompoundDataset setSlice(Object obj, int[] start, int[] stop, int[] step) {
235                return (CompoundDataset) super.setSlice(obj, start, stop, step);
236        }
237
238        @Override
239        public CompoundDataset setSlice(Object object, Slice... slice) {
240                return (CompoundDataset) super.setSlice(object, slice);
241        }
242
243        @Override
244        public CompoundDataset sort(Integer axis) {
245                throw new UnsupportedOperationException("Cannot sort dataset");
246        }
247
248        @Override
249        public CompoundDataset squeezeEnds() {
250                return (CompoundDataset) super.squeezeEnds();
251        }
252
253        @Override
254        public CompoundDataset squeeze() {
255                return (CompoundDataset) super.squeeze();
256        }
257
258        @Override
259        public CompoundDataset squeeze(boolean onlyFromEnd) {
260                return (CompoundDataset) super.squeeze(onlyFromEnd);
261        }
262
263        @Override
264        public CompoundDataset swapAxes(int axis1, int axis2) {
265                return (CompoundDataset) super.swapAxes(axis1, axis2);
266        }
267
268        @Override
269        public synchronized CompoundDataset synchronizedCopy() {
270                return clone();
271        }
272
273        @Override
274        public CompoundDataset transpose(int... axes) {
275                return (CompoundDataset) super.transpose(axes);
276        }
277
278        /**
279         * @since 2.0
280         */
281        abstract protected double getFirstValue();
282
283        abstract protected double getFirstValue(final int i);
284
285        abstract protected double getFirstValue(final int i, final int j);
286
287        abstract protected double getFirstValue(final int...pos);
288
289        @Override
290        public boolean getBoolean() {
291                return getFirstValue() != 0;
292        }
293
294        @Override
295        public boolean getBoolean(final int i) {
296                return getFirstValue(i) != 0;
297        }
298
299        @Override
300        public boolean getBoolean(final int i, final int j) {
301                return getFirstValue(i, j) != 0;
302        }
303
304        @Override
305        public boolean getBoolean(final int... pos) {
306                return getFirstValue(pos) != 0;
307        }
308
309        @Override
310        public byte getByte() {
311                return (byte) getFirstValue();
312        }
313
314        @Override
315        public byte getByte(final int i) {
316                return (byte) getFirstValue(i);
317        }
318
319        @Override
320        public byte getByte(final int i, final int j) {
321                return (byte) getFirstValue(i, j);
322        }
323
324        @Override
325        public byte getByte(final int... pos) {
326                return (byte) getFirstValue(pos);
327        }
328
329        @Override
330        public short getShort() {
331                return (short) getFirstValue();
332        }
333
334        @Override
335        public short getShort(final int i) {
336                return (short) getFirstValue(i);
337        }
338
339        @Override
340        public short getShort(final int i, final int j) {
341                return (short) getFirstValue(i, j);
342        }
343
344        @Override
345        public short getShort(final int... pos) {
346                return (short) getFirstValue(pos);
347        }
348
349        @Override
350        public int getInt() {
351                return (int) getFirstValue();
352        }
353
354        @Override
355        public int getInt(final int i) {
356                return (int) getFirstValue(i);
357        }
358
359        @Override
360        public int getInt(final int i, final int j) {
361                return (int) getFirstValue(i, j);
362        }
363
364        @Override
365        public int getInt(final int... pos) {
366                return (int) getFirstValue(pos);
367        }
368
369        @Override
370        public long getLong() {
371                return (long) getFirstValue();
372        }
373
374        @Override
375        public long getLong(final int i) {
376                return (long) getFirstValue(i);
377        }
378
379        @Override
380        public long getLong(final int i, final int j) {
381                return (long) getFirstValue(i, j);
382        }
383
384        @Override
385        public long getLong(final int... pos) {
386                return (long) getFirstValue(pos);
387        }
388
389        @Override
390        public float getFloat() {
391                return (float) getFirstValue();
392        }
393
394        @Override
395        public float getFloat(final int i) {
396                return (float) getFirstValue(i);
397        }
398
399        @Override
400        public float getFloat(final int i, final int j) {
401                return (float) getFirstValue(i, j);
402        }
403
404        @Override
405        public float getFloat(final int... pos) {
406                return (float) getFirstValue(pos);
407        }
408
409        @Override
410        public double getDouble() {
411                return getFirstValue();
412        }
413
414        @Override
415        public double getDouble(final int i) {
416                return getFirstValue(i);
417        }
418
419        @Override
420        public double getDouble(final int i, final int j) {
421                return getFirstValue(i, j);
422        }
423
424        @Override
425        public double getDouble(final int... pos) {
426                return getFirstValue(pos);
427        }
428
429        @Override
430        public void getDoubleArray(final double[] darray) {
431                getDoubleArrayAbs(getFirst1DIndex(), darray);
432        }
433
434        @Override
435        public void getDoubleArray(final double[] darray, final int i) {
436                getDoubleArrayAbs(get1DIndex(i), darray);
437        }
438
439        @Override
440        public void getDoubleArray(final double[] darray, final int i, final int j) {
441                getDoubleArrayAbs(get1DIndex(i, j), darray);
442        }
443
444        @Override
445        public void getDoubleArray(final double[] darray, final int... pos) {
446                getDoubleArrayAbs(get1DIndex(pos), darray);
447        }
448
449        /**
450         * @since 2.0
451         */
452        @SuppressWarnings("unchecked")
453        protected StatisticsMetadata<double[]> getCompoundStats() {
454                StatisticsMetadata<double[]> md = getFirstMetadata(StatisticsMetadata.class);
455                if (md == null || md.isDirty()) {
456                        md = new StatisticsMetadataImpl<double[]>();
457                        md.initialize(this);
458                        setMetadata(md);
459                }
460                return md;
461        }
462
463        @Override
464        public IntegerDataset argMax(int axis, boolean... ignoreInvalids) {
465                logger.error("Cannot compare compound numbers");
466                throw new UnsupportedOperationException("Cannot compare compound numbers");
467        }
468
469        @Override
470        public IntegerDataset argMin(int axis, boolean... ignoreInvalids) {
471                logger.error("Cannot compare compound numbers");
472                throw new UnsupportedOperationException("Cannot compare compound numbers");
473        }
474
475        @Override
476        public Number max(boolean... ignoreInvalids) {
477                logger.error("Cannot compare compound numbers");
478                throw new UnsupportedOperationException("Cannot compare compound numbers");
479        }
480
481        @Override
482        public CompoundDataset max(int axis, boolean... ignoreInvalids) {
483                logger.error("Cannot compare compound numbers");
484                throw new UnsupportedOperationException("Cannot compare compound numbers");
485        }
486
487        @Override
488        public Number min(boolean... ignoreInvalids) {
489                logger.error("Cannot compare compound numbers");
490                throw new UnsupportedOperationException("Cannot compare compound numbers");
491        }
492
493        @Override
494        public CompoundDataset min(int axis, boolean... ignoreInvalids) {
495                logger.error("Cannot compare compound numbers");
496                throw new UnsupportedOperationException("Cannot compare compound numbers");
497        }
498
499
500        @Override
501        public int[] maxPos(boolean... ignoreNaNs) {
502                logger.error("Cannot compare compound numbers");
503                throw new UnsupportedOperationException("Cannot compare compound numbers");
504        }
505
506        @Override
507        public int[] minPos(boolean... ignoreNaNs) {
508                logger.error("Cannot compare compound numbers");
509                throw new UnsupportedOperationException("Cannot compare compound numbers");
510        }
511
512        @Override
513        public CompoundDataset peakToPeak(int axis, boolean... ignoreInvalids) {
514                logger.error("Cannot compare compound numbers");
515                throw new UnsupportedOperationException("Cannot compare compound numbers");
516        }
517
518        @Override
519        public double[] maxItem() {
520                return getCompoundStats().getMaximum();
521        }
522
523        @Override
524        public double[] minItem() {
525                return getCompoundStats().getMinimum();
526        }
527
528        @Override
529        public Object mean(boolean... ignoreInvalids) {
530                return getCompoundStats().getMean();
531        }
532
533        @Override
534        public CompoundDataset mean(int axis, boolean... ignoreInvalids) {
535                return (CompoundDataset) super.mean(axis, ignoreInvalids);
536        }
537
538        @Override
539        public CompoundDataset product(int axis, boolean... ignoreInvalids) {
540                return (CompoundDataset) super.product(axis, ignoreInvalids);
541        }
542
543        @Override
544        public CompoundDataset rootMeanSquare(int axis, boolean... ignoreInvalids) {
545                return (CompoundDataset) super.rootMeanSquare(axis, ignoreInvalids);
546        }
547
548        @Override
549        public CompoundDataset stdDeviation(int axis) {
550                return (CompoundDataset) super.stdDeviation(axis, false);
551        }
552
553        @Override
554        public CompoundDataset stdDeviation(int axis, boolean isWholePopulation, boolean... ignoreInvalids) {
555                return (CompoundDataset) super.stdDeviation(axis, isWholePopulation, ignoreInvalids);
556        }
557
558        @Override
559        public Object sum(boolean... ignoreInvalids) {
560                return getCompoundStats().getSum();
561        }
562
563        @Override
564        public CompoundDataset sum(int axis, boolean... ignoreInvalids) {
565                return (CompoundDataset) super.sum(axis, ignoreInvalids);
566        }
567
568        @Override
569        public double variance(boolean isWholePopulation, boolean... ignoreInvalids) {
570                return getCompoundStats().getVariance(isWholePopulation, ignoreInvalids);
571        }
572
573        @Override
574        public CompoundDataset variance(int axis) {
575                return (CompoundDataset) super.variance(axis, false);
576        }
577
578        @Override
579        public CompoundDataset variance(int axis, boolean isWholePopulation, boolean... ignoreInvalids) {
580                return (CompoundDataset) super.variance(axis, isWholePopulation, ignoreInvalids);
581        }
582
583        @Override
584        public double rootMeanSquare(boolean... ignoreInvalids) {
585                StatisticsMetadata<double[]> stats = getCompoundStats();
586
587                double[] mean = stats.getMean(ignoreInvalids);
588                double result = 0;
589                for (int i = 0; i < isize; i++) {
590                        double m = mean[i];
591                        result += m * m;
592                }
593                return Math.sqrt(result + stats.getVariance(true));
594        }
595
596        /**
597         * @return error
598         */
599        private CompoundDataset getInternalError() {
600                ILazyDataset led = super.getErrors();
601                if (led == null)
602                        return null;
603
604                Dataset ed = null;
605                try {
606                        ed = DatasetUtils.sliceAndConvertLazyDataset(led);
607                } catch (DatasetException e) {
608                        logger.error("Could not get data from lazy dataset", e);
609                }
610
611                CompoundDataset ced; // ensure it has the same number of elements
612                if (!(ed instanceof CompoundDataset) || ed.getElementsPerItem() != isize) {
613                        ced = new CompoundDoubleDataset(isize, true, ed);
614                } else {
615                        ced = (CompoundDataset) ed;
616                }
617                
618                if (led != ced) {
619                        setErrors(ced); // set back
620                }
621                return ced;
622        }
623
624        @Override
625        public CompoundDataset getErrors() {
626                CompoundDataset ed = getInternalError();
627                if (ed == null)
628                        return null;
629
630                return ed.getBroadcastView(shape);
631        }
632
633        @Override
634        public double getError(final int i) {
635                return calcError(getInternalErrorArray(true, i));
636        }
637
638        @Override
639        public double getError(final int i, final int j) {
640                return calcError(getInternalErrorArray(true, i, j));
641        }
642
643        @Override
644        public double getError(final int... pos) {
645                return calcError(getInternalErrorArray(true, pos));
646        }
647
648        private double calcError(double[] es) {
649                if (es == null)
650                        return 0;
651
652                // assume elements are independent
653                double e = 0;
654                for (int k = 0; k < isize; k++) {
655                        e += es[k];
656                }
657
658                return Math.sqrt(e);
659        }
660
661        @Override
662        public double[] getErrorArray(final int i) {
663                return getInternalErrorArray(false, i);
664        }
665
666        @Override
667        public double[] getErrorArray(final int i, final int j) {
668                return getInternalErrorArray(false, i, j);
669        }
670
671        @Override
672        public double[] getErrorArray(final int... pos) {
673                return getInternalErrorArray(false, pos);
674        }
675
676        private Dataset getInternalError(final boolean squared) {
677                Dataset sed = squared ? getInternalSquaredError() : getInternalError();
678                if (sed == null)
679                        return null;
680
681                return sed.getBroadcastView(shape);
682        }
683
684        private double[] getInternalErrorArray(final boolean squared, final int i) {
685                Dataset sed = getInternalError(squared);
686                if (sed == null)
687                        return null;
688
689                double[] es;
690                if (sed instanceof CompoundDoubleDataset) {
691                        es = ((CompoundDoubleDataset) sed).getDoubleArray(i);
692                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
693                                Arrays.fill(es, es[0]);
694                        }
695                } else {
696                        es = new double[isize];
697                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(i));
698                }
699                return es;
700        }
701
702        private double[] getInternalErrorArray(final boolean squared, final int i, final int j) {
703                Dataset sed = getInternalError(squared);
704                if (sed == null)
705                        return null;
706
707                double[] es;
708                if (sed instanceof CompoundDoubleDataset) {
709                        es = ((CompoundDoubleDataset) sed).getDoubleArray(i, j);
710                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
711                                Arrays.fill(es, es[0]);
712                        }
713                } else {
714                        es = new double[isize];
715                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(i, j));
716                }
717                return es;
718        }
719
720        private double[] getInternalErrorArray(final boolean squared, final int... pos) {
721                Dataset sed = getInternalError(squared);
722                if (sed == null)
723                        return null;
724
725                double[] es = new double[isize];
726                if (sed instanceof CompoundDoubleDataset) {
727                        es = ((CompoundDoubleDataset) sed).getDoubleArray(pos);
728                        if (sed.getElementsPerItem() != isize) { // ensure error is broadcasted
729                                Arrays.fill(es, es[0]);
730                        }
731                } else {
732                        es = new double[isize];
733                        Arrays.fill(es, ((DoubleDataset) sed).getDouble(pos));
734                }
735                return es;
736        }
737}
738