mlir/attributes/
symbol_ref.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright 2024-2025, Giordano Salvador
// SPDX-License-Identifier: BSD-3-Clause

#![allow(dead_code)]

use mlir_sys::MlirAttribute;
use mlir_sys::mlirFlatSymbolRefAttrGet;
use mlir_sys::mlirFlatSymbolRefAttrGetValue;
use mlir_sys::mlirSymbolRefAttrGet;
use mlir_sys::mlirSymbolRefAttrGetLeafReference;
use mlir_sys::mlirSymbolRefAttrGetNestedReference;
use mlir_sys::mlirSymbolRefAttrGetNumNestedReferences;
use mlir_sys::mlirSymbolRefAttrGetRootReference;
use mlir_sys::mlirSymbolRefAttrGetTypeID;

use std::cmp;
use std::fmt;

use crate::attributes;
use crate::do_unsafe;
use crate::exit_code;
use crate::ir;

use attributes::IAttribute;
use exit_code::ExitCode;
use exit_code::exit;
use ir::Attribute;
use ir::Context;
use ir::StringRef;
use ir::TypeID;

#[derive(Clone)]
pub struct SymbolRef(MlirAttribute);

impl SymbolRef {
    pub fn new(context: &Context, sym: &StringRef, refs: &[Attribute]) -> Self {
        let r: Vec<MlirAttribute> = refs.iter().map(|a| *a.get()).collect();
        Self::from(do_unsafe!(mlirSymbolRefAttrGet(
            *context.get(),
            *sym.get(),
            refs.len() as isize,
            r.as_ptr(),
        )))
    }

    pub fn new_flat(context: &Context, sym: &StringRef) -> Self {
        Self::from(do_unsafe!(mlirFlatSymbolRefAttrGet(
            *context.get(),
            *sym.get()
        )))
    }

    pub fn from_checked(attr_: MlirAttribute) -> Self {
        let attr = Attribute::from(attr_);
        if !attr.is_symbol_ref() {
            eprintln!(
                "Cannot coerce attribute to symbol reference attribute: {}",
                attr
            );
            exit(ExitCode::IRError);
        }
        Self::from(attr_)
    }

    pub fn get(&self) -> &MlirAttribute {
        &self.0
    }

    pub fn get_leaf(&self) -> StringRef {
        StringRef::from(do_unsafe!(mlirSymbolRefAttrGetLeafReference(*self.get())))
    }

    pub fn get_mut(&mut self) -> &mut MlirAttribute {
        &mut self.0
    }

    pub fn get_nested_reference(&self, i: isize) -> Attribute {
        if i >= self.num_nested_references() || i < 0 {
            eprintln!(
                "Index '{}' out of bounds for nested reference of symbol ref",
                i
            );
            exit(ExitCode::IRError);
        }
        Attribute::from(do_unsafe!(mlirSymbolRefAttrGetNestedReference(
            *self.get(),
            i
        )))
    }

    pub fn get_root(&self) -> StringRef {
        StringRef::from(do_unsafe!(mlirSymbolRefAttrGetRootReference(*self.get())))
    }

    pub fn get_type_id() -> TypeID {
        TypeID::from(do_unsafe!(mlirSymbolRefAttrGetTypeID()))
    }

    /// If this is a flat symbol reference, return the referenced symbol.
    pub fn get_value(&self) -> Option<StringRef> {
        if self.as_attribute().is_flat_symbol_ref() {
            Some(StringRef::from(do_unsafe!(mlirFlatSymbolRefAttrGetValue(
                *self.get()
            ))))
        } else {
            None
        }
    }

    pub fn num_nested_references(&self) -> isize {
        do_unsafe!(mlirSymbolRefAttrGetNumNestedReferences(*self.get()))
    }
}

impl fmt::Display for SymbolRef {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", match self.get_value() {
            None => "".to_string(),
            Some(s) => s.to_string(),
        })
    }
}

impl From<MlirAttribute> for SymbolRef {
    fn from(attr: MlirAttribute) -> Self {
        Self(attr)
    }
}

impl From<Attribute> for SymbolRef {
    fn from(attr: Attribute) -> Self {
        Self::from(&attr)
    }
}

impl From<&Attribute> for SymbolRef {
    fn from(attr: &Attribute) -> Self {
        Self::from(*attr.get())
    }
}

impl IAttribute for SymbolRef {
    fn get(&self) -> &MlirAttribute {
        self.get()
    }

    fn get_mut(&mut self) -> &mut MlirAttribute {
        self.get_mut()
    }
}

impl cmp::PartialEq for SymbolRef {
    fn eq(&self, rhs: &Self) -> bool {
        self.as_attribute() == rhs.as_attribute()
    }
}