mlir/
types.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright 2024, Giordano Salvador
// SPDX-License-Identifier: BSD-3-Clause

#![allow(dead_code)]

use mlir_sys::MlirType;
use mlir_sys::mlirTypeEqual;

use std::cmp;

use crate::do_unsafe;
use crate::ir;

use ir::Context;
use ir::Type;

use float::Float as FloatType;
use index::Index;
use integer::Integer as IntegerType;
use shaped::Shaped;

pub mod complex;
pub mod float;
pub mod function;
pub mod index;
pub mod integer;
pub mod memref;
pub mod none;
pub mod opaque;
pub mod ranked_tensor;
pub mod shaped;
pub mod tuple;
pub mod unit;
pub mod unranked_memref;
pub mod unranked_tensor;
pub mod vector;

pub trait GetWidth: IType {
    fn get_width(&self) -> Option<usize> {
        if self.as_type().is_index() {
            Some(Index::from(*self.get()).get_width())
        } else if self.as_type().is_integer() {
            Some(IntegerType::from(*self.get()).get_width())
        } else if self.as_type().is_float() {
            Some(FloatType::from(*self.get()).get_width())
        } else if self.as_type().is_shaped() {
            Shaped::from(*self.get()).get_element_type().get_width()
        } else {
            None
        }
    }
}

pub trait IType {
    fn get(&self) -> &MlirType;
    fn get_mut(&mut self) -> &mut MlirType;

    fn as_type(&self) -> Type {
        Type::from(*self.get())
    }

    fn get_context(&self) -> Context {
        self.as_type().get_context()
    }
}

pub trait IsPromotableTo<T> {
    fn is_promotable_to(&self, other: &T) -> bool;
}

impl GetWidth for dyn IType {}

impl cmp::PartialEq for dyn IType {
    fn eq(&self, rhs: &Self) -> bool {
        do_unsafe!(mlirTypeEqual(*self.get(), *rhs.get()))
    }
}

impl<T: IType> IsPromotableTo<T> for dyn IType {
    fn is_promotable_to(&self, other: &T) -> bool {
        let t = self.as_type();
        let t_other = other.as_type();
        if t.is_integer() && t_other.is_integer() {
            let t_int = IntegerType::from(*t.get());
            let t_int_other = IntegerType::from(*t_other.get());
            t_int.is_promotable_to(&t_int_other)
        } else if t.is_float() && t_other.is_float() {
            let t_float = FloatType::from(*t.get());
            let t_float_other = FloatType::from(*t_other.get());
            t_float.is_promotable_to(&t_float_other)
        } else {
            false
        }
    }
}