001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.io.Serializable; 016import java.lang.reflect.Array; 017import java.util.ArrayList; 018import java.util.Arrays; 019import java.util.Comparator; 020import java.util.List; 021 022import org.apache.commons.math3.util.MathArrays; 023import org.eclipse.january.DatasetException; 024import org.slf4j.Logger; 025import org.slf4j.LoggerFactory; 026 027/** 028 * Utilities for manipulating datasets 029 */ 030@SuppressWarnings("unchecked") 031public class DatasetUtils { 032 033 /** 034 * Setup the logging facilities 035 */ 036 transient protected static final Logger utilsLogger = LoggerFactory.getLogger(DatasetUtils.class); 037 038 /** 039 * Append copy of dataset with another dataset along n-th axis 040 * 041 * @param a 042 * @param b 043 * @param axis 044 * number of axis (negative number counts from last) 045 * @return appended dataset 046 */ 047 public static Dataset append(IDataset a, IDataset b, int axis) { 048 final int[] ashape = a.getShape(); 049 final int rank = ashape.length; 050 final int[] bshape = b.getShape(); 051 if (rank != bshape.length) { 052 throw new IllegalArgumentException("Incompatible number of dimensions"); 053 } 054 axis = ShapeUtils.checkAxis(rank, axis); 055 056 for (int i = 0; i < rank; i++) { 057 if (i != axis && ashape[i] != bshape[i]) { 058 throw new IllegalArgumentException("Incompatible dimensions"); 059 } 060 } 061 final int[] nshape = new int[rank]; 062 for (int i = 0; i < rank; i++) { 063 nshape[i] = ashape[i]; 064 } 065 nshape[axis] += bshape[axis]; 066 final int ot = DTypeUtils.getDType(b); 067 final int dt = DTypeUtils.getDType(a); 068 @SuppressWarnings("deprecation") 069 Dataset ds = DatasetFactory.zeros(a.getElementsPerItem(), nshape, dt > ot ? dt : ot); 070 IndexIterator iter = ds.getIterator(true); 071 int[] pos = iter.getPos(); 072 while (iter.hasNext()) { 073 int d = ashape[axis]; 074 if (pos[axis] < d) { 075 ds.setObjectAbs(iter.index, a.getObject(pos)); 076 } else { 077 pos[axis] -= d; 078 ds.setObjectAbs(iter.index, b.getObject(pos)); 079 pos[axis] += d; 080 } 081 } 082 083 return ds; 084 } 085 086 /** 087 * Changes specific items of dataset by replacing them with other array 088 * @param a 089 * @param indices dataset interpreted as integers 090 * @param values 091 * @return changed dataset 092 */ 093 public static <T extends Dataset> T put(final T a, final Dataset indices, Object values) { 094 IndexIterator it = indices.getIterator(); 095 Dataset vd = DatasetFactory.createFromObject(values).flatten(); 096 int vlen = vd.getSize(); 097 int v = 0; 098 while (it.hasNext()) { 099 if (v >= vlen) v -= vlen; 100 101 a.setObjectAbs((int) indices.getElementLongAbs(it.index), vd.getObjectAbs(v++)); 102 } 103 return a; 104 } 105 106 /** 107 * Changes specific items of dataset by replacing them with other array 108 * @param a 109 * @param indices 110 * @param values 111 * @return changed dataset 112 */ 113 public static <T extends Dataset> T put(final T a, final int[] indices, Object values) { 114 int ilen = indices.length; 115 Dataset vd = DatasetFactory.createFromObject(values).flatten(); 116 int vlen = vd.getSize(); 117 for (int i = 0, v= 0; i < ilen; i++) { 118 if (v >= vlen) v -= vlen; 119 120 a.setObjectAbs(indices[i], vd.getObjectAbs(v++)); 121 } 122 return a; 123 } 124 125 /** 126 * Take items from dataset along an axis 127 * @param indices dataset interpreted as integers 128 * @param axis if null, then use flattened view 129 * @return a sub-array 130 */ 131 public static <T extends Dataset> T take(final T a, final Dataset indices, Integer axis) { 132 IntegerDataset indexes = (IntegerDataset) indices.flatten().cast(Dataset.INT32); 133 return take(a, indexes.getData(), axis); 134 } 135 136 /** 137 * Take items from dataset along an axis 138 * @param indices 139 * @param axis if null, then use flattened view 140 * @return a sub-array 141 */ 142 @SuppressWarnings("deprecation") 143 public static <T extends Dataset> T take(final T a, final int[] indices, Integer axis) { 144 if (indices == null || indices.length == 0) { 145 utilsLogger.error("No indices given"); 146 throw new IllegalArgumentException("No indices given"); 147 } 148 int[] ashape = a.getShape(); 149 final int rank = ashape.length; 150 final int at = a.getDType(); 151 final int ilen = indices.length; 152 final int is = a.getElementsPerItem(); 153 154 Dataset result; 155 if (axis == null) { 156 ashape = new int[1]; 157 ashape[0] = ilen; 158 result = DatasetFactory.zeros(is, ashape, at); 159 Serializable src = a.getBuffer(); 160 for (int i = 0; i < ilen; i++) { 161 ((AbstractDataset) result).setItemDirect(i, indices[i], src); 162 } 163 } else { 164 axis = a.checkAxis(axis); 165 ashape[axis] = ilen; 166 result = DatasetFactory.zeros(is, ashape, at); 167 168 int[] dpos = new int[rank]; 169 int[] spos = new int[rank]; 170 boolean[] axes = new boolean[rank]; 171 Arrays.fill(axes, true); 172 axes[axis] = false; 173 Serializable src = a.getBuffer(); 174 for (int i = 0; i < ilen; i++) { 175 spos[axis] = indices[i]; 176 dpos[axis] = i; 177 SliceIterator siter = a.getSliceIteratorFromAxes(spos, axes); 178 SliceIterator diter = result.getSliceIteratorFromAxes(dpos, axes); 179 180 while (siter.hasNext() && diter.hasNext()) { 181 ((AbstractDataset) result).setItemDirect(diter.index, siter.index, src); 182 } 183 } 184 } 185 result.setDirty(); 186 return (T) result; 187 } 188 189 /** 190 * Construct a dataset that contains the original dataset repeated the number 191 * of times in each axis given by corresponding entries in the reps array 192 * 193 * @param a 194 * @param reps 195 * @return tiled dataset 196 */ 197 public static Dataset tile(final IDataset a, int... reps) { 198 int[] shape = a.getShape(); 199 int rank = shape.length; 200 final int rlen = reps.length; 201 202 // expand shape 203 if (rank < rlen) { 204 int[] newShape = new int[rlen]; 205 int extraRank = rlen - rank; 206 for (int i = 0; i < extraRank; i++) { 207 newShape[i] = 1; 208 } 209 for (int i = 0; i < rank; i++) { 210 newShape[i+extraRank] = shape[i]; 211 } 212 213 shape = newShape; 214 rank = rlen; 215 } else if (rank > rlen) { 216 int[] newReps = new int[rank]; 217 int extraRank = rank - rlen; 218 for (int i = 0; i < extraRank; i++) { 219 newReps[i] = 1; 220 } 221 for (int i = 0; i < rlen; i++) { 222 newReps[i+extraRank] = reps[i]; 223 } 224 reps = newReps; 225 } 226 227 // calculate new shape 228 int[] newShape = new int[rank]; 229 for (int i = 0; i < rank; i++) { 230 newShape[i] = shape[i]*reps[i]; 231 } 232 233 @SuppressWarnings("deprecation") 234 Dataset tdata = DatasetFactory.zeros(a.getElementsPerItem(), newShape, DTypeUtils.getDType(a)); 235 236 // decide which way to put slices 237 boolean manyColumns; 238 if (rank == 1) 239 manyColumns = true; 240 else 241 manyColumns = shape[rank-1] > 64; 242 243 if (manyColumns) { 244 // generate each start point and put a slice in 245 IndexIterator iter = tdata.getSliceIterator(null, null, shape); 246 SliceIterator siter = (SliceIterator) tdata.getSliceIterator(null, shape, null); 247 final int[] pos = iter.getPos(); 248 while (iter.hasNext()) { 249 siter.setStart(pos); 250 tdata.setSlice(a, siter); 251 } 252 253 } else { 254 // for each value, set slice given by repeats 255 final int[] skip = new int[rank]; 256 for (int i = 0; i < rank; i++) { 257 if (reps[i] == 1) { 258 skip[i] = newShape[i]; 259 } else { 260 skip[i] = shape[i]; 261 } 262 } 263 264 Dataset aa = convertToDataset(a); 265 IndexIterator ita = aa.getIterator(true); 266 final int[] pos = ita.getPos(); 267 268 final int[] sstart = new int[rank]; 269 final int extra = rank - pos.length; 270 for (int i = 0; i < extra; i++) { 271 sstart[i] = 0; 272 } 273 SliceIterator siter = (SliceIterator) tdata.getSliceIterator(sstart, null, skip); 274 while (ita.hasNext()) { 275 for (int i = 0; i < pos.length; i++) { 276 sstart[i + extra] = pos[i]; 277 } 278 siter.setStart(sstart); 279 tdata.setSlice(aa.getObjectAbs(ita.index), siter); 280 } 281 } 282 283 return tdata; 284 } 285 286 /** 287 * Permute copy of dataset's axes so that given order is old order: 288 * <pre> 289 * axisPerm = (p(0), p(1),...) => newdata(n(0), n(1),...) = olddata(o(0), o(1), ...) 290 * such that n(i) = o(p(i)) for all i 291 * </pre> 292 * I.e. for a 3D dataset (1,0,2) implies the new dataset has its 1st dimension 293 * running along the old dataset's 2nd dimension and the new 2nd is the old 1st. 294 * The 3rd dimension is left unchanged. 295 * 296 * @param a 297 * @param axes if null or zero length then axes order reversed 298 * @return remapped copy of data 299 */ 300 public static Dataset transpose(final IDataset a, int... axes) { 301 return convertToDataset(a).transpose(axes); 302 } 303 304 /** 305 * Swap two axes in dataset 306 * @param a 307 * @param axis1 308 * @param axis2 309 * @return swapped dataset 310 */ 311 public static Dataset swapAxes(final IDataset a, int axis1, int axis2) { 312 return convertToDataset(a).swapAxes(axis1, axis2); 313 } 314 315 /** 316 * @param a 317 * @return sorted flattened copy of dataset 318 */ 319 public static <T extends Dataset> T sort(final T a) { 320 return sort(a, (Integer) null); 321 } 322 323 /** 324 * @param a 325 * @param axis to sort along, if null then dataset is first flattened 326 * @return dataset sorted along axis 327 */ 328 public static <T extends Dataset> T sort(final T a, final Integer axis) { 329 Dataset s = a.clone(); 330 return (T) s.sort(axis); 331 } 332 333 /** 334 * Sort in place given dataset and reorder ancillary datasets too 335 * @param a dataset to be sorted 336 * @param b ancillary datasets 337 */ 338 public static void sort(Dataset a, Dataset... b) { 339 if (!DTypeUtils.isDTypeNumerical(a.getDType())) { 340 throw new UnsupportedOperationException("Sorting non-numerical datasets not supported yet"); 341 } 342 343 // gather all datasets as double dataset copies 344 DoubleDataset s = copy(DoubleDataset.class, a); 345 int l = b == null ? 0 : b.length; 346 DoubleDataset[] t = new DoubleDataset[l]; 347 int n = 0; 348 for (int i = 0; i < l; i++) { 349 if (b[i] != null) { 350 if (!DTypeUtils.isDTypeNumerical(b[i].getDType())) { 351 throw new UnsupportedOperationException("Sorting non-numerical datasets not supported yet"); 352 } 353 t[i] = copy(DoubleDataset.class, b[i]); 354 n++; 355 } 356 } 357 358 double[][] y = new double[n][]; 359 for (int i = 0, j = 0; i < l; i++) { 360 if (t[i] != null) { 361 y[j++] = t[i].getData(); 362 } 363 } 364 365 MathArrays.sortInPlace(s.getData(), y); 366 367 a.setSlice(s); 368 for (int i = 0; i < l; i++) { 369 if (b[i] != null) { 370 b[i].setSlice(t[i]); 371 } 372 } 373 } 374 375 /** 376 * Indirectly sort along given axis 377 * @param a dataset whose indexes will be sorted 378 * @param axis to sort along, if null then dataset is first flattened 379 * @return indexes 380 * @since 2.1 381 */ 382 public static IntegerDataset indexSort(Dataset a, Integer axis) { 383 if (axis == null) { 384 int size = a.getSize(); 385 Integer[] index = new Integer[size]; 386 for (int i = 0; i < size; i++) { 387 index[i] = i; 388 } 389 final Dataset f = a.flatten(); // is this correct for views??? Check with NumPy 390 Comparator<Integer> cmp = new Comparator<Integer>() { 391 392 @Override 393 public int compare(Integer o1, Integer o2) { 394 395 return Double.compare(f.getElementDoubleAbs(o1), f.getElementDoubleAbs(o2)); 396 } 397 }; 398 Arrays.sort(index, cmp); 399 return DatasetFactory.createFromObject(IntegerDataset.class, index); 400 } 401 402 axis = a.checkAxis(axis); 403 final int[] shape = a.getShapeRef(); 404 IntegerDataset id = DatasetFactory.zeros(IntegerDataset.class, shape); 405 int size = shape[axis]; 406 Integer[] index = new Integer[size]; 407 408 int[] dShape = new int[shape.length]; 409 Arrays.fill(dShape, 1); 410 dShape[axis] = size; 411 final DoubleDataset dd = DatasetFactory.zeros(DoubleDataset.class, dShape); 412 final Comparator<Integer> cmp = new Comparator<Integer>() { 413 @Override 414 public int compare(Integer o1, Integer o2) { 415 416 return Double.compare(dd.getElementDoubleAbs(o1), dd.getElementDoubleAbs(o2)); 417 } 418 }; 419 420 SliceND ds = new SliceND(dShape); 421 SliceNDIterator it = new SliceNDIterator(new SliceND(shape), axis); 422 int[] pos = it.getPos(); 423 int[] ipos = pos.clone(); 424 while (it.hasNext()) { 425 dd.setSlice(a.getSliceView(it.getCurrentSlice()), ds); 426 for (int i = 0; i < size; i++) { 427 index[i] = i; 428 } 429 Arrays.sort(index, cmp); 430 431 System.arraycopy(pos, 0, ipos, 0, pos.length); 432 for (int i = 0; i < size; i++) { 433 ipos[axis] = i; 434 id.set(index[i], ipos); 435 } 436 } 437 438 return id; 439 } 440 441 /** 442 * Concatenate the set of datasets along given axis 443 * @param as 444 * @param axis 445 * @return concatenated dataset 446 */ 447 public static Dataset concatenate(final IDataset[] as, int axis) { 448 if (as == null || as.length == 0) { 449 utilsLogger.error("No datasets given"); 450 throw new IllegalArgumentException("No datasets given"); 451 } 452 IDataset a = as[0]; 453 if (as.length == 1) { 454 return convertToDataset(a.clone()); 455 } 456 457 int[] ashape = a.getShape(); 458 axis = ShapeUtils.checkAxis(ashape.length, axis); 459 int at = DTypeUtils.getDType(a); 460 int anum = as.length; 461 int isize = a.getElementsPerItem(); 462 463 int i = 1; 464 for (; i < anum; i++) { 465 if (at != DTypeUtils.getDType(as[i])) { 466 utilsLogger.error("Datasets are not of same type"); 467 break; 468 } 469 if (!ShapeUtils.areShapesCompatible(ashape, as[i].getShape(), axis)) { 470 utilsLogger.error("Datasets' shapes are not equal"); 471 break; 472 } 473 final int is = as[i].getElementsPerItem(); 474 if (isize < is) 475 isize = is; 476 } 477 if (i < anum) { 478 utilsLogger.error("Dataset are not compatible"); 479 throw new IllegalArgumentException("Datasets are not compatible"); 480 } 481 482 for (i = 1; i < anum; i++) { 483 ashape[axis] += as[i].getShape()[axis]; 484 } 485 486 @SuppressWarnings("deprecation") 487 Dataset result = DatasetFactory.zeros(isize, ashape, at); 488 489 int[] start = new int[ashape.length]; 490 int[] stop = ashape; 491 stop[axis] = 0; 492 for (i = 0; i < anum; i++) { 493 IDataset b = as[i]; 494 int[] bshape = b.getShape(); 495 stop[axis] += bshape[axis]; 496 result.setSlice(b, start, stop, null); 497 start[axis] += bshape[axis]; 498 } 499 500 return result; 501 } 502 503 /** 504 * Split a dataset into equal sections along given axis 505 * @param a 506 * @param sections 507 * @param axis 508 * @param checkEqual makes sure the division is into equal parts 509 * @return list of split datasets 510 */ 511 public static List<Dataset> split(final Dataset a, int sections, int axis, final boolean checkEqual) { 512 int[] ashape = a.getShapeRef(); 513 axis = a.checkAxis(axis); 514 int imax = ashape[axis]; 515 if (checkEqual && (imax%sections) != 0) { 516 utilsLogger.error("Number of sections does not divide axis into equal parts"); 517 throw new IllegalArgumentException("Number of sections does not divide axis into equal parts"); 518 } 519 int n = (imax + sections - 1) / sections; 520 int[] indices = new int[sections-1]; 521 for (int i = 1; i < sections; i++) 522 indices[i-1] = n*i; 523 return split(a, indices, axis); 524 } 525 526 /** 527 * Split a dataset into parts along given axis 528 * @param a 529 * @param indices 530 * @param axis 531 * @return list of split datasets 532 */ 533 public static List<Dataset> split(final Dataset a, int[] indices, int axis) { 534 final int[] ashape = a.getShapeRef(); 535 axis = a.checkAxis(axis); 536 final int rank = ashape.length; 537 final int imax = ashape[axis]; 538 539 final List<Dataset> result = new ArrayList<Dataset>(); 540 541 final int[] nshape = ashape.clone(); 542 final int is = a.getElementsPerItem(); 543 544 int oind = 0; 545 final int[] start = new int[rank]; 546 final int[] stop = new int[rank]; 547 final int[] step = new int[rank]; 548 for (int i = 0; i < rank; i++) { 549 start[i] = 0; 550 stop[i] = ashape[i]; 551 step[i] = 1; 552 } 553 for (int ind : indices) { 554 if (ind > imax) { 555 result.add(DatasetFactory.zeros(is, a.getClass(), 0)); 556 } else { 557 nshape[axis] = ind - oind; 558 start[axis] = oind; 559 stop[axis] = ind; 560 Dataset n = DatasetFactory.zeros(is, a.getClass(), nshape); 561 IndexIterator iter = a.getSliceIterator(start, stop, step); 562 563 a.fillDataset(n, iter); 564 result.add(n); 565 oind = ind; 566 } 567 } 568 569 if (imax > oind) { 570 nshape[axis] = imax - oind; 571 start[axis] = oind; 572 stop[axis] = imax; 573 Dataset n = DatasetFactory.zeros(is, a.getClass(), nshape); 574 IndexIterator iter = a.getSliceIterator(start, stop, step); 575 576 a.fillDataset(n, iter); 577 result.add(n); 578 } 579 580 return result; 581 } 582 583 /** 584 * Constructs a dataset which has its elements along an axis replicated from 585 * the original dataset by the number of times given in the repeats array. 586 * 587 * By default, axis=-1 implies using a flattened version of the input dataset 588 * 589 * @param a 590 * @param repeats 591 * @param axis 592 * @return dataset 593 */ 594 public static <T extends Dataset> T repeat(T a, int[] repeats, int axis) { 595 Serializable buf = a.getBuffer(); 596 int[] shape = a.getShape(); 597 int rank = shape.length; 598 final int is = a.getElementsPerItem(); 599 600 if (axis >= rank) { 601 utilsLogger.warn("Axis value is out of bounds"); 602 throw new IllegalArgumentException("Axis value is out of bounds"); 603 } 604 605 int alen; 606 if (axis < 0) { 607 alen = a.getSize(); 608 axis = 0; 609 rank = 1; 610 shape[0] = alen; 611 } else { 612 alen = shape[axis]; 613 } 614 int rlen = repeats.length; 615 if (rlen != 1 && rlen != alen) { 616 utilsLogger.warn("Repeats array should have length of 1 or match chosen axis"); 617 throw new IllegalArgumentException("Repeats array should have length of 1 or match chosen axis"); 618 } 619 620 for (int i = 0; i < rlen; i++) { 621 if (repeats[i] < 0) { 622 utilsLogger.warn("Negative repeat value is not allowed"); 623 throw new IllegalArgumentException("Negative repeat value is not allowed"); 624 } 625 } 626 627 int[] newShape = new int[rank]; 628 for (int i = 0; i < rank; i ++) 629 newShape[i] = shape[i]; 630 631 // do single repeat separately 632 if (repeats.length == 1) { 633 newShape[axis] *= repeats[0]; 634 } else { 635 int nlen = 0; 636 for (int i = 0; i < alen; i++) { 637 nlen += repeats[i]; 638 } 639 newShape[axis] = nlen; 640 } 641 642 @SuppressWarnings("deprecation") 643 Dataset rdata = DatasetFactory.zeros(is, newShape, a.getDType()); 644 Serializable nbuf = rdata.getBuffer(); 645 646 int csize = is; // chunk size 647 for (int i = axis+1; i < rank; i++) { 648 csize *= newShape[i]; 649 } 650 int nout = 1; 651 for (int i = 0; i < axis; i++) { 652 nout *= newShape[i]; 653 } 654 655 int oi = 0; 656 int ni = 0; 657 if (rlen == 1) { // do single repeat separately 658 for (int i = 0; i < nout; i++) { 659 for (int j = 0; j < shape[axis]; j++) { 660 for (int k = 0; k < repeats[0]; k++) { 661 System.arraycopy(buf, oi, nbuf, ni, csize); 662 ni += csize; 663 } 664 oi += csize; 665 } 666 } 667 } else { 668 for (int i = 0; i < nout; i++) { 669 for (int j = 0; j < shape[axis]; j++) { 670 for (int k = 0; k < repeats[j]; k++) { 671 System.arraycopy(buf, oi, nbuf, ni, csize); 672 ni += csize; 673 } 674 oi += csize; 675 } 676 } 677 } 678 679 return (T) rdata; 680 } 681 682 /** 683 * Resize a dataset 684 * @param a 685 * @param shape 686 * @return new dataset with new shape and items that are truncated or repeated, as necessary 687 */ 688 public static <T extends Dataset> T resize(final T a, final int... shape) { 689 int size = a.getSize(); 690 @SuppressWarnings("deprecation") 691 Dataset rdata = DatasetFactory.zeros(a.getElementsPerItem(), shape, a.getDType()); 692 IndexIterator it = rdata.getIterator(); 693 while (it.hasNext()) { 694 rdata.setObjectAbs(it.index, a.getObjectAbs(it.index % size)); 695 } 696 697 return (T) rdata; 698 } 699 700 /** 701 * Copy and cast a dataset 702 * 703 * @param d 704 * The dataset to be copied 705 * @param dtype dataset type 706 * @return copied dataset of given type 707 */ 708 public static Dataset copy(final IDataset d, final int dtype) { 709 Dataset a = convertToDataset(d); 710 711 Dataset c = null; 712 try { 713 // copy across the data 714 switch (dtype) { 715 case Dataset.STRING: 716 c = new StringDataset(a); 717 break; 718 case Dataset.BOOL: 719 c = new BooleanDataset(a); 720 break; 721 case Dataset.INT8: 722 if (a instanceof CompoundDataset) 723 c = new CompoundByteDataset(a); 724 else 725 c = new ByteDataset(a); 726 break; 727 case Dataset.INT16: 728 if (a instanceof CompoundDataset) 729 c = new CompoundShortDataset(a); 730 else 731 c = new ShortDataset(a); 732 break; 733 case Dataset.INT32: 734 if (a instanceof CompoundDataset) 735 c = new CompoundIntegerDataset(a); 736 else 737 c = new IntegerDataset(a); 738 break; 739 case Dataset.INT64: 740 if (a instanceof CompoundDataset) 741 c = new CompoundLongDataset(a); 742 else 743 c = new LongDataset(a); 744 break; 745 case Dataset.ARRAYINT8: 746 if (a instanceof CompoundDataset) 747 c = new CompoundByteDataset((CompoundDataset) a); 748 else 749 c = new CompoundByteDataset(a); 750 break; 751 case Dataset.ARRAYINT16: 752 if (a instanceof CompoundDataset) 753 c = new CompoundShortDataset((CompoundDataset) a); 754 else 755 c = new CompoundShortDataset(a); 756 break; 757 case Dataset.ARRAYINT32: 758 if (a instanceof CompoundDataset) 759 c = new CompoundIntegerDataset((CompoundDataset) a); 760 else 761 c = new CompoundIntegerDataset(a); 762 break; 763 case Dataset.ARRAYINT64: 764 if (a instanceof CompoundDataset) 765 c = new CompoundLongDataset((CompoundDataset) a); 766 else 767 c = new CompoundLongDataset(a); 768 break; 769 case Dataset.FLOAT32: 770 c = new FloatDataset(a); 771 break; 772 case Dataset.FLOAT64: 773 c = new DoubleDataset(a); 774 break; 775 case Dataset.ARRAYFLOAT32: 776 if (a instanceof CompoundDataset) 777 c = new CompoundFloatDataset((CompoundDataset) a); 778 else 779 c = new CompoundFloatDataset(a); 780 break; 781 case Dataset.ARRAYFLOAT64: 782 if (a instanceof CompoundDataset) 783 c = new CompoundDoubleDataset((CompoundDataset) a); 784 else 785 c = new CompoundDoubleDataset(a); 786 break; 787 case Dataset.COMPLEX64: 788 c = new ComplexFloatDataset(a); 789 break; 790 case Dataset.COMPLEX128: 791 c = new ComplexDoubleDataset(a); 792 break; 793 case Dataset.RGB: 794 if (a instanceof CompoundDataset) 795 c = RGBDataset.createFromCompoundDataset((CompoundDataset) a); 796 else 797 c = new RGBDataset(a); 798 break; 799 default: 800 utilsLogger.error("Dataset of unknown type!"); 801 break; 802 } 803 } catch (OutOfMemoryError e) { 804 utilsLogger.error("Not enough memory available to create dataset"); 805 throw new OutOfMemoryError("Not enough memory available to create dataset"); 806 } 807 808 return c; 809 } 810 811 /** 812 * Copy and cast a dataset 813 * 814 * @param clazz dataset class 815 * @param d 816 * The dataset to be copied 817 * @return copied dataset of given type 818 */ 819 public static <T extends Dataset> T copy(Class<T> clazz, final IDataset d) { 820 return (T) copy(d, DTypeUtils.getDType(clazz)); 821 } 822 823 824 /** 825 * Cast a dataset 826 * 827 * @param d 828 * The dataset to be cast. 829 * @param dtype dataset type 830 * @return dataset of given type (or same dataset if already of the right type) 831 */ 832 public static Dataset cast(final IDataset d, final int dtype) { 833 Dataset a = convertToDataset(d); 834 835 if (a.getDType() == dtype) { 836 return a; 837 } 838 return copy(d, dtype); 839 } 840 841 /** 842 * Cast a dataset 843 * 844 * @param clazz dataset class 845 * @param d 846 * The dataset to be cast. 847 * @return dataset of given type (or same dataset if already of the right type) 848 */ 849 public static <T extends Dataset> T cast(Class<T> clazz, final IDataset d) { 850 return (T) cast(d, DTypeUtils.getDType(clazz)); 851 } 852 853 /** 854 * Cast a dataset 855 * 856 * @param d 857 * The dataset to be cast. 858 * @param repeat repeat elements over item 859 * @param dtype dataset type 860 * @param isize item size 861 */ 862 public static Dataset cast(final IDataset d, final boolean repeat, final int dtype, final int isize) { 863 Dataset a = convertToDataset(d); 864 865 if (a.getDType() == dtype && a.getElementsPerItem() == isize) { 866 return a; 867 } 868 if (isize <= 0) { 869 utilsLogger.error("Item size is invalid (>0)"); 870 throw new IllegalArgumentException("Item size is invalid (>0)"); 871 } 872 if (isize > 1 && dtype <= Dataset.FLOAT64) { 873 utilsLogger.error("Item size is inconsistent with dataset type"); 874 throw new IllegalArgumentException("Item size is inconsistent with dataset type"); 875 } 876 877 Dataset c = null; 878 879 try { 880 // copy across the data 881 switch (dtype) { 882 case Dataset.BOOL: 883 c = new BooleanDataset(a); 884 break; 885 case Dataset.INT8: 886 c = new ByteDataset(a); 887 break; 888 case Dataset.INT16: 889 c = new ShortDataset(a); 890 break; 891 case Dataset.INT32: 892 c = new IntegerDataset(a); 893 break; 894 case Dataset.INT64: 895 c = new LongDataset(a); 896 break; 897 case Dataset.ARRAYINT8: 898 c = new CompoundByteDataset(isize, repeat, a); 899 break; 900 case Dataset.ARRAYINT16: 901 c = new CompoundShortDataset(isize, repeat, a); 902 break; 903 case Dataset.ARRAYINT32: 904 c = new CompoundIntegerDataset(isize, repeat, a); 905 break; 906 case Dataset.ARRAYINT64: 907 c = new CompoundLongDataset(isize, repeat, a); 908 break; 909 case Dataset.FLOAT32: 910 c = new FloatDataset(a); 911 break; 912 case Dataset.FLOAT64: 913 c = new DoubleDataset(a); 914 break; 915 case Dataset.ARRAYFLOAT32: 916 c = new CompoundFloatDataset(isize, repeat, a); 917 break; 918 case Dataset.ARRAYFLOAT64: 919 c = new CompoundDoubleDataset(isize, repeat, a); 920 break; 921 case Dataset.COMPLEX64: 922 c = new ComplexFloatDataset(a); 923 break; 924 case Dataset.COMPLEX128: 925 c = new ComplexDoubleDataset(a); 926 break; 927 default: 928 utilsLogger.error("Dataset of unknown type!"); 929 break; 930 } 931 } catch (OutOfMemoryError e) { 932 utilsLogger.error("Not enough memory available to create dataset"); 933 throw new OutOfMemoryError("Not enough memory available to create dataset"); 934 } 935 936 return c; 937 } 938 939 /** 940 * Cast array of datasets to a compound dataset 941 * 942 * @param a 943 * The datasets to be cast. 944 */ 945 public static CompoundDataset cast(final Dataset[] a, final int dtype) { 946 CompoundDataset c = null; 947 948 switch (dtype) { 949 case Dataset.INT8: 950 case Dataset.ARRAYINT8: 951 c = new CompoundByteDataset(a); 952 break; 953 case Dataset.INT16: 954 case Dataset.ARRAYINT16: 955 c = new CompoundShortDataset(a); 956 break; 957 case Dataset.INT32: 958 case Dataset.ARRAYINT32: 959 c = new CompoundIntegerDataset(a); 960 break; 961 case Dataset.INT64: 962 case Dataset.ARRAYINT64: 963 c = new CompoundLongDataset(a); 964 break; 965 case Dataset.FLOAT32: 966 case Dataset.ARRAYFLOAT32: 967 c = new CompoundFloatDataset(a); 968 break; 969 case Dataset.FLOAT64: 970 case Dataset.ARRAYFLOAT64: 971 c = new CompoundDoubleDataset(a); 972 break; 973 case Dataset.COMPLEX64: 974 if (a.length != 2) { 975 throw new IllegalArgumentException("Need two datasets for complex dataset type"); 976 } 977 c = new ComplexFloatDataset(a[0], a[1]); 978 break; 979 case Dataset.COMPLEX128: 980 if (a.length != 2) { 981 throw new IllegalArgumentException("Need two datasets for complex dataset type"); 982 } 983 c = new ComplexDoubleDataset(a[0], a[1]); 984 break; 985 default: 986 utilsLogger.error("Dataset of unsupported type!"); 987 break; 988 } 989 990 return c; 991 } 992 993 /** 994 * Make a dataset unsigned by promoting it to a wider dataset type and unwrapping the signs 995 * of its contents 996 * @param a 997 * @return unsigned dataset or original if it is not an integer dataset 998 */ 999 public static Dataset makeUnsigned(IDataset a) { 1000 return makeUnsigned(a, false); 1001 } 1002 1003 /** 1004 * Make a dataset unsigned by promoting it to a wider dataset type and unwrapping the signs 1005 * of its contents 1006 * @param a 1007 * @param check if true, then check for negative values 1008 * @return unsigned dataset or original if it is not an integer dataset or it has been check for negative numbers 1009 * @since 2.1 1010 */ 1011 public static Dataset makeUnsigned(IDataset a, boolean check) { 1012 Dataset d = convertToDataset(a); 1013 1014 if (d.hasFloatingPointElements()) { 1015 return d; 1016 } 1017 if (check && d.min(true).longValue() >= 0) { 1018 return d; 1019 } 1020 1021 int dtype = d.getDType(); 1022 switch (dtype) { 1023 case Dataset.INT32: 1024 d = new LongDataset(d); 1025 unwrapUnsigned(d, 32); 1026 break; 1027 case Dataset.INT16: 1028 d = new IntegerDataset(d); 1029 unwrapUnsigned(d, 16); 1030 break; 1031 case Dataset.INT8: 1032 d = new ShortDataset(d); 1033 unwrapUnsigned(d, 8); 1034 break; 1035 case Dataset.ARRAYINT32: 1036 d = new CompoundLongDataset(d); 1037 unwrapUnsigned(d, 32); 1038 break; 1039 case Dataset.ARRAYINT16: 1040 d = new CompoundIntegerDataset(d); 1041 unwrapUnsigned(d, 16); 1042 break; 1043 case Dataset.ARRAYINT8: 1044 d = new CompoundShortDataset(d); 1045 unwrapUnsigned(d, 8); 1046 break; 1047 } 1048 return d; 1049 } 1050 1051 /** 1052 * Unwrap dataset elements so that all elements are unsigned 1053 * @param a dataset 1054 * @param bitWidth width of original primitive in bits 1055 */ 1056 public static void unwrapUnsigned(Dataset a, final int bitWidth) { 1057 final int dtype = a.getDType(); 1058 final double dv = 1L << bitWidth; 1059 final int isize = a.getElementsPerItem(); 1060 IndexIterator it = a.getIterator(); 1061 1062 switch (dtype) { 1063 case Dataset.BOOL: 1064 break; 1065 case Dataset.INT8: 1066 break; 1067 case Dataset.INT16: 1068 ShortDataset sds = (ShortDataset) a; 1069 final short soffset = (short) dv; 1070 while (it.hasNext()) { 1071 final short x = sds.getAbs(it.index); 1072 if (x < 0) 1073 sds.setAbs(it.index, (short) (x + soffset)); 1074 } 1075 break; 1076 case Dataset.INT32: 1077 IntegerDataset ids = (IntegerDataset) a; 1078 final int ioffset = (int) dv; 1079 while (it.hasNext()) { 1080 final int x = ids.getAbs(it.index); 1081 if (x < 0) 1082 ids.setAbs(it.index, x + ioffset); 1083 } 1084 break; 1085 case Dataset.INT64: 1086 LongDataset lds = (LongDataset) a; 1087 final long loffset = (long) dv; 1088 while (it.hasNext()) { 1089 final long x = lds.getAbs(it.index); 1090 if (x < 0) 1091 lds.setAbs(it.index, x + loffset); 1092 } 1093 break; 1094 case Dataset.FLOAT32: 1095 FloatDataset fds = (FloatDataset) a; 1096 final float foffset = (float) dv; 1097 while (it.hasNext()) { 1098 final float x = fds.getAbs(it.index); 1099 if (x < 0) 1100 fds.setAbs(it.index, x + foffset); 1101 } 1102 break; 1103 case Dataset.FLOAT64: 1104 DoubleDataset dds = (DoubleDataset) a; 1105 final double doffset = dv; 1106 while (it.hasNext()) { 1107 final double x = dds.getAbs(it.index); 1108 if (x < 0) 1109 dds.setAbs(it.index, x + doffset); 1110 } 1111 break; 1112 case Dataset.ARRAYINT8: 1113 break; 1114 case Dataset.ARRAYINT16: 1115 CompoundShortDataset csds = (CompoundShortDataset) a; 1116 final short csoffset = (short) dv; 1117 final short[] csa = new short[isize]; 1118 while (it.hasNext()) { 1119 csds.getAbs(it.index, csa); 1120 boolean dirty = false; 1121 for (int i = 0; i < isize; i++) { 1122 short x = csa[i]; 1123 if (x < 0) { 1124 csa[i] = (short) (x + csoffset); 1125 dirty = true; 1126 } 1127 } 1128 if (dirty) 1129 csds.setAbs(it.index, csa); 1130 } 1131 break; 1132 case Dataset.ARRAYINT32: 1133 CompoundIntegerDataset cids = (CompoundIntegerDataset) a; 1134 final int cioffset = (int) dv; 1135 final int[] cia = new int[isize]; 1136 while (it.hasNext()) { 1137 cids.getAbs(it.index, cia); 1138 boolean dirty = false; 1139 for (int i = 0; i < isize; i++) { 1140 int x = cia[i]; 1141 if (x < 0) { 1142 cia[i] = x + cioffset; 1143 dirty = true; 1144 } 1145 } 1146 if (dirty) 1147 cids.setAbs(it.index, cia); 1148 } 1149 break; 1150 case Dataset.ARRAYINT64: 1151 CompoundLongDataset clds = (CompoundLongDataset) a; 1152 final long cloffset = (long) dv; 1153 final long[] cla = new long[isize]; 1154 while (it.hasNext()) { 1155 clds.getAbs(it.index, cla); 1156 boolean dirty = false; 1157 for (int i = 0; i < isize; i++) { 1158 long x = cla[i]; 1159 if (x < 0) { 1160 cla[i] = x + cloffset; 1161 dirty = true; 1162 } 1163 } 1164 if (dirty) 1165 clds.setAbs(it.index, cla); 1166 } 1167 break; 1168 default: 1169 utilsLogger.error("Dataset of unsupported type for this method"); 1170 break; 1171 } 1172 } 1173 1174 /** 1175 * @param rows 1176 * @param cols 1177 * @param offset 1178 * @param dtype 1179 * @return a new 2d dataset of given shape and type, filled with ones on the (offset) diagonal 1180 */ 1181 public static Dataset eye(final int rows, final int cols, final int offset, final int dtype) { 1182 int[] shape = new int[] {rows, cols}; 1183 @SuppressWarnings("deprecation") 1184 Dataset a = DatasetFactory.zeros(shape, dtype); 1185 1186 int[] pos = new int[] {0, offset}; 1187 while (pos[1] < 0) { 1188 pos[0]++; 1189 pos[1]++; 1190 } 1191 while (pos[0] < rows && pos[1] < cols) { 1192 a.set(1, pos); 1193 pos[0]++; 1194 pos[1]++; 1195 } 1196 1197 return a; 1198 } 1199 1200 /** 1201 * Create a (off-)diagonal matrix from items in dataset 1202 * @param a 1203 * @param offset 1204 * @return diagonal matrix 1205 */ 1206 @SuppressWarnings("deprecation") 1207 public static <T extends Dataset> T diag(final T a, final int offset) { 1208 final int dtype = a.getDType(); 1209 final int rank = a.getRank(); 1210 final int is = a.getElementsPerItem(); 1211 1212 if (rank == 0 || rank > 2) { 1213 utilsLogger.error("Rank of dataset should be one or two"); 1214 throw new IllegalArgumentException("Rank of dataset should be one or two"); 1215 } 1216 1217 Dataset result; 1218 final int[] shape = a.getShapeRef(); 1219 if (rank == 1) { 1220 int side = shape[0] + Math.abs(offset); 1221 int[] pos = new int[] {side, side}; 1222 result = DatasetFactory.zeros(is, pos, dtype); 1223 if (offset >= 0) { 1224 pos[0] = 0; 1225 pos[1] = offset; 1226 } else { 1227 pos[0] = -offset; 1228 pos[1] = 0; 1229 } 1230 int i = 0; 1231 while (pos[0] < side && pos[1] < side) { 1232 result.set(a.getObject(i++), pos); 1233 pos[0]++; 1234 pos[1]++; 1235 } 1236 } else { 1237 int side = offset >= 0 ? Math.min(shape[0], shape[1]-offset) : Math.min(shape[0]+offset, shape[1]); 1238 if (side < 0) 1239 side = 0; 1240 result = DatasetFactory.zeros(is, new int[] {side}, dtype); 1241 1242 if (side > 0) { 1243 int[] pos = offset >= 0 ? new int[] { 0, offset } : new int[] { -offset, 0 }; 1244 int i = 0; 1245 while (pos[0] < shape[0] && pos[1] < shape[1]) { 1246 result.set(a.getObject(pos), i++); 1247 pos[0]++; 1248 pos[1]++; 1249 } 1250 } 1251 } 1252 1253 return (T) result; 1254 } 1255 1256 /** 1257 * Slice (or fully load), if necessary, a lazy dataset, otherwise take a slice view and 1258 * convert to our dataset implementation. If a slice is necessary, this may cause resource 1259 * problems when used on large datasets and throw runtime exceptions 1260 * @param lazy can be null 1261 * @return Converted dataset or null 1262 * @throws DatasetException 1263 */ 1264 public static Dataset sliceAndConvertLazyDataset(ILazyDataset lazy) throws DatasetException { 1265 if (lazy == null) 1266 return null; 1267 1268 IDataset data = lazy instanceof IDataset ? (IDataset) lazy.getSliceView() : lazy.getSlice(); 1269 1270 return convertToDataset(data); 1271 } 1272 1273 /** 1274 * Convert (if necessary) a dataset obeying the interface to our implementation 1275 * @param data can be null 1276 * @return Converted dataset or null 1277 */ 1278 public static Dataset convertToDataset(IDataset data) { 1279 if (data == null) 1280 return null; 1281 1282 if (data instanceof Dataset) { 1283 return (Dataset) data; 1284 } 1285 1286 int dtype = DTypeUtils.getDType(data); 1287 1288 final int isize = data.getElementsPerItem(); 1289 if (isize <= 0) { 1290 throw new IllegalArgumentException("Datasets with " + isize + " elements per item not supported"); 1291 } 1292 1293 @SuppressWarnings("deprecation") 1294 final Dataset result = DatasetFactory.zeros(isize, data.getShape(), dtype); 1295 result.setName(data.getName()); 1296 1297 final IndexIterator it = result.getIterator(true); 1298 final int[] pos = it.getPos(); 1299 switch (dtype) { 1300 case Dataset.BOOL: 1301 while (it.hasNext()) { 1302 result.setObjectAbs(it.index, data.getBoolean(pos)); 1303 } 1304 break; 1305 case Dataset.INT8: 1306 while (it.hasNext()) { 1307 result.setObjectAbs(it.index, data.getByte(pos)); 1308 } 1309 break; 1310 case Dataset.INT16: 1311 while (it.hasNext()) { 1312 result.setObjectAbs(it.index, data.getShort(pos)); 1313 } 1314 break; 1315 case Dataset.INT32: 1316 while (it.hasNext()) { 1317 result.setObjectAbs(it.index, data.getInt(pos)); 1318 } 1319 break; 1320 case Dataset.INT64: 1321 while (it.hasNext()) { 1322 result.setObjectAbs(it.index, data.getLong(pos)); 1323 } 1324 break; 1325 case Dataset.FLOAT32: 1326 while (it.hasNext()) { 1327 result.setObjectAbs(it.index, data.getFloat(pos)); 1328 } 1329 break; 1330 case Dataset.FLOAT64: 1331 while (it.hasNext()) { 1332 result.setObjectAbs(it.index, data.getDouble(pos)); 1333 } 1334 break; 1335 default: 1336 while (it.hasNext()) { 1337 result.setObjectAbs(it.index, data.getObject(pos)); 1338 } 1339 break; 1340 } 1341 1342 result.setErrors(data.getErrors()); 1343 return result; 1344 } 1345 1346 /** 1347 * Create a compound dataset from given datasets 1348 * @param datasets 1349 * @return compound dataset or null if none given 1350 */ 1351 public static CompoundDataset createCompoundDataset(final Dataset... datasets) { 1352 if (datasets == null || datasets.length == 0) 1353 return null; 1354 1355 return createCompoundDataset(datasets[0].getDType(), datasets); 1356 } 1357 1358 /** 1359 * Create a compound dataset from given datasets 1360 * @param dtype 1361 * @param datasets 1362 * @return compound dataset or null if none given 1363 */ 1364 public static CompoundDataset createCompoundDataset(final int dtype, final Dataset... datasets) { 1365 if (datasets == null || datasets.length == 0) 1366 return null; 1367 1368 switch (dtype) { 1369 case Dataset.INT8: 1370 case Dataset.ARRAYINT8: 1371 return new CompoundByteDataset(datasets); 1372 case Dataset.INT16: 1373 case Dataset.ARRAYINT16: 1374 return new CompoundShortDataset(datasets); 1375 case Dataset.INT32: 1376 case Dataset.ARRAYINT32: 1377 return new CompoundIntegerDataset(datasets); 1378 case Dataset.INT64: 1379 case Dataset.ARRAYINT64: 1380 return new CompoundLongDataset(datasets); 1381 case Dataset.FLOAT32: 1382 case Dataset.ARRAYFLOAT32: 1383 return new CompoundFloatDataset(datasets); 1384 case Dataset.FLOAT64: 1385 case Dataset.ARRAYFLOAT64: 1386 return new CompoundDoubleDataset(datasets); 1387 case Dataset.COMPLEX64: 1388 case Dataset.COMPLEX128: 1389 if (datasets.length > 2) { 1390 utilsLogger.error("At most two datasets are allowed"); 1391 throw new IllegalArgumentException("At most two datasets are allowed"); 1392 } else if (datasets.length == 2) { 1393 return dtype == Dataset.COMPLEX64 ? new ComplexFloatDataset(datasets[0], datasets[1]) : new ComplexDoubleDataset(datasets[0], datasets[1]); 1394 } 1395 return dtype == Dataset.COMPLEX64 ? new ComplexFloatDataset(datasets[0]) : new ComplexDoubleDataset(datasets[0]); 1396 case Dataset.RGB: 1397 if (datasets.length == 1) { 1398 return new RGBDataset(datasets[0]); 1399 } else if (datasets.length == 3) { 1400 return new RGBDataset(datasets[0], datasets[1], datasets[2]); 1401 } else { 1402 utilsLogger.error("Only one or three datasets are allowed to create a RGB dataset"); 1403 throw new IllegalArgumentException("Only one or three datasets are allowed to create a RGB dataset"); 1404 } 1405 default: 1406 utilsLogger.error("Dataset type not supported for this operation"); 1407 throw new UnsupportedOperationException("Dataset type not supported"); 1408 } 1409 } 1410 1411 /** 1412 * Create a compound dataset from given datasets 1413 * @param clazz dataset class 1414 * @param datasets 1415 * @return compound dataset or null if none given 1416 */ 1417 public static <T extends CompoundDataset> T createCompoundDataset(Class<T> clazz, final Dataset... datasets) { 1418 return (T) createCompoundDataset(DTypeUtils.getDType(clazz), datasets); 1419 } 1420 1421 /** 1422 * Create a compound dataset from given dataset 1423 * @param dataset 1424 * @param itemSize 1425 * @return compound dataset 1426 */ 1427 public static CompoundDataset createCompoundDataset(final Dataset dataset, final int itemSize) { 1428 int[] shape = dataset.getShapeRef(); 1429 int[] nshape = shape; 1430 if (shape != null && itemSize > 1) { 1431 int size = ShapeUtils.calcSize(shape); 1432 if (size % itemSize != 0) { 1433 throw new IllegalArgumentException("Input dataset has number of items that is not a multiple of itemSize"); 1434 } 1435 int d = shape.length; 1436 int l = 1; 1437 while (--d >= 0) { 1438 l *= shape[d]; 1439 if (l % itemSize == 0) { 1440 break; 1441 } 1442 } 1443 assert d >= 0; 1444 nshape = new int[d + 1]; 1445 for (int i = 0; i < d; i++) { 1446 nshape[i] = shape[i]; 1447 } 1448 nshape[d] = l / itemSize; 1449 } 1450 switch (dataset.getDType()) { 1451 case Dataset.INT8: 1452 return new CompoundByteDataset(itemSize, (byte[]) dataset.getBuffer(), nshape); 1453 case Dataset.INT16: 1454 return new CompoundShortDataset(itemSize, (short[]) dataset.getBuffer(), nshape); 1455 case Dataset.INT32: 1456 return new CompoundIntegerDataset(itemSize, (int[]) dataset.getBuffer(), nshape); 1457 case Dataset.INT64: 1458 return new CompoundLongDataset(itemSize, (long[]) dataset.getBuffer(), nshape); 1459 case Dataset.FLOAT32: 1460 return new CompoundFloatDataset(itemSize, (float[]) dataset.getBuffer(), nshape); 1461 case Dataset.FLOAT64: 1462 return new CompoundDoubleDataset(itemSize, (double[]) dataset.getBuffer(), nshape); 1463 default: 1464 utilsLogger.error("Dataset type not supported for this operation"); 1465 throw new UnsupportedOperationException("Dataset type not supported"); 1466 } 1467 } 1468 1469 1470 /** 1471 * Create a compound dataset by using last axis as elements of an item 1472 * @param a 1473 * @param shareData if true, then share data 1474 * @return compound dataset 1475 */ 1476 public static CompoundDataset createCompoundDatasetFromLastAxis(final Dataset a, final boolean shareData) { 1477 switch (a.getDType()) { 1478 case Dataset.INT8: 1479 return CompoundByteDataset.createCompoundDatasetWithLastDimension(a, shareData); 1480 case Dataset.INT16: 1481 return CompoundShortDataset.createCompoundDatasetWithLastDimension(a, shareData); 1482 case Dataset.INT32: 1483 return CompoundIntegerDataset.createCompoundDatasetWithLastDimension(a, shareData); 1484 case Dataset.INT64: 1485 return CompoundLongDataset.createCompoundDatasetWithLastDimension(a, shareData); 1486 case Dataset.FLOAT32: 1487 return CompoundFloatDataset.createCompoundDatasetWithLastDimension(a, shareData); 1488 case Dataset.FLOAT64: 1489 return CompoundDoubleDataset.createCompoundDatasetWithLastDimension(a, shareData); 1490 default: 1491 utilsLogger.error("Dataset type not supported for this operation"); 1492 throw new UnsupportedOperationException("Dataset type not supported"); 1493 } 1494 } 1495 1496 /** 1497 * Create a dataset from a compound dataset by using elements of an item as last axis 1498 * <p> 1499 * In the case where the number of elements is one, the last axis is squeezed out. 1500 * @param a 1501 * @param shareData if true, then share data 1502 * @return non-compound dataset 1503 */ 1504 public static Dataset createDatasetFromCompoundDataset(final CompoundDataset a, final boolean shareData) { 1505 return a.asNonCompoundDataset(shareData); 1506 } 1507 1508 /** 1509 * Create a copy that has been coerced to an appropriate dataset type 1510 * depending on the input object's class 1511 * 1512 * @param a 1513 * @param obj 1514 * @return coerced copy of dataset 1515 */ 1516 public static Dataset coerce(Dataset a, Object obj) { 1517 final int dt = a.getDType(); 1518 final int ot = DTypeUtils.getDTypeFromClass(obj.getClass()); 1519 1520 return cast(a.clone(), DTypeUtils.getBestDType(dt, ot)); 1521 } 1522 1523 /** 1524 * Function that returns a normalised dataset which is bounded between 0 and 1 1525 * @param a dataset 1526 * @return normalised dataset 1527 */ 1528 public static Dataset norm(Dataset a) { 1529 double amin = a.min().doubleValue(); 1530 double aptp = a.max().doubleValue() - amin; 1531 Dataset temp = Maths.subtract(a, amin); 1532 temp.idivide(aptp); 1533 return temp; 1534 } 1535 1536 /** 1537 * Function that returns a normalised compound dataset which is bounded between 0 and 1. There 1538 * are (at least) two ways to normalise a compound dataset: per element - extrema for each element 1539 * in a compound item is used, i.e. many min/max pairs; over all elements - extrema for all elements 1540 * is used, i.e. one min/max pair. 1541 * @param a dataset 1542 * @param overAllElements if true, then normalise over all elements in each item 1543 * @return normalised dataset 1544 */ 1545 public static CompoundDataset norm(CompoundDataset a, boolean overAllElements) { 1546 double[] amin = a.minItem(); 1547 double[] amax = a.maxItem(); 1548 final int is = a.getElementsPerItem(); 1549 Dataset result; 1550 1551 if (overAllElements) { 1552 Arrays.sort(amin); 1553 Arrays.sort(amax); 1554 double aptp = amax[0] - amin[0]; 1555 1556 result = Maths.subtract(a, amin[0]); 1557 result.idivide(aptp); 1558 } else { 1559 double[] aptp = new double[is]; 1560 for (int j = 0; j < is; j++) { 1561 aptp[j] = amax[j] - amin[j]; 1562 } 1563 1564 result = Maths.subtract(a, amin); 1565 result.idivide(aptp); 1566 } 1567 return (CompoundDataset) result; 1568 } 1569 1570 /** 1571 * Function that returns a normalised dataset which is bounded between 0 and 1 1572 * and has been distributed on a log10 scale 1573 * @param a dataset 1574 * @return normalised dataset 1575 */ 1576 public static Dataset lognorm(Dataset a) { 1577 double amin = a.min().doubleValue(); 1578 double aptp = Math.log10(a.max().doubleValue() - amin + 1.); 1579 Dataset temp = Maths.subtract(a, amin - 1.); 1580 temp = Maths.log10(temp); 1581 temp = Maths.divide(temp, aptp); 1582 return temp; 1583 } 1584 1585 /** 1586 * Function that returns a normalised dataset which is bounded between 0 and 1 1587 * and has been distributed on a natural log scale 1588 * @param a dataset 1589 * @return normalised dataset 1590 */ 1591 public static Dataset lnnorm(Dataset a) { 1592 double amin = a.min().doubleValue(); 1593 double aptp = Math.log(a.max().doubleValue() - amin + 1.); 1594 Dataset temp = Maths.subtract(a, amin - 1.); 1595 temp = Maths.log(temp); 1596 temp = Maths.divide(temp, aptp); 1597 return temp; 1598 } 1599 1600 /** 1601 * Construct a list of datasets where each represents a coordinate varying over the hypergrid 1602 * formed by the input list of axes 1603 * 1604 * @param axes an array of 1D datasets representing axes 1605 * @return a list of coordinate datasets 1606 */ 1607 public static List<Dataset> meshGrid(final Dataset... axes) { 1608 List<Dataset> result = new ArrayList<Dataset>(); 1609 int rank = axes.length; 1610 1611 if (rank < 2) { 1612 utilsLogger.error("Two or more axes datasets are required"); 1613 throw new IllegalArgumentException("Two or more axes datasets are required"); 1614 } 1615 1616 int[] nshape = new int[rank]; 1617 1618 for (int i = 0; i < rank; i++) { 1619 Dataset axis = axes[i]; 1620 if (axis.getRank() != 1) { 1621 utilsLogger.error("Given axis is not 1D"); 1622 throw new IllegalArgumentException("Given axis is not 1D"); 1623 } 1624 nshape[i] = axis.getSize(); 1625 } 1626 1627 for (int i = 0; i < rank; i++) { 1628 Dataset axis = axes[i]; 1629 @SuppressWarnings("deprecation") 1630 Dataset coord = DatasetFactory.zeros(nshape, axis.getDType()); 1631 result.add(coord); 1632 1633 final int alen = axis.getSize(); 1634 for (int j = 0; j < alen; j++) { 1635 final Object obj = axis.getObjectAbs(j); 1636 PositionIterator pi = coord.getPositionIterator(i); 1637 final int[] pos = pi.getPos(); 1638 1639 pos[i] = j; 1640 while (pi.hasNext()) { 1641 coord.set(obj, pos); 1642 } 1643 } 1644 } 1645 1646 return result; 1647 } 1648 1649 /** 1650 * Generate an index dataset for given dataset where sub-datasets contain index values 1651 * 1652 * @return an index dataset 1653 */ 1654 public static IntegerDataset indices(int... shape) { 1655 // now create another dataset to plot against 1656 final int rank = shape.length; 1657 int[] nshape = new int[rank+1]; 1658 nshape[0] = rank; 1659 for (int i = 0; i < rank; i++) { 1660 nshape[i+1] = shape[i]; 1661 } 1662 1663 IntegerDataset index = new IntegerDataset(nshape); 1664 1665 if (rank == 1) { 1666 final int alen = shape[0]; 1667 int[] pos = new int[2]; 1668 for (int j = 0; j < alen; j++) { 1669 pos[1] = j; 1670 index.set(j, pos); 1671 } 1672 } else { 1673 for (int i = 1; i <= rank; i++) { 1674 final int alen = nshape[i]; 1675 for (int j = 0; j < alen; j++) { 1676 PositionIterator pi = index.getPositionIterator(0, i); 1677 final int[] pos = pi.getPos(); 1678 1679 pos[0] = i-1; 1680 pos[i] = j; 1681 while (pi.hasNext()) { 1682 index.set(j, pos); 1683 } 1684 } 1685 } 1686 } 1687 return index; 1688 } 1689 1690 /** 1691 * Get the centroid value of a dataset, this function works out the centroid in every direction 1692 * 1693 * @param a 1694 * the dataset to be analysed 1695 * @param bases the optional array of base coordinates to use as weights. 1696 * This defaults to the mid-point of indices 1697 * @return a double array containing the centroid for each dimension 1698 */ 1699 public static double[] centroid(Dataset a, Dataset... bases) { 1700 int rank = a.getRank(); 1701 if (bases.length > 0 && bases.length != rank) { 1702 throw new IllegalArgumentException("Number of bases must be zero or match rank of dataset"); 1703 } 1704 1705 int[] shape = a.getShapeRef(); 1706 if (bases.length == rank) { 1707 for (int i = 0; i < rank; i++) { 1708 Dataset b = bases[i]; 1709 if (b.getRank() != 1 && b.getSize() != shape[i]) { 1710 throw new IllegalArgumentException("A base does not have shape to match given dataset"); 1711 } 1712 } 1713 } 1714 1715 double[] dc = new double[rank]; 1716 if (rank == 0) 1717 return dc; 1718 1719 final PositionIterator iter = new PositionIterator(shape); 1720 final int[] pos = iter.getPos(); 1721 1722 double tsum = 0.0; 1723 while (iter.hasNext()) { 1724 double val = a.getDouble(pos); 1725 tsum += val; 1726 for (int d = 0; d < rank; d++) { 1727 Dataset b = bases.length == 0 ? null : bases[d]; 1728 if (b == null) { 1729 dc[d] += (pos[d] + 0.5) * val; 1730 } else { 1731 dc[d] += b.getElementDoubleAbs(pos[d]) * val; 1732 } 1733 } 1734 } 1735 1736 for (int d = 0; d < rank; d++) { 1737 dc[d] /= tsum; 1738 } 1739 return dc; 1740 } 1741 1742 /** 1743 * Find linearly-interpolated crossing points where the given dataset crosses the given value 1744 * 1745 * @param d 1746 * @param value 1747 * @return list of interpolated indices 1748 */ 1749 public static List<Double> crossings(Dataset d, double value) { 1750 if (d.getRank() != 1) { 1751 utilsLogger.error("Only 1d datasets supported"); 1752 throw new UnsupportedOperationException("Only 1d datasets supported"); 1753 } 1754 List<Double> results = new ArrayList<Double>(); 1755 1756 // run through all pairs of points on the line and see if value lies within 1757 IndexIterator it = d.getIterator(); 1758 double y1, y2; 1759 1760 y2 = it.hasNext() ? d.getElementDoubleAbs(it.index) : 0; 1761 double x = 1; 1762 while (it.hasNext()) { 1763 y1 = y2; 1764 y2 = d.getElementDoubleAbs(it.index); 1765 // check if value lies within pair [y1, y2] 1766 if ((y1 <= value && value < y2) || (y1 > value && y2 <= value)) { 1767 final double f = (value - y2)/(y2 - y1); // negative distance from right to left 1768 results.add(x + f); 1769 } 1770 x++; 1771 } 1772 if (y2 == value) { // add end point of it intersects 1773 results.add(x); 1774 } 1775 1776 return results; 1777 } 1778 1779 /** 1780 * Find x values of all the crossing points of the dataset with the given y value 1781 * 1782 * @param xAxis 1783 * Dataset of the X axis that needs to be looked at 1784 * @param yAxis 1785 * Dataset of the Y axis that needs to be looked at 1786 * @param yValue 1787 * The y value the X values are required for 1788 * @return An list of doubles containing all the X coordinates of where the line crosses 1789 */ 1790 public static List<Double> crossings(Dataset xAxis, Dataset yAxis, double yValue) { 1791 if (xAxis.getSize() > yAxis.getSize()) { 1792 throw new IllegalArgumentException( 1793 "Number of values of yAxis must as least be equal to the number of values of xAxis"); 1794 } 1795 1796 List<Double> results = new ArrayList<Double>(); 1797 1798 List<Double> indices = crossings(yAxis, yValue); 1799 1800 for (double xi : indices) { 1801 results.add(Maths.interpolate(xAxis, xi)); 1802 } 1803 return results; 1804 } 1805 1806 /** 1807 * Function that uses the crossings function but prunes the result, so that multiple crossings within a 1808 * certain proportion of the overall range of the x values 1809 * 1810 * @param xAxis 1811 * Dataset of the X axis 1812 * @param yAxis 1813 * Dataset of the Y axis 1814 * @param yValue 1815 * The y value the x values are required for 1816 * @param xRangeProportion 1817 * The proportion of the overall x spread used to prune result 1818 * @return A list containing all the unique crossing points 1819 */ 1820 public static List<Double> crossings(Dataset xAxis, Dataset yAxis, double yValue, double xRangeProportion) { 1821 // get the values found 1822 List<Double> vals = crossings(xAxis, yAxis, yValue); 1823 1824 // use the proportion to calculate the error spacing 1825 double error = xRangeProportion * xAxis.peakToPeak().doubleValue(); 1826 1827 int i = 0; 1828 // now go through and check for groups of three crossings which are all 1829 // within the boundaries 1830 while (i <= vals.size() - 3) { 1831 double v1 = Math.abs(vals.get(i) - vals.get(i + 2)); 1832 if (v1 < error) { 1833 // these 3 points should be treated as one 1834 // make the first point equal to the average of them all 1835 vals.set(i + 2, ((vals.get(i) + vals.get(i + 1) + vals.get(i + 2)) / 3.0)); 1836 // remove the other offending points 1837 vals.remove(i); 1838 vals.remove(i); 1839 } else { 1840 i++; 1841 } 1842 } 1843 1844 // once the thinning process has been completed, return the pruned list 1845 return vals; 1846 } 1847 1848 // recursive function 1849 private static void setRow(Object row, Dataset a, int... pos) { 1850 final int l = Array.getLength(row); 1851 final int rank = pos.length; 1852 final int[] npos = Arrays.copyOf(pos, rank+1); 1853 Object r; 1854 if (rank+1 < a.getRank()) { 1855 for (int i = 0; i < l; i++) { 1856 npos[rank] = i; 1857 r = Array.get(row, i); 1858 setRow(r, a, npos); 1859 } 1860 } else { 1861 for (int i = 0; i < l; i++) { 1862 npos[rank] = i; 1863 r = a.getObject(npos); 1864 Array.set(row, i, r); 1865 } 1866 } 1867 } 1868 1869 /** 1870 * Create Java array (of arrays) from dataset 1871 * @param a dataset 1872 * @return Java array (of arrays...) 1873 */ 1874 public static Object createJavaArray(Dataset a) { 1875 if (a.getElementsPerItem() > 1) { 1876 a = createDatasetFromCompoundDataset((CompoundDataset) a, true); 1877 } 1878 Object matrix; 1879 1880 switch (a.getDType()) { 1881 case Dataset.BOOL: 1882 matrix = Array.newInstance(boolean.class, a.getShape()); 1883 break; 1884 case Dataset.INT8: 1885 matrix = Array.newInstance(byte.class, a.getShape()); 1886 break; 1887 case Dataset.INT16: 1888 matrix = Array.newInstance(short.class, a.getShape()); 1889 break; 1890 case Dataset.INT32: 1891 matrix = Array.newInstance(int.class, a.getShape()); 1892 break; 1893 case Dataset.INT64: 1894 matrix = Array.newInstance(long.class, a.getShape()); 1895 break; 1896 case Dataset.FLOAT32: 1897 matrix = Array.newInstance(float.class, a.getShape()); 1898 break; 1899 case Dataset.FLOAT64: 1900 matrix = Array.newInstance(double.class, a.getShape()); 1901 break; 1902 default: 1903 utilsLogger.error("Dataset type not supported"); 1904 throw new IllegalArgumentException("Dataset type not supported"); 1905 } 1906 1907 // populate matrix 1908 setRow(matrix, a); 1909 return matrix; 1910 } 1911 1912 /** 1913 * Removes NaNs and infinities from floating point datasets. 1914 * All other dataset types are ignored. 1915 * 1916 * @param a dataset 1917 * @param value replacement value 1918 */ 1919 public static void removeNansAndInfinities(Dataset a, final Number value) { 1920 if (a instanceof DoubleDataset) { 1921 final double dvalue = DTypeUtils.toReal(value); 1922 final DoubleDataset set = (DoubleDataset) a; 1923 final IndexIterator it = set.getIterator(); 1924 final double[] data = set.getData(); 1925 while (it.hasNext()) { 1926 double x = data[it.index]; 1927 if (Double.isNaN(x) || Double.isInfinite(x)) 1928 data[it.index] = dvalue; 1929 } 1930 } else if (a instanceof FloatDataset) { 1931 final float fvalue = (float) DTypeUtils.toReal(value); 1932 final FloatDataset set = (FloatDataset) a; 1933 final IndexIterator it = set.getIterator(); 1934 final float[] data = set.getData(); 1935 while (it.hasNext()) { 1936 float x = data[it.index]; 1937 if (Float.isNaN(x) || Float.isInfinite(x)) 1938 data[it.index] = fvalue; 1939 } 1940 } else if (a instanceof CompoundDoubleDataset) { 1941 final double dvalue = DTypeUtils.toReal(value); 1942 final CompoundDoubleDataset set = (CompoundDoubleDataset) a; 1943 final int is = set.getElementsPerItem(); 1944 final IndexIterator it = set.getIterator(); 1945 final double[] data = set.getData(); 1946 while (it.hasNext()) { 1947 for (int j = 0; j < is; j++) { 1948 double x = data[it.index + j]; 1949 if (Double.isNaN(x) || Double.isInfinite(x)) 1950 data[it.index + j] = dvalue; 1951 } 1952 } 1953 } else if (a instanceof CompoundFloatDataset) { 1954 final float fvalue = (float) DTypeUtils.toReal(value); 1955 final CompoundFloatDataset set = (CompoundFloatDataset) a; 1956 final int is = set.getElementsPerItem(); 1957 final IndexIterator it = set.getIterator(); 1958 final float[] data = set.getData(); 1959 while (it.hasNext()) { 1960 for (int j = 0; j < is; j++) { 1961 float x = data[it.index + j]; 1962 if (Float.isNaN(x) || Float.isInfinite(x)) 1963 data[it.index + j] = fvalue; 1964 } 1965 } 1966 } 1967 } 1968 1969 /** 1970 * Make floating point datasets contain only finite values. Infinities and NaNs are replaced 1971 * by +/- MAX_VALUE and 0, respectively. 1972 * All other dataset types are ignored. 1973 * 1974 * @param a dataset 1975 */ 1976 public static void makeFinite(Dataset a) { 1977 if (a instanceof DoubleDataset) { 1978 final DoubleDataset set = (DoubleDataset) a; 1979 final IndexIterator it = set.getIterator(); 1980 final double[] data = set.getData(); 1981 while (it.hasNext()) { 1982 final double x = data[it.index]; 1983 if (Double.isNaN(x)) 1984 data[it.index] = 0; 1985 else if (Double.isInfinite(x)) 1986 data[it.index] = x > 0 ? Double.MAX_VALUE : -Double.MAX_VALUE; 1987 } 1988 } else if (a instanceof FloatDataset) { 1989 final FloatDataset set = (FloatDataset) a; 1990 final IndexIterator it = set.getIterator(); 1991 final float[] data = set.getData(); 1992 while (it.hasNext()) { 1993 final float x = data[it.index]; 1994 if (Float.isNaN(x)) 1995 data[it.index] = 0; 1996 else if (Float.isInfinite(x)) 1997 data[it.index] = x > 0 ? Float.MAX_VALUE : -Float.MAX_VALUE; 1998 } 1999 } else if (a instanceof CompoundDoubleDataset) { 2000 final CompoundDoubleDataset set = (CompoundDoubleDataset) a; 2001 final int is = set.getElementsPerItem(); 2002 final IndexIterator it = set.getIterator(); 2003 final double[] data = set.getData(); 2004 while (it.hasNext()) { 2005 for (int j = 0; j < is; j++) { 2006 final double x = data[it.index + j]; 2007 if (Double.isNaN(x)) 2008 data[it.index + j] = 0; 2009 else if (Double.isInfinite(x)) 2010 data[it.index + j] = x > 0 ? Double.MAX_VALUE : -Double.MAX_VALUE; 2011 } 2012 } 2013 } else if (a instanceof CompoundFloatDataset) { 2014 final CompoundFloatDataset set = (CompoundFloatDataset) a; 2015 final int is = set.getElementsPerItem(); 2016 final IndexIterator it = set.getIterator(); 2017 final float[] data = set.getData(); 2018 while (it.hasNext()) { 2019 for (int j = 0; j < is; j++) { 2020 final float x = data[it.index + j]; 2021 if (Float.isNaN(x)) 2022 data[it.index + j] = 0; 2023 else if (Float.isInfinite(x)) 2024 data[it.index + j] = x > 0 ? Float.MAX_VALUE : -Float.MAX_VALUE; 2025 } 2026 } 2027 } 2028 } 2029 2030 /** 2031 * Find absolute index of first value in dataset that is equal to given number 2032 * @param a 2033 * @param n 2034 * @return absolute index (if greater than a.getSize() then no value found) 2035 */ 2036 public static int findIndexEqualTo(final Dataset a, final double n) { 2037 IndexIterator iter = a.getIterator(); 2038 while (iter.hasNext()) { 2039 if (a.getElementDoubleAbs(iter.index) == n) 2040 break; 2041 } 2042 2043 return iter.index; 2044 } 2045 2046 /** 2047 * Find absolute index of first value in dataset that is greater than given number 2048 * @param a 2049 * @param n 2050 * @return absolute index (if greater than a.getSize() then no value found) 2051 */ 2052 public static int findIndexGreaterThan(final Dataset a, final double n) { 2053 IndexIterator iter = a.getIterator(); 2054 while (iter.hasNext()) { 2055 if (a.getElementDoubleAbs(iter.index) > n) 2056 break; 2057 } 2058 2059 return iter.index; 2060 } 2061 2062 /** 2063 * Find absolute index of first value in dataset that is greater than or equal to given number 2064 * @param a 2065 * @param n 2066 * @return absolute index (if greater than a.getSize() then no value found) 2067 */ 2068 public static int findIndexGreaterThanOrEqualTo(final Dataset a, final double n) { 2069 IndexIterator iter = a.getIterator(); 2070 while (iter.hasNext()) { 2071 if (a.getElementDoubleAbs(iter.index) >= n) 2072 break; 2073 } 2074 2075 return iter.index; 2076 } 2077 2078 /** 2079 * Find absolute index of first value in dataset that is less than given number 2080 * @param a 2081 * @param n 2082 * @return absolute index (if greater than a.getSize() then no value found) 2083 */ 2084 public static int findIndexLessThan(final Dataset a, final double n) { 2085 IndexIterator iter = a.getIterator(); 2086 while (iter.hasNext()) { 2087 if (a.getElementDoubleAbs(iter.index) < n) 2088 break; 2089 } 2090 2091 return iter.index; 2092 } 2093 2094 /** 2095 * Find absolute index of first value in dataset that is less than or equal to given number 2096 * @param a 2097 * @param n 2098 * @return absolute index (if greater than a.getSize() then no value found) 2099 */ 2100 public static int findIndexLessThanOrEqualTo(final Dataset a, final double n) { 2101 IndexIterator iter = a.getIterator(); 2102 while (iter.hasNext()) { 2103 if (a.getElementDoubleAbs(iter.index) <= n) 2104 break; 2105 } 2106 2107 return iter.index; 2108 } 2109 2110 /** 2111 * Find first occurrences in one dataset of values given in another sorted dataset 2112 * @param a 2113 * @param values sorted 1D dataset of values to find 2114 * @return absolute indexes of those first occurrences (-1 is used to indicate value not found) 2115 */ 2116 public static IntegerDataset findFirstOccurrences(final Dataset a, final Dataset values) { 2117 if (values.getRank() != 1) { 2118 throw new IllegalArgumentException("Values dataset must be 1D"); 2119 } 2120 IntegerDataset indexes = new IntegerDataset(values.getSize()); 2121 indexes.fill(-1); 2122 2123 IndexIterator it = a.getIterator(); 2124 final int n = values.getSize(); 2125 if (values.getDType() == Dataset.INT64) { 2126 while (it.hasNext()) { 2127 long x = a.getElementLongAbs(it.index); 2128 2129 int l = 0; // binary search to find value in sorted dataset 2130 long vl = values.getLong(l); 2131 if (x <= vl) { 2132 if (x == vl && indexes.getAbs(l) < 0) 2133 indexes.setAbs(l, it.index); 2134 continue; 2135 } 2136 int h = n - 1; 2137 long vh = values.getLong(h); 2138 if (x >= vh) { 2139 if (x == vh && indexes.getAbs(h) < 0) 2140 indexes.setAbs(h, it.index); 2141 continue; 2142 } 2143 while (h - l > 1) { 2144 int m = (l + h) / 2; 2145 long vm = values.getLong(m); 2146 if (x < vm) { 2147 h = m; 2148 } else if (x > vm) { 2149 l = m; 2150 } else { 2151 if (indexes.getAbs(m) < 0) 2152 indexes.setAbs(m, it.index); 2153 break; 2154 } 2155 } 2156 } 2157 } else { 2158 while (it.hasNext()) { 2159 double x = a.getElementDoubleAbs(it.index); 2160 2161 int l = 0; // binary search to find value in sorted dataset 2162 double vl = values.getDouble(l); 2163 if (x <= vl) { 2164 if (x == vl && indexes.getAbs(l) < 0) 2165 indexes.setAbs(l, it.index); 2166 continue; 2167 } 2168 int h = n - 1; 2169 double vh = values.getDouble(h); 2170 if (x >= vh) { 2171 if (x == vh && indexes.getAbs(h) < 0) 2172 indexes.setAbs(h, it.index); 2173 continue; 2174 } 2175 while (h - l > 1) { 2176 int m = (l + h) / 2; 2177 double vm = values.getDouble(m); 2178 if (x < vm) { 2179 h = m; 2180 } else if (x > vm) { 2181 l = m; 2182 } else { 2183 if (indexes.getAbs(m) < 0) 2184 indexes.setAbs(m, it.index); 2185 break; 2186 } 2187 } 2188 } 2189 } 2190 return indexes; 2191 } 2192 2193 /** 2194 * Find indexes in sorted dataset of values for each value in other dataset 2195 * @param a 2196 * @param values sorted 1D dataset of values to find 2197 * @return absolute indexes of values (-1 is used to indicate value not found) 2198 */ 2199 public static IntegerDataset findIndexesForValues(final Dataset a, final Dataset values) { 2200 if (values.getRank() != 1) { 2201 throw new IllegalArgumentException("Values dataset must be 1D"); 2202 } 2203 IntegerDataset indexes = new IntegerDataset(a.getSize()); 2204 indexes.fill(-1); 2205 2206 IndexIterator it = a.getIterator(); 2207 int i = -1; 2208 final int n = values.getSize(); 2209 if (values.getDType() == Dataset.INT64) { 2210 while (it.hasNext()) { 2211 i++; 2212 long x = a.getElementLongAbs(it.index); 2213 2214 int l = 0; // binary search to find value in sorted dataset 2215 long vl = values.getLong(l); 2216 if (x <= vl) { 2217 if (x == vl) 2218 indexes.setAbs(i, l); 2219 continue; 2220 } 2221 int h = n - 1; 2222 long vh = values.getLong(h); 2223 if (x >= vh) { 2224 if (x == vh) 2225 indexes.setAbs(i, h); 2226 continue; 2227 } 2228 while (h - l > 1) { 2229 int m = (l + h) / 2; 2230 long vm = values.getLong(m); 2231 if (x < vm) { 2232 h = m; 2233 } else if (x > vm) { 2234 l = m; 2235 } else { 2236 indexes.setAbs(i, m); 2237 break; 2238 } 2239 } 2240 } 2241 } else { 2242 while (it.hasNext()) { 2243 i++; 2244 double x = a.getElementDoubleAbs(it.index); 2245 2246 int l = 0; // binary search to find value in sorted dataset 2247 double vl = values.getDouble(l); 2248 if (x <= vl) { 2249 if (x == vl) 2250 indexes.setAbs(i, l); 2251 continue; 2252 } 2253 int h = n - 1; 2254 double vh = values.getDouble(h); 2255 if (x >= vh) { 2256 if (x == vh) 2257 indexes.setAbs(i, h); 2258 continue; 2259 } 2260 while (h - l > 1) { 2261 int m = (l + h) / 2; 2262 double vm = values.getDouble(m); 2263 if (x < vm) { 2264 h = m; 2265 } else if (x > vm) { 2266 l = m; 2267 } else { 2268 indexes.setAbs(i, m); 2269 break; 2270 } 2271 } 2272 } 2273 } 2274 2275 return indexes; 2276 } 2277 2278 /** 2279 * Roll items over given axis by given amount 2280 * @param a 2281 * @param shift 2282 * @param axis if null, then roll flattened dataset 2283 * @return rolled dataset 2284 */ 2285 public static <T extends Dataset> T roll(final T a, final int shift, Integer axis) { 2286 Dataset r = DatasetFactory.zeros(a); 2287 int is = a.getElementsPerItem(); 2288 if (axis == null) { 2289 IndexIterator it = a.getIterator(); 2290 int s = r.getSize(); 2291 int i = shift % s; 2292 if (i < 0) 2293 i += s; 2294 while (it.hasNext()) { 2295 r.setObjectAbs(i, a.getObjectAbs(it.index)); 2296 i += is; 2297 if (i >= s) { 2298 i %= s; 2299 } 2300 } 2301 } else { 2302 axis = a.checkAxis(axis); 2303 PositionIterator pi = a.getPositionIterator(axis); 2304 int s = a.getShapeRef()[axis]; 2305 Dataset u = DatasetFactory.zeros(is, a.getClass(), new int[] {s}); 2306 Dataset v = DatasetFactory.zeros(u); 2307 int[] pos = pi.getPos(); 2308 boolean[] hit = pi.getOmit(); 2309 while (pi.hasNext()) { 2310 a.copyItemsFromAxes(pos, hit, u); 2311 int i = shift % s; 2312 if (i < 0) 2313 i += s; 2314 for (int j = 0; j < s; j++) { 2315 v.setObjectAbs(i, u.getObjectAbs(j*is)); 2316 i += is; 2317 if (i >= s) { 2318 i %= s; 2319 } 2320 } 2321 r.setItemsOnAxes(pos, hit, v.getBuffer()); 2322 } 2323 } 2324 return (T) r; 2325 } 2326 2327 /** 2328 * Roll the specified axis backwards until it lies in given position 2329 * @param a 2330 * @param axis The rolled axis (index in shape array). Other axes are left unchanged in relative positions 2331 * @param start The position with it right of the destination of the rolled axis 2332 * @return dataset with rolled axis 2333 */ 2334 public static <T extends Dataset> T rollAxis(final T a, int axis, int start) { 2335 int r = a.getRank(); 2336 axis = a.checkAxis(axis); 2337 if (start < 0) 2338 start += r; 2339 if (start < 0 || start > r) { 2340 throw new IllegalArgumentException("Start is out of range: it should be >= 0 and <= " + r); 2341 } 2342 if (axis < start) 2343 start--; 2344 2345 if (axis == start) 2346 return a; 2347 2348 ArrayList<Integer> axes = new ArrayList<Integer>(); 2349 for (int i = 0; i < r; i++) { 2350 if (i != axis) { 2351 axes.add(i); 2352 } 2353 } 2354 axes.add(start, axis); 2355 int[] aa = new int[r]; 2356 for (int i = 0; i < r; i++) { 2357 aa[i] = axes.get(i); 2358 } 2359 return (T) a.getTransposedView(aa); 2360 } 2361 2362 private static SliceND createFlippedSlice(final Dataset a, int axis) { 2363 int[] shape = a.getShapeRef(); 2364 SliceND slice = new SliceND(shape); 2365 slice.flip(axis); 2366 return slice; 2367 } 2368 2369 /** 2370 * Flip items in left/right direction, column-wise, or along second axis 2371 * @param a dataset must be at least 2D 2372 * @return view of flipped dataset 2373 */ 2374 public static <T extends Dataset> T flipLeftRight(final T a) { 2375 if (a.getRank() < 2) { 2376 throw new IllegalArgumentException("Dataset must be at least 2D"); 2377 } 2378 return (T) a.getSliceView(createFlippedSlice(a, 1)); 2379 } 2380 2381 /** 2382 * Flip items in up/down direction, row-wise, or along first axis 2383 * @param a dataset 2384 * @return view of flipped dataset 2385 */ 2386 public static <T extends Dataset> T flipUpDown(final T a) { 2387 return (T) a.getSliceView(createFlippedSlice(a, 0)); 2388 } 2389 2390 /** 2391 * Rotate items in first two dimension by 90 degrees anti-clockwise 2392 * @param a dataset must be at least 2D 2393 * @return view of flipped dataset 2394 */ 2395 public static <T extends Dataset> T rotate90(final T a) { 2396 return rotate90(a, 1); 2397 } 2398 2399 /** 2400 * Rotate items in first two dimension by 90 degrees anti-clockwise 2401 * @param a dataset must be at least 2D 2402 * @param k number of 90-degree rotations 2403 * @return view of flipped dataset 2404 */ 2405 public static <T extends Dataset> T rotate90(final T a, int k) { 2406 k = k % 4; 2407 while (k < 0) { 2408 k += 4; 2409 } 2410 int r = a.getRank(); 2411 if (r < 2) { 2412 throw new IllegalArgumentException("Dataset must be at least 2D"); 2413 } 2414 switch (k) { 2415 case 1: case 3: 2416 int[] axes = new int[r]; 2417 axes[0] = 1; 2418 axes[1] = 0; 2419 for (int i = 2; i < r; i++) { 2420 axes[i] = i; 2421 } 2422 Dataset t = a.getTransposedView(axes); 2423 return (T) t.getSliceView(createFlippedSlice(t, k == 1 ? 0 : 1)); 2424 case 2: 2425 SliceND s = createFlippedSlice(a, 0); 2426 s.flip(1); 2427 return (T) a.getSliceView(s); 2428 default: 2429 case 0: 2430 return a; 2431 } 2432 } 2433 2434 /** 2435 * Select content according where condition is true. All inputs are broadcasted to a maximum shape 2436 * @param condition boolean dataset 2437 * @param x 2438 * @param y 2439 * @return dataset where content is x or y depending on whether condition is true or otherwise 2440 */ 2441 public static Dataset select(BooleanDataset condition, Object x, Object y) { 2442 Object[] all = new Object[] {condition, x, y}; 2443 Dataset[] dAll = BroadcastUtils.convertAndBroadcast(all); 2444 condition = (BooleanDataset) dAll[0]; 2445 Dataset dx = dAll[1]; 2446 Dataset dy = dAll[2]; 2447 int dt = DTypeUtils.getBestDType(dx.getDType(),dy.getDType()); 2448 int ds = Math.max(dx.getElementsPerItem(), dy.getElementsPerItem()); 2449 2450 @SuppressWarnings("deprecation") 2451 Dataset r = DatasetFactory.zeros(ds, condition.getShapeRef(), dt); 2452 IndexIterator iter = condition.getIterator(true); 2453 final int[] pos = iter.getPos(); 2454 int i = 0; 2455 while (iter.hasNext()) { 2456 r.setObjectAbs(i++, condition.getElementBooleanAbs(iter.index) ? dx.getObject(pos) : dy.getObject(pos)); 2457 } 2458 return r; 2459 } 2460 2461 /** 2462 * Select content from choices where condition is true, otherwise use default. All inputs are broadcasted to a maximum shape 2463 * @param conditions array of boolean datasets 2464 * @param choices array of datasets or objects 2465 * @param def default value (can be a dataset) 2466 * @return dataset 2467 */ 2468 public static Dataset select(BooleanDataset[] conditions, Object[] choices, Object def) { 2469 final int n = conditions.length; 2470 if (choices.length != n) { 2471 throw new IllegalArgumentException("Choices list is not same length as conditions list"); 2472 } 2473 Object[] all = new Object[2*n]; 2474 System.arraycopy(conditions, 0, all, 0, n); 2475 System.arraycopy(choices, 0, all, n, n); 2476 Dataset[] dAll = BroadcastUtils.convertAndBroadcast(all); 2477 conditions = new BooleanDataset[n]; 2478 Dataset[] dChoices = new Dataset[n]; 2479 System.arraycopy(dAll, 0, conditions, 0, n); 2480 System.arraycopy(dAll, n, dChoices, 0, n); 2481 int dt = -1; 2482 int ds = -1; 2483 for (int i = 0; i < n; i++) { 2484 Dataset a = dChoices[i]; 2485 int t = a.getDType(); 2486 if (t > dt) 2487 dt = t; 2488 int s = a.getElementsPerItem(); 2489 if (s > ds) 2490 ds = s; 2491 } 2492 if (dt < 0 || ds < 1) { 2493 throw new IllegalArgumentException("Dataset types of choices are invalid"); 2494 } 2495 2496 @SuppressWarnings("deprecation") 2497 Dataset r = DatasetFactory.zeros(ds, conditions[0].getShapeRef(), dt); 2498 Dataset d = DatasetFactory.createFromObject(def).getBroadcastView(r.getShapeRef()); 2499 PositionIterator iter = new PositionIterator(r.getShapeRef()); 2500 final int[] pos = iter.getPos(); 2501 int i = 0; 2502 while (iter.hasNext()) { 2503 int j = 0; 2504 for (; j < n; j++) { 2505 if (conditions[j].get(pos)) { 2506 r.setObjectAbs(i++, dChoices[j].getObject(pos)); 2507 break; 2508 } 2509 } 2510 if (j == n) { 2511 r.setObjectAbs(i++, d.getObject(pos)); 2512 } 2513 } 2514 return r; 2515 } 2516 2517 /** 2518 * Choose content from choices where condition is true, otherwise use default. All inputs are broadcasted to a maximum shape 2519 * @param index integer dataset (ideally, items should be in [0, n) range, if there are n choices) 2520 * @param choices array of datasets or objects 2521 * @param throwAIOOBE if true, throw array index out of bound exception 2522 * @param clip true to clip else wrap indices out of bounds; only used when throwAOOBE is false 2523 * @return dataset 2524 */ 2525 public static Dataset choose(IntegerDataset index, Object[] choices, boolean throwAIOOBE, boolean clip) { 2526 final int n = choices.length; 2527 Object[] all = new Object[n + 1]; 2528 System.arraycopy(choices, 0, all, 0, n); 2529 all[n] = index; 2530 Dataset[] dChoices = BroadcastUtils.convertAndBroadcast(all); 2531 int dt = -1; 2532 int ds = -1; 2533 int mr = -1; 2534 for (int i = 0; i < n; i++) { 2535 Dataset a = dChoices[i]; 2536 int r = a.getRank(); 2537 if (r > mr) 2538 mr = r; 2539 int t = a.getDType(); 2540 if (t > dt) 2541 dt = t; 2542 int s = a.getElementsPerItem(); 2543 if (s > ds) 2544 ds = s; 2545 } 2546 if (dt < 0 || ds < 1) { 2547 throw new IllegalArgumentException("Dataset types of choices are invalid"); 2548 } 2549 index = (IntegerDataset) dChoices[n]; 2550 dChoices[n] = null; 2551 2552 @SuppressWarnings("deprecation") 2553 Dataset r = DatasetFactory.zeros(ds, index.getShape(), dt); 2554 IndexIterator iter = index.getIterator(true); 2555 final int[] pos = iter.getPos(); 2556 int i = 0; 2557 while (iter.hasNext()) { 2558 int j = index.getAbs(iter.index); 2559 if (j < 0) { 2560 if (throwAIOOBE) 2561 throw new ArrayIndexOutOfBoundsException(j); 2562 if (clip) { 2563 j = 0; 2564 } else { 2565 j %= n; 2566 j += n; // as remainder still negative 2567 } 2568 } 2569 if (j >= n) { 2570 if (throwAIOOBE) 2571 throw new ArrayIndexOutOfBoundsException(j); 2572 if (clip) { 2573 j = n - 1; 2574 } else { 2575 j %= n; 2576 } 2577 } 2578 Dataset c = dChoices[j]; 2579 r.setObjectAbs(i++, c.getObject(pos)); 2580 } 2581 return r; 2582 } 2583 2584 /** 2585 * Calculate positions in given shape from a dataset of 1-D indexes 2586 * @param indices 2587 * @param shape 2588 * @return list of positions as integer datasets 2589 */ 2590 public static List<IntegerDataset> calcPositionsFromIndexes(Dataset indices, int[] shape) { 2591 int rank = shape.length; 2592 List<IntegerDataset> posns = new ArrayList<IntegerDataset>(); 2593 int[] iShape = indices.getShapeRef(); 2594 for (int i = 0; i < rank; i++) { 2595 posns.add(new IntegerDataset(iShape)); 2596 } 2597 IndexIterator it = indices.getIterator(true); 2598 int[] pos = it.getPos(); 2599 while (it.hasNext()) { 2600 int n = indices.getInt(pos); 2601 int[] p = ShapeUtils.getNDPositionFromShape(n, shape); 2602 for (int i = 0; i < rank; i++) { 2603 posns.get(i).setItem(p[i], pos); 2604 } 2605 } 2606 return posns; 2607 } 2608 2609 2610 /** 2611 * Calculate indexes in given shape from datasets of position 2612 * @param positions as a list of datasets where each holds the position in a dimension 2613 * @param shape 2614 * @param mode either null, zero-length, unit length or length of rank of shape where 2615 * 0 = raise exception, 1 = wrap, 2 = clip 2616 * @return indexes as an integer dataset 2617 */ 2618 public static IntegerDataset calcIndexesFromPositions(List<? extends Dataset> positions, int[] shape, int... mode) { 2619 int rank = shape.length; 2620 if (positions.size() != rank) { 2621 throw new IllegalArgumentException("Number of position datasets must be equal to rank of shape"); 2622 } 2623 2624 if (mode == null || mode.length == 0) { 2625 mode = new int[rank]; 2626 } else if (mode.length == 1) { 2627 int m = mode[0]; 2628 mode = new int[rank]; 2629 Arrays.fill(mode, m); 2630 } else if (mode.length != rank) { 2631 throw new IllegalArgumentException("Mode length greater than one must match rank of shape"); 2632 } 2633 for (int i = 0; i < rank; i++) { 2634 int m = mode[i]; 2635 if (m < 0 || m > 2) { 2636 throw new IllegalArgumentException("Unknown mode value - it must be 0, 1, or 2"); 2637 } 2638 } 2639 2640 Dataset p = positions.get(0); 2641 IntegerDataset indexes = new IntegerDataset(p.getShapeRef()); 2642 IndexIterator it = p.getIterator(true); 2643 int[] iPos = it.getPos(); 2644 int[] tPos = new int[rank]; 2645 while (it.hasNext()) { 2646 for (int i = 0; i < rank; i++) { 2647 p = positions.get(i); 2648 int j = p.getInt(iPos); 2649 int d = shape[i]; 2650 if (mode[i] == 0) { 2651 if (j < 0 || j >= d) { 2652 throw new ArrayIndexOutOfBoundsException("Position value exceeds dimension in shape"); 2653 } 2654 } else if (mode[i] == 1) { 2655 while (j < 0) 2656 j += d; 2657 while (j >= d) 2658 j -= d; 2659 } else { 2660 if (j < 0) 2661 j = 0; 2662 if (j >= d) 2663 j = d - 1; 2664 } 2665 tPos[i] = j; 2666 } 2667 indexes.set(ShapeUtils.getFlat1DIndex(shape, tPos), iPos); 2668 } 2669 2670 return indexes; 2671 } 2672 2673 /** 2674 * Serialize dataset by flattening it. Discards metadata 2675 * @param data 2676 * @return some java array 2677 */ 2678 public static Serializable serializeDataset(final IDataset data) { 2679 Dataset d = convertToDataset(data.getSliceView()); 2680 d.clearMetadata(null); 2681 return d.flatten().getBuffer(); 2682 } 2683 2684 /** 2685 * Extract values where condition is non-zero. This is similar to Dataset#getByBoolean but supports broadcasting 2686 * @param data 2687 * @param condition should be broadcastable to data 2688 * @return 1-D dataset of values 2689 */ 2690 @SuppressWarnings("deprecation") 2691 public static Dataset extract(final IDataset data, final IDataset condition) { 2692 Dataset a = convertToDataset(data.getSliceView()); 2693 Dataset b = cast(condition.getSliceView(), Dataset.BOOL); 2694 2695 try { 2696 return a.getByBoolean(b); 2697 } catch (IllegalArgumentException e) { 2698 final int length = ((Number) b.sum()).intValue(); 2699 2700 BroadcastPairIterator it = new BroadcastPairIterator(a, b, null, false); 2701 int size = ShapeUtils.calcSize(it.getShape()); 2702 Dataset c; 2703 if (length < size) { 2704 int[] ashape = it.getFirstShape(); 2705 int[] bshape = it.getSecondShape(); 2706 int r = ashape.length; 2707 size = length; 2708 for (int i = 0; i < r; i++) { 2709 int s = ashape[i]; 2710 if (s > 1 && bshape[i] == 1) { 2711 size *= s; 2712 } 2713 } 2714 } 2715 c = DatasetFactory.zeros(new int[] {size}, a.getDType()); 2716 2717 int i = 0; 2718 if (it.isOutputDouble()) { 2719 while (it.hasNext()) { 2720 if (it.bLong != 0) { 2721 c.setObjectAbs(i++, it.aDouble); 2722 } 2723 } 2724 } else { 2725 while (it.hasNext()) { 2726 if (it.bLong != 0) { 2727 c.setObjectAbs(i++, it.aLong); 2728 } 2729 } 2730 } 2731 2732 return c; 2733 } 2734 } 2735}