package org.pytorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.Locale;

/* loaded from: classes2.dex */
public abstract class Tensor {
    final MemoryFormat a;

    @DoNotStrip
    private HybridData mHybridData;

    @DoNotStrip
    final long[] shape;

    /* loaded from: classes2.dex */
    public static class Tensor_float32 extends Tensor {
        private final FloatBuffer b;

        public Tensor_float32(FloatBuffer floatBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = floatBuffer;
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.FLOAT32;
        }

        @Override // org.pytorch.Tensor
        public final float[] b() {
            this.b.rewind();
            float[] fArr = new float[this.b.remaining()];
            this.b.get(fArr);
            return fArr;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(this.shape));
        }
    }

    /* loaded from: classes2.dex */
    static class Tensor_float64 extends Tensor {
        private final DoubleBuffer b;

        private Tensor_float64(DoubleBuffer doubleBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = doubleBuffer;
        }

        /* synthetic */ Tensor_float64(DoubleBuffer doubleBuffer, long[] jArr, MemoryFormat memoryFormat, byte b) {
            this(doubleBuffer, jArr, memoryFormat);
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.FLOAT64;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(this.shape));
        }
    }

    /* loaded from: classes2.dex */
    static class Tensor_int32 extends Tensor {
        private final IntBuffer b;

        private Tensor_int32(IntBuffer intBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = intBuffer;
        }

        /* synthetic */ Tensor_int32(IntBuffer intBuffer, long[] jArr, MemoryFormat memoryFormat, byte b) {
            this(intBuffer, jArr, memoryFormat);
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.INT32;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(this.shape));
        }
    }

    /* loaded from: classes2.dex */
    static class Tensor_int64 extends Tensor {
        private final LongBuffer b;

        private Tensor_int64(LongBuffer longBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = longBuffer;
        }

        /* synthetic */ Tensor_int64(LongBuffer longBuffer, long[] jArr, MemoryFormat memoryFormat, byte b) {
            this(longBuffer, jArr, memoryFormat);
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.INT64;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(this.shape));
        }
    }

    /* loaded from: classes2.dex */
    static class Tensor_int8 extends Tensor {
        private final ByteBuffer b;

        private Tensor_int8(ByteBuffer byteBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = byteBuffer;
        }

        /* synthetic */ Tensor_int8(ByteBuffer byteBuffer, long[] jArr, MemoryFormat memoryFormat, byte b) {
            this(byteBuffer, jArr, memoryFormat);
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.INT8;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(this.shape));
        }
    }

    /* loaded from: classes2.dex */
    static class Tensor_uint8 extends Tensor {
        private final ByteBuffer b;

        private Tensor_uint8(ByteBuffer byteBuffer, long[] jArr, MemoryFormat memoryFormat) {
            super(jArr, memoryFormat, (byte) 0);
            this.b = byteBuffer;
        }

        /* synthetic */ Tensor_uint8(ByteBuffer byteBuffer, long[] jArr, MemoryFormat memoryFormat, byte b) {
            this(byteBuffer, jArr, memoryFormat);
        }

        @Override // org.pytorch.Tensor
        public final DType a() {
            return DType.UINT8;
        }

        @Override // org.pytorch.Tensor
        Buffer getRawDataBuffer() {
            return this.b;
        }

        public String toString() {
            return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(this.shape));
        }
    }

    private Tensor(long[] jArr, MemoryFormat memoryFormat) {
        b(jArr);
        this.shape = Arrays.copyOf(jArr, jArr.length);
        this.a = memoryFormat;
    }

    /* synthetic */ Tensor(long[] jArr, MemoryFormat memoryFormat, byte b) {
        this(jArr, memoryFormat);
    }

    public static long a(long[] jArr) {
        b(jArr);
        int i = 1;
        for (int i2 = 0; i2 < 4; i2++) {
            i = (int) (i * jArr[i2]);
        }
        return i;
    }

    public static void a(boolean z, String str, Object... objArr) {
        if (!z) {
            throw new IllegalArgumentException(String.format(Locale.US, str, objArr));
        }
    }

    public static void b(long[] jArr) {
        a(jArr != null, "Shape must be not null", new Object[0]);
        for (long j : jArr) {
            a(j >= 0, "Shape elements must be non negative", new Object[0]);
        }
    }

    @DoNotStrip
    private static Tensor nativeNewTensor(ByteBuffer byteBuffer, long[] jArr, int i, int i2, HybridData hybridData) {
        Tensor tensor;
        MemoryFormat memoryFormat = MemoryFormat.CONTIGUOUS;
        if (MemoryFormat.CHANNELS_LAST.jniCode == i2) {
            memoryFormat = MemoryFormat.CHANNELS_LAST;
        } else if (MemoryFormat.CHANNELS_LAST_3D.jniCode == i2) {
            memoryFormat = MemoryFormat.CHANNELS_LAST_3D;
        }
        if (DType.FLOAT32.jniCode == i) {
            tensor = new Tensor_float32(byteBuffer.asFloatBuffer(), jArr, memoryFormat);
        } else {
            byte b = 0;
            if (DType.INT32.jniCode == i) {
                tensor = new Tensor_int32(byteBuffer.asIntBuffer(), jArr, memoryFormat, b);
            } else if (DType.INT64.jniCode == i) {
                tensor = new Tensor_int64(byteBuffer.asLongBuffer(), jArr, memoryFormat, b);
            } else if (DType.FLOAT64.jniCode == i) {
                tensor = new Tensor_float64(byteBuffer.asDoubleBuffer(), jArr, memoryFormat, b);
            } else if (DType.UINT8.jniCode == i) {
                tensor = new Tensor_uint8(byteBuffer, jArr, memoryFormat, b);
            } else if (DType.INT8.jniCode == i) {
                tensor = new Tensor_int8(byteBuffer, jArr, memoryFormat, b);
            } else {
                new IllegalArgumentException("Unknown Tensor dtype");
                tensor = null;
            }
        }
        tensor.mHybridData = hybridData;
        return tensor;
    }

    public abstract DType a();

    public float[] b() {
        throw new IllegalStateException("Tensor of type " + getClass().getSimpleName() + " cannot return data as float array.");
    }

    @DoNotStrip
    int dtypeJniCode() {
        return a().jniCode;
    }

    @DoNotStrip
    Buffer getRawDataBuffer() {
        throw new IllegalStateException("Tensor of type " + getClass().getSimpleName() + " cannot return raw data buffer.");
    }

    @DoNotStrip
    int memoryFormatJniCode() {
        return this.a.jniCode;
    }
}
